marginaleffects 0.5.1__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/PKG-INFO +2 -3
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/autodiff/__init__.py +4 -13
- marginaleffects-0.6.0/marginaleffects/autodiff/glm/__init__.py +1 -0
- marginaleffects-0.6.0/marginaleffects/autodiff/lower.py +219 -0
- marginaleffects-0.6.0/marginaleffects/autodiff/ops.py +60 -0
- marginaleffects-0.6.0/marginaleffects/autodiff/pipeline.py +237 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/by.py +42 -29
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/classes/model.py +3 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/comparisons.py +187 -85
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/datagrid.py +1 -1
- marginaleffects-0.6.0/marginaleffects/hypothesis_compile.py +213 -0
- marginaleffects-0.6.0/marginaleffects/plan.py +170 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/predictions.py +94 -86
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/settings.py +7 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/statsmodels/model.py +29 -61
- marginaleffects-0.6.0/marginaleffects/test/core.py +17 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/test/formula.py +1 -1
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/uncertainty.py +8 -2
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects.egg-info/PKG-INFO +2 -3
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects.egg-info/SOURCES.txt +11 -8
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects.egg-info/requires.txt +1 -3
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/pyproject.toml +2 -4
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_autodiff.py +35 -1
- marginaleffects-0.6.0/tests/test_autodiff_lower.py +180 -0
- marginaleffects-0.6.0/tests/test_autodiff_pipeline.py +240 -0
- marginaleffects-0.6.0/tests/test_comparison_plan.py +162 -0
- marginaleffects-0.6.0/tests/test_hypothesis_compile.py +51 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_jss.py +2 -5
- marginaleffects-0.6.0/tests/test_plan.py +82 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_plot_comparisons.py +2 -5
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_plot_predictions.py +3 -5
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_plot_slopes.py +13 -6
- marginaleffects-0.6.0/tests/test_prediction_plan.py +98 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_quantreg.py +0 -1
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/utilities.py +13 -0
- marginaleffects-0.5.1/marginaleffects/autodiff/comparisons.py +0 -69
- marginaleffects-0.5.1/marginaleffects/autodiff/dispatch.py +0 -310
- marginaleffects-0.5.1/marginaleffects/autodiff/glm/__init__.py +0 -5
- marginaleffects-0.5.1/marginaleffects/autodiff/glm/comparisons.py +0 -195
- marginaleffects-0.5.1/marginaleffects/autodiff/glm/predictions.py +0 -131
- marginaleffects-0.5.1/marginaleffects/autodiff/linear/__init__.py +0 -4
- marginaleffects-0.5.1/marginaleffects/autodiff/linear/comparisons.py +0 -147
- marginaleffects-0.5.1/marginaleffects/autodiff/linear/predictions.py +0 -71
- marginaleffects-0.5.1/marginaleffects/autodiff/utils.py +0 -31
- marginaleffects-0.5.1/marginaleffects/test/core.py +0 -155
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/README.md +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/benchmarks/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/benchmarks/benchmark_autodiff.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/autodiff/glm/families.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/classes/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/classes/result.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/datasets.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/docstrings/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/docstrings/params.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/docstrings/qmd.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/estimands.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/formula.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/linearmodels/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/linearmodels/model.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/plot/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/plot/common.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/plot/comparisons.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/plot/predictions.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/plot/slopes.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/pyfixest/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/pyfixest/model.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/by.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/categorical.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/comparison.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/deprecated.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/hypothesis_null.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/newdata.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/sanitize_model.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/utils.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/validation.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/variables.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sanitize/vcov.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sklearn/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/sklearn/model.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/slopes.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/statsmodels/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/test/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/test/equivalence.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/test/joint.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/test/main.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/transform.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects/utils.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects.egg-info/dependency_links.txt +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/marginaleffects.egg-info/top_level.txt +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/setup.cfg +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/__init__.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/helpers.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_analytic.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_bugfix.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_by.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_categorical.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_categorical_validation.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_comparisons.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_comparisons_interaction.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_datagrid_01.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_datagrid_02.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_equivalence.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_formula.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_formulaic_utils.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_hypotheses.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_hypotheses_joint.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_hypothesis.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_linearmodels_panelols.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_missing.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_newdata.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_predictions.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_pyfixest.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_sklearn.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_slopes.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_logit.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_mixedlm.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_mnlogit.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_negativebinomial.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_ols.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_ordinal.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_poisson.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_probit.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_vcov.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_statsmodels_wls.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_typical.py +0 -0
- {marginaleffects-0.5.1 → marginaleffects-0.6.0}/tests/test_utils.py +0 -0
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: marginaleffects
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models.
|
|
5
5
|
License-Expression: GPL-3.0-or-later
|
|
6
6
|
Requires-Python: >=3.10
|
|
7
7
|
Description-Content-Type: text/markdown
|
|
8
8
|
Requires-Dist: formulaic>=1.0.2
|
|
9
|
+
Requires-Dist: jax>=0.4.0
|
|
9
10
|
Requires-Dist: narwhals>=1.34.0
|
|
10
11
|
Requires-Dist: numpy>=2.0.0
|
|
11
12
|
Requires-Dist: patsy>=1.0.1
|
|
@@ -14,8 +15,6 @@ Requires-Dist: pydantic>=2.10.3
|
|
|
14
15
|
Requires-Dist: plotnine>=0.14.5
|
|
15
16
|
Requires-Dist: scipy>=1.14.1
|
|
16
17
|
Requires-Dist: pyarrow>=19.0.1
|
|
17
|
-
Provides-Extra: autodiff
|
|
18
|
-
Requires-Dist: jax>=0.4.0; extra == "autodiff"
|
|
19
18
|
Provides-Extra: test
|
|
20
19
|
Requires-Dist: duckdb>=1.1.2; extra == "test"
|
|
21
20
|
Requires-Dist: matplotlib>=3.7.1; extra == "test"
|
|
@@ -29,18 +29,14 @@ if _JAX_AVAILABLE:
|
|
|
29
29
|
|
|
30
30
|
jax.config.update("jax_enable_x64", True)
|
|
31
31
|
|
|
32
|
-
# Import submodules to make them accessible
|
|
33
|
-
from . import linear as linear
|
|
34
32
|
from . import glm as glm
|
|
33
|
+
from . import pipeline as pipeline
|
|
35
34
|
|
|
36
|
-
# Re-export types for convenience
|
|
37
|
-
from .comparisons import ComparisonType as ComparisonType
|
|
38
35
|
from .glm.families import Family as Family, Link as Link
|
|
39
36
|
|
|
40
37
|
__all__ = [
|
|
41
|
-
"linear",
|
|
42
38
|
"glm",
|
|
43
|
-
"
|
|
39
|
+
"pipeline",
|
|
44
40
|
"Family",
|
|
45
41
|
"Link",
|
|
46
42
|
]
|
|
@@ -50,13 +46,8 @@ else:
|
|
|
50
46
|
def __getattr__(self, name):
|
|
51
47
|
_raise_jax_error()
|
|
52
48
|
|
|
53
|
-
linear = _DummyModule()
|
|
54
49
|
glm = _DummyModule()
|
|
55
|
-
|
|
56
|
-
# Create dummy enums that raise errors
|
|
57
|
-
class ComparisonType:
|
|
58
|
-
def __getattribute__(self, name):
|
|
59
|
-
_raise_jax_error()
|
|
50
|
+
pipeline = _DummyModule()
|
|
60
51
|
|
|
61
52
|
class Family:
|
|
62
53
|
def __getattribute__(self, name):
|
|
@@ -66,4 +57,4 @@ else:
|
|
|
66
57
|
def __getattribute__(self, name):
|
|
67
58
|
_raise_jax_error()
|
|
68
59
|
|
|
69
|
-
__all__ = []
|
|
60
|
+
__all__ = ["glm", "pipeline", "Family", "Link"]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import families as families
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ..settings import is_autodiff_enabled, is_autodiff_forced
|
|
9
|
+
from ..uncertainty import get_se
|
|
10
|
+
from .ops import COMPARISON_OPS
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class Lowered:
|
|
15
|
+
ok: bool
|
|
16
|
+
kwargs: dict | None = None
|
|
17
|
+
reason: str = ""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class AutodiffResult:
|
|
22
|
+
std_error: np.ndarray
|
|
23
|
+
jacobian: np.ndarray
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _fail(reason: str) -> Lowered:
|
|
27
|
+
return Lowered(False, kwargs=None, reason=reason)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _model_args(model):
|
|
31
|
+
args = model.get_autodiff_args()
|
|
32
|
+
if args is None:
|
|
33
|
+
return None, _fail("")
|
|
34
|
+
if isinstance(args, str):
|
|
35
|
+
return None, _fail(args)
|
|
36
|
+
return args, None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _coefs_or_failure(model):
|
|
40
|
+
coefs = np.asarray(model.get_coef(), dtype=float).reshape(-1)
|
|
41
|
+
if np.isnan(coefs).any():
|
|
42
|
+
return None, _fail("models with NA coefficients")
|
|
43
|
+
return coefs, None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _design_or_failure(X, coefs, n_pred):
|
|
47
|
+
X = np.asarray(X, dtype=float)
|
|
48
|
+
if X.ndim != 2 or X.shape[1] != coefs.size or X.shape[0] != n_pred:
|
|
49
|
+
return None, _fail("this model/data configuration")
|
|
50
|
+
return X, None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _hypothesis_or_failure(plan):
|
|
54
|
+
if plan.hyp is None:
|
|
55
|
+
return None, None
|
|
56
|
+
if plan.hyp.kind != "matrix":
|
|
57
|
+
return None, _fail("this form of the `hypothesis` argument")
|
|
58
|
+
return np.asarray(plan.hyp.H, dtype=float), None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _has_nan(x) -> bool:
|
|
62
|
+
return x is not None and np.isnan(np.asarray(x, dtype=float)).any()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def lower_predictions(plan, model) -> Lowered:
|
|
66
|
+
args, failure = _model_args(model)
|
|
67
|
+
if failure is not None:
|
|
68
|
+
return failure
|
|
69
|
+
coefs, failure = _coefs_or_failure(model)
|
|
70
|
+
if failure is not None:
|
|
71
|
+
return failure
|
|
72
|
+
X, failure = _design_or_failure(plan.exog, coefs, plan.n_pred)
|
|
73
|
+
if failure is not None:
|
|
74
|
+
return failure
|
|
75
|
+
if plan.align is not None:
|
|
76
|
+
return _fail("models with grouped/multi-equation outcomes")
|
|
77
|
+
if plan.has_na:
|
|
78
|
+
return _fail("missing values in predictions")
|
|
79
|
+
H, failure = _hypothesis_or_failure(plan)
|
|
80
|
+
if failure is not None:
|
|
81
|
+
return failure
|
|
82
|
+
|
|
83
|
+
agg_segments = None
|
|
84
|
+
agg_num_segments = None
|
|
85
|
+
agg_weights = None
|
|
86
|
+
if plan.agg is not None:
|
|
87
|
+
agg_segments = np.empty(plan.n_pred, dtype=np.int32)
|
|
88
|
+
agg_weights = np.ones(plan.n_pred, dtype=float)
|
|
89
|
+
for i, group in enumerate(plan.agg):
|
|
90
|
+
agg_segments[group.idx] = i
|
|
91
|
+
if group.w is not None:
|
|
92
|
+
if _has_nan(group.w):
|
|
93
|
+
return _fail("missing values in weights")
|
|
94
|
+
agg_weights[group.idx] = np.asarray(group.w, dtype=float)
|
|
95
|
+
agg_num_segments = len(plan.agg)
|
|
96
|
+
|
|
97
|
+
kwargs = {
|
|
98
|
+
**args,
|
|
99
|
+
"X": X,
|
|
100
|
+
"agg_segments": agg_segments,
|
|
101
|
+
"agg_num_segments": agg_num_segments,
|
|
102
|
+
"agg_weights": agg_weights if agg_segments is not None else None,
|
|
103
|
+
"H": H,
|
|
104
|
+
}
|
|
105
|
+
return Lowered(True, kwargs=kwargs)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def lower_comparisons(plan, model) -> Lowered:
|
|
109
|
+
args, failure = _model_args(model)
|
|
110
|
+
if failure is not None:
|
|
111
|
+
return failure
|
|
112
|
+
coefs, failure = _coefs_or_failure(model)
|
|
113
|
+
if failure is not None:
|
|
114
|
+
return failure
|
|
115
|
+
X_hi, failure = _design_or_failure(plan.exog_hi, coefs, plan.n_pred)
|
|
116
|
+
if failure is not None:
|
|
117
|
+
return failure
|
|
118
|
+
X_lo, failure = _design_or_failure(plan.exog_lo, coefs, plan.n_pred)
|
|
119
|
+
if failure is not None:
|
|
120
|
+
return failure
|
|
121
|
+
if plan.align is not None:
|
|
122
|
+
return _fail("models with grouped/multi-equation outcomes")
|
|
123
|
+
if plan.has_na:
|
|
124
|
+
return _fail("missing values in predictions")
|
|
125
|
+
|
|
126
|
+
if any(group.fun_key is None for group in plan.groups):
|
|
127
|
+
return _fail("custom comparison functions")
|
|
128
|
+
if plan.need_y:
|
|
129
|
+
return _fail("elasticities")
|
|
130
|
+
for group in plan.groups:
|
|
131
|
+
if group.fun_key not in COMPARISON_OPS:
|
|
132
|
+
return _fail(f"comparison='{group.fun_key}'")
|
|
133
|
+
|
|
134
|
+
H, failure = _hypothesis_or_failure(plan)
|
|
135
|
+
if failure is not None:
|
|
136
|
+
return failure
|
|
137
|
+
|
|
138
|
+
if not plan.groups:
|
|
139
|
+
order = np.asarray([], dtype=int)
|
|
140
|
+
else:
|
|
141
|
+
order = np.concatenate([group.idx for group in plan.groups]).astype(int)
|
|
142
|
+
if order.size != plan.n_pred:
|
|
143
|
+
return _fail("this model/data configuration")
|
|
144
|
+
|
|
145
|
+
ops = []
|
|
146
|
+
for group in plan.groups:
|
|
147
|
+
spec = COMPARISON_OPS[group.fun_key]
|
|
148
|
+
w = None
|
|
149
|
+
if spec.weighted:
|
|
150
|
+
if _has_nan(group.w):
|
|
151
|
+
return _fail("missing values in weights")
|
|
152
|
+
w = None if group.w is None else np.asarray(group.w, dtype=float)
|
|
153
|
+
ops.append({"op": spec.pipeline_op, "n": len(group.idx), "w": w})
|
|
154
|
+
|
|
155
|
+
kwargs = {
|
|
156
|
+
**args,
|
|
157
|
+
"X_hi": X_hi[order],
|
|
158
|
+
"X_lo": X_lo[order],
|
|
159
|
+
"ops": ops,
|
|
160
|
+
"H": H,
|
|
161
|
+
}
|
|
162
|
+
return Lowered(True, kwargs=kwargs)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _warn_unsupported(reason):
|
|
166
|
+
if reason:
|
|
167
|
+
warnings.warn(
|
|
168
|
+
"Automatic differentiation does not support "
|
|
169
|
+
f"{reason}. Reverting to finite differences.",
|
|
170
|
+
UserWarning,
|
|
171
|
+
stacklevel=3,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def autodiff_try(plan, model, V, estimate, kind):
|
|
176
|
+
if plan is None or V is None or not is_autodiff_enabled():
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
warn_on_fallback = is_autodiff_forced()
|
|
180
|
+
lowered = (
|
|
181
|
+
lower_predictions(plan, model)
|
|
182
|
+
if kind == "predictions"
|
|
183
|
+
else lower_comparisons(plan, model)
|
|
184
|
+
)
|
|
185
|
+
if not lowered.ok:
|
|
186
|
+
if warn_on_fallback:
|
|
187
|
+
_warn_unsupported(lowered.reason)
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
from . import pipeline
|
|
192
|
+
|
|
193
|
+
result = pipeline.compute(beta=model.get_coef(), **lowered.kwargs)
|
|
194
|
+
except Exception as exc:
|
|
195
|
+
if warn_on_fallback:
|
|
196
|
+
warnings.warn(
|
|
197
|
+
"Automatic differentiation failed "
|
|
198
|
+
f"({exc}). Reverting to finite differences.",
|
|
199
|
+
UserWarning,
|
|
200
|
+
stacklevel=3,
|
|
201
|
+
)
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
estimate = np.asarray(estimate, dtype=float).reshape(-1)
|
|
205
|
+
if not np.allclose(result["estimate"], estimate, rtol=1e-8, atol=1e-8):
|
|
206
|
+
if warn_on_fallback:
|
|
207
|
+
warnings.warn(
|
|
208
|
+
"Automatic differentiation estimates did not match the standard "
|
|
209
|
+
"pipeline. Reverting to finite differences.",
|
|
210
|
+
UserWarning,
|
|
211
|
+
stacklevel=3,
|
|
212
|
+
)
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
# Coef/vcov positional alignment is guaranteed by each adapter vault.
|
|
216
|
+
J = np.asarray(result["jacobian"], dtype=float)
|
|
217
|
+
se = get_se(J, V)
|
|
218
|
+
se[se == 0] = np.nan
|
|
219
|
+
return AutodiffResult(std_error=se, jacobian=J)
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Shared comparison operation registry for autodiff lowering and execution."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
ArrayFn = Callable[[list[Any]], Any]
|
|
11
|
+
WMeanFn = Callable[[Any, Any | None], Any]
|
|
12
|
+
EstimateFn = Callable[[Any, Any, Any | None, ArrayFn, WMeanFn], Any]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class PipelineOp:
|
|
17
|
+
estimate: EstimateFn
|
|
18
|
+
scalar: bool
|
|
19
|
+
|
|
20
|
+
def output_size(self, n: int) -> int:
|
|
21
|
+
return 1 if self.scalar else n
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True)
|
|
25
|
+
class ComparisonOp:
|
|
26
|
+
pipeline_op: str
|
|
27
|
+
weighted: bool
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _difference(hi, lo, _w, _array, _wmean):
|
|
31
|
+
return hi - lo
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _ratio(hi, lo, _w, _array, _wmean):
|
|
35
|
+
return hi / lo
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _differenceavg(hi, lo, w, array, wmean):
|
|
39
|
+
return array([wmean(hi, w) - wmean(lo, w)])
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _ratioavg(hi, lo, w, array, wmean):
|
|
43
|
+
return array([wmean(hi, w) / wmean(lo, w)])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
PIPELINE_OPS = {
|
|
47
|
+
"difference": PipelineOp(estimate=_difference, scalar=False),
|
|
48
|
+
"ratio": PipelineOp(estimate=_ratio, scalar=False),
|
|
49
|
+
"differenceavg": PipelineOp(estimate=_differenceavg, scalar=True),
|
|
50
|
+
"ratioavg": PipelineOp(estimate=_ratioavg, scalar=True),
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
COMPARISON_OPS = {
|
|
54
|
+
"difference": ComparisonOp(pipeline_op="difference", weighted=False),
|
|
55
|
+
"ratio": ComparisonOp(pipeline_op="ratio", weighted=False),
|
|
56
|
+
"differenceavg": ComparisonOp(pipeline_op="differenceavg", weighted=False),
|
|
57
|
+
"ratioavg": ComparisonOp(pipeline_op="ratioavg", weighted=False),
|
|
58
|
+
"differenceavgwts": ComparisonOp(pipeline_op="differenceavg", weighted=True),
|
|
59
|
+
"ratioavgwts": ComparisonOp(pipeline_op="ratioavg", weighted=True),
|
|
60
|
+
}
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""Composable JAX pipeline for R-side plan lowering."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from functools import partial
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
|
|
11
|
+
from .glm.families import Family, Link, linkinv, resolve_link
|
|
12
|
+
from .ops import PIPELINE_OPS
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_FAMILY = {
|
|
16
|
+
"gaussian": Family.GAUSSIAN,
|
|
17
|
+
"binomial": Family.BINOMIAL,
|
|
18
|
+
"poisson": Family.POISSON,
|
|
19
|
+
"gamma": Family.GAMMA,
|
|
20
|
+
"Gamma": Family.GAMMA,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
_LINK = {
|
|
24
|
+
"identity": Link.IDENTITY,
|
|
25
|
+
"log": Link.LOG,
|
|
26
|
+
"logit": Link.LOGIT,
|
|
27
|
+
"probit": Link.PROBIT,
|
|
28
|
+
"inverse": Link.INVERSE,
|
|
29
|
+
"sqrt": Link.SQRT,
|
|
30
|
+
"cloglog": Link.CLOGLOG,
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _resolve_family_link(family: str | None, link: str | None) -> int | None:
|
|
35
|
+
if family is None and link is None:
|
|
36
|
+
return None
|
|
37
|
+
try:
|
|
38
|
+
family_type = _FAMILY[family]
|
|
39
|
+
link_type = _LINK[link] if link is not None else None
|
|
40
|
+
except KeyError as exc:
|
|
41
|
+
raise ValueError(f"Unsupported GLM family/link: {family}/{link}") from exc
|
|
42
|
+
return resolve_link(family_type, link_type)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _wmean(x, w):
|
|
46
|
+
if w is None:
|
|
47
|
+
return jnp.mean(x)
|
|
48
|
+
w = jnp.asarray(w, dtype=jnp.float64)
|
|
49
|
+
return jnp.sum(x * w) / jnp.sum(w)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _asarray1(x):
|
|
53
|
+
return jnp.asarray(x, dtype=jnp.float64)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _apply_agg(est, segments, num_segments, weights):
|
|
57
|
+
if segments is None:
|
|
58
|
+
return est
|
|
59
|
+
segments = jnp.asarray(segments, dtype=jnp.int32)
|
|
60
|
+
if weights is None:
|
|
61
|
+
weights = jnp.ones_like(est, dtype=jnp.float64)
|
|
62
|
+
else:
|
|
63
|
+
weights = jnp.asarray(weights, dtype=jnp.float64)
|
|
64
|
+
numer = jax.ops.segment_sum(est * weights, segments, num_segments=num_segments)
|
|
65
|
+
denom = jax.ops.segment_sum(weights, segments, num_segments=num_segments)
|
|
66
|
+
return numer / denom
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _comparison_estimates(mu_hi, mu_lo, ops_meta, ops_weights):
|
|
70
|
+
pieces = []
|
|
71
|
+
start = 0
|
|
72
|
+
w_iter = iter(ops_weights)
|
|
73
|
+
for name, n, has_w in ops_meta:
|
|
74
|
+
stop = start + n
|
|
75
|
+
hi = mu_hi[start:stop]
|
|
76
|
+
lo = mu_lo[start:stop]
|
|
77
|
+
w = next(w_iter) if has_w else None
|
|
78
|
+
spec = PIPELINE_OPS[name]
|
|
79
|
+
pieces.append(spec.estimate(hi, lo, w, _asarray1, _wmean))
|
|
80
|
+
start = stop
|
|
81
|
+
if not pieces:
|
|
82
|
+
return jnp.asarray([], dtype=jnp.float64)
|
|
83
|
+
return jnp.concatenate(pieces)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _predict(model_type, link_type, x, b):
|
|
87
|
+
eta = x @ b
|
|
88
|
+
if model_type == "glm":
|
|
89
|
+
return linkinv(link_type, eta)
|
|
90
|
+
return eta
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# The whole estimate+Jacobian computation is compiled with XLA. Traced arrays
|
|
94
|
+
# only affect the cache key through shape/dtype, so repeated calls at the same
|
|
95
|
+
# problem size recompile nothing; a new shape or plan structure pays one
|
|
96
|
+
# compilation.
|
|
97
|
+
@partial(
|
|
98
|
+
jax.jit,
|
|
99
|
+
static_argnames=(
|
|
100
|
+
"model_type",
|
|
101
|
+
"link_type",
|
|
102
|
+
"ops_meta",
|
|
103
|
+
"agg_num_segments",
|
|
104
|
+
"use_fwd",
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
def _estimate_and_jacobian(
|
|
108
|
+
beta,
|
|
109
|
+
model_type,
|
|
110
|
+
link_type,
|
|
111
|
+
X,
|
|
112
|
+
X_hi,
|
|
113
|
+
X_lo,
|
|
114
|
+
ops_meta,
|
|
115
|
+
ops_weights,
|
|
116
|
+
est_keep,
|
|
117
|
+
agg_segments,
|
|
118
|
+
agg_num_segments,
|
|
119
|
+
agg_weights,
|
|
120
|
+
H,
|
|
121
|
+
use_fwd,
|
|
122
|
+
):
|
|
123
|
+
def f(b):
|
|
124
|
+
if X is not None:
|
|
125
|
+
est = _predict(model_type, link_type, X, b)
|
|
126
|
+
else:
|
|
127
|
+
mu_hi = _predict(model_type, link_type, X_hi, b)
|
|
128
|
+
mu_lo = _predict(model_type, link_type, X_lo, b)
|
|
129
|
+
est = _comparison_estimates(mu_hi, mu_lo, ops_meta, ops_weights)
|
|
130
|
+
|
|
131
|
+
if est_keep is not None:
|
|
132
|
+
est = est[est_keep]
|
|
133
|
+
|
|
134
|
+
est = _apply_agg(est, agg_segments, agg_num_segments, agg_weights)
|
|
135
|
+
|
|
136
|
+
if H is not None:
|
|
137
|
+
est = est @ H
|
|
138
|
+
|
|
139
|
+
return jnp.atleast_1d(est)
|
|
140
|
+
|
|
141
|
+
estimate = f(beta)
|
|
142
|
+
jac_fun = jax.jacfwd if use_fwd else jax.jacrev
|
|
143
|
+
jacobian = jac_fun(f)(beta)
|
|
144
|
+
jacobian = jnp.reshape(jacobian, (estimate.size, beta.size))
|
|
145
|
+
return estimate, jacobian
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def compute(
|
|
149
|
+
beta,
|
|
150
|
+
model_type,
|
|
151
|
+
family=None,
|
|
152
|
+
link=None,
|
|
153
|
+
X=None,
|
|
154
|
+
X_hi=None,
|
|
155
|
+
X_lo=None,
|
|
156
|
+
ops=None,
|
|
157
|
+
est_keep=None,
|
|
158
|
+
agg_segments=None,
|
|
159
|
+
agg_num_segments=None,
|
|
160
|
+
agg_weights=None,
|
|
161
|
+
H=None,
|
|
162
|
+
):
|
|
163
|
+
"""Return estimate and Jacobian for a lowered autodiff plan."""
|
|
164
|
+
beta = jnp.asarray(beta, dtype=jnp.float64)
|
|
165
|
+
if model_type not in ("linear", "glm"):
|
|
166
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
|
167
|
+
link_type = _resolve_family_link(family, link) if model_type == "glm" else None
|
|
168
|
+
|
|
169
|
+
if X is not None:
|
|
170
|
+
X = jnp.asarray(X, dtype=jnp.float64)
|
|
171
|
+
if X_hi is not None:
|
|
172
|
+
X_hi = jnp.asarray(X_hi, dtype=jnp.float64)
|
|
173
|
+
if X_lo is not None:
|
|
174
|
+
X_lo = jnp.asarray(X_lo, dtype=jnp.float64)
|
|
175
|
+
if est_keep is not None:
|
|
176
|
+
est_keep = jnp.asarray(est_keep, dtype=jnp.int32)
|
|
177
|
+
if agg_segments is not None:
|
|
178
|
+
agg_segments = jnp.asarray(agg_segments, dtype=jnp.int32)
|
|
179
|
+
agg_num_segments = int(agg_num_segments)
|
|
180
|
+
else:
|
|
181
|
+
agg_num_segments = None
|
|
182
|
+
if agg_weights is not None:
|
|
183
|
+
agg_weights = jnp.asarray(agg_weights, dtype=jnp.float64)
|
|
184
|
+
if H is not None:
|
|
185
|
+
H = jnp.asarray(H, dtype=jnp.float64)
|
|
186
|
+
|
|
187
|
+
# Split ops into a hashable static structure (jit cache key) and traced
|
|
188
|
+
# weight arrays, while tallying the pre-keep estimate length.
|
|
189
|
+
ops_meta = ()
|
|
190
|
+
ops_weights = ()
|
|
191
|
+
if X is not None:
|
|
192
|
+
n_est = X.shape[0]
|
|
193
|
+
else:
|
|
194
|
+
n_est = 0
|
|
195
|
+
for op in ops or []:
|
|
196
|
+
name = op["op"]
|
|
197
|
+
if name not in PIPELINE_OPS:
|
|
198
|
+
raise ValueError(f"Unsupported comparison op: {name}")
|
|
199
|
+
spec = PIPELINE_OPS[name]
|
|
200
|
+
n = int(op["n"])
|
|
201
|
+
w = op.get("w")
|
|
202
|
+
ops_meta = ops_meta + ((name, n, w is not None),)
|
|
203
|
+
if w is not None:
|
|
204
|
+
ops_weights = ops_weights + (jnp.asarray(w, dtype=jnp.float64),)
|
|
205
|
+
n_est += spec.output_size(n)
|
|
206
|
+
|
|
207
|
+
# Output size is known from shapes alone; it picks forward vs reverse mode
|
|
208
|
+
# before tracing.
|
|
209
|
+
n_out = n_est
|
|
210
|
+
if est_keep is not None:
|
|
211
|
+
n_out = est_keep.shape[0]
|
|
212
|
+
if agg_segments is not None:
|
|
213
|
+
n_out = agg_num_segments
|
|
214
|
+
if H is not None:
|
|
215
|
+
n_out = H.shape[1]
|
|
216
|
+
use_fwd = beta.size <= max(n_out, 1)
|
|
217
|
+
|
|
218
|
+
estimate, jacobian = _estimate_and_jacobian(
|
|
219
|
+
beta,
|
|
220
|
+
model_type,
|
|
221
|
+
link_type,
|
|
222
|
+
X,
|
|
223
|
+
X_hi,
|
|
224
|
+
X_lo,
|
|
225
|
+
ops_meta,
|
|
226
|
+
ops_weights,
|
|
227
|
+
est_keep,
|
|
228
|
+
agg_segments,
|
|
229
|
+
agg_num_segments,
|
|
230
|
+
agg_weights,
|
|
231
|
+
H,
|
|
232
|
+
use_fwd,
|
|
233
|
+
)
|
|
234
|
+
return {
|
|
235
|
+
"estimate": np.asarray(estimate, dtype=np.float64),
|
|
236
|
+
"jacobian": np.asarray(jacobian, dtype=np.float64),
|
|
237
|
+
}
|