marginaleffects 0.1.5__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/PKG-INFO +3 -1
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/__init__.py +14 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/__init__.py +63 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/comparisons.py +58 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/glm/__init__.py +5 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/glm/comparisons.py +193 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/glm/families.py +108 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/glm/predictions.py +131 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/linear/__init__.py +4 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/linear/comparisons.py +145 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/linear/predictions.py +71 -0
- marginaleffects-0.2.1/marginaleffects/autodiff/utils.py +31 -0
- marginaleffects-0.2.1/marginaleffects/classes.py +64 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/comparisons.py +71 -27
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/datagrid.py +5 -3
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/docs.py +8 -1
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/formulaic_utils.py +32 -3
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/hypotheses.py +3 -3
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/hypotheses_joint.py +2 -2
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/hypothesis.py +26 -2
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_abstract.py +9 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_pyfixest.py +8 -3
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/plot_common.py +53 -1
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/plot_predictions.py +19 -1
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/predictions.py +37 -8
- marginaleffects-0.2.1/marginaleffects/result.py +218 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/sanity.py +111 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/slopes.py +4 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/utils.py +7 -9
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects.egg-info/PKG-INFO +3 -1
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects.egg-info/SOURCES.txt +21 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects.egg-info/requires.txt +3 -0
- marginaleffects-0.2.1/marginaleffects.egg-info/top_level.txt +6 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/__init__.py +9 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/comparisons.py +58 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/glm/__init__.py +5 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/glm/comparisons.py +140 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/glm/families.py +108 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/glm/predictions.py +114 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/linear/__init__.py +4 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/linear/comparisons.py +119 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/linear/predictions.py +62 -0
- marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/utils.py +31 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/pyproject.toml +9 -1
- marginaleffects-0.2.1/tests/test_bugfix.py +24 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_comparisons.py +94 -26
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_datagrid_01.py +25 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_datagrid_02.py +38 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_formula.py +2 -2
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_formulaic_utils.py +3 -2
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_hypotheses.py +3 -1
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_hypotheses_joint.py +3 -3
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_hypothesis.py +69 -2
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_jss.py +36 -29
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_plot_predictions.py +12 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_predictions.py +9 -6
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_pyfixest.py +52 -45
- marginaleffects-0.2.1/tests/test_sklearn.py +146 -0
- marginaleffects-0.1.5/marginaleffects/classes.py +0 -226
- marginaleffects-0.1.5/marginaleffects.egg-info/top_level.txt +0 -3
- marginaleffects-0.1.5/tests/test_bugfix.py +0 -15
- marginaleffects-0.1.5/tests/test_sklearn.py +0 -38
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/README.md +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/by.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/equivalence.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/estimands.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/hypothesis_formula.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/inject_docs.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_linearmodels.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_sklearn.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_statsmodels.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/plot_comparisons.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/plot_slopes.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/sanitize_model.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/transform.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/uncertainty.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/validation.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects.egg-info/dependency_links.txt +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/setup.cfg +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/__init__.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/helpers.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_analytic.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_by.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_categorical.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_equivalence.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_linearmodels_panelols.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_missing.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_newdata.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_plot_comparisons.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_plot_slopes.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_slopes.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_logit.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_mixedlm.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_mnlogit.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_negativebinomial.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_ols.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_poisson.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_probit.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_quantreg.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_wls.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_utils.py +0 -0
- {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/utilities.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: marginaleffects
|
|
3
|
-
Version: 0.1
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models.
|
|
5
5
|
License-Expression: GPL-3.0-or-later
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -14,6 +14,8 @@ Requires-Dist: pydantic>=2.10.3
|
|
|
14
14
|
Requires-Dist: plotnine>=0.14.5
|
|
15
15
|
Requires-Dist: scipy>=1.14.1
|
|
16
16
|
Requires-Dist: pyarrow>=19.0.1
|
|
17
|
+
Provides-Extra: autodiff
|
|
18
|
+
Requires-Dist: jax>=0.4.0; extra == "autodiff"
|
|
17
19
|
Provides-Extra: test
|
|
18
20
|
Requires-Dist: duckdb>=1.1.2; extra == "test"
|
|
19
21
|
Requires-Dist: matplotlib>=3.7.1; extra == "test"
|
|
@@ -10,6 +10,16 @@ from .plot_slopes import plot_slopes
|
|
|
10
10
|
from .predictions import avg_predictions, predictions
|
|
11
11
|
from .slopes import avg_slopes, slopes
|
|
12
12
|
from .utils import get_dataset
|
|
13
|
+
from .result import MarginaleffectsResult
|
|
14
|
+
|
|
15
|
+
# Conditionally import autodiff module if JAX is available
|
|
16
|
+
try:
|
|
17
|
+
from . import autodiff
|
|
18
|
+
|
|
19
|
+
_AUTODIFF_AVAILABLE = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
_AUTODIFF_AVAILABLE = False
|
|
22
|
+
autodiff = None
|
|
13
23
|
|
|
14
24
|
__all__ = [
|
|
15
25
|
"avg_comparisons",
|
|
@@ -27,4 +37,8 @@ __all__ = [
|
|
|
27
37
|
"fit_sklearn",
|
|
28
38
|
"fit_linearmodels",
|
|
29
39
|
"get_dataset",
|
|
40
|
+
"MarginaleffectsResult",
|
|
30
41
|
]
|
|
42
|
+
|
|
43
|
+
if _AUTODIFF_AVAILABLE:
|
|
44
|
+
__all__.append("autodiff")
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX-based automatic differentiation module for marginal effects.
|
|
3
|
+
|
|
4
|
+
This module provides high-performance computation of predictions, comparisons,
|
|
5
|
+
and standard errors using JAX's automatic differentiation capabilities.
|
|
6
|
+
|
|
7
|
+
Note: This module requires JAX to be installed. Install with:
|
|
8
|
+
pip install marginaleffects[autodiff]
|
|
9
|
+
or
|
|
10
|
+
pip install jax
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import importlib.util
|
|
14
|
+
|
|
15
|
+
_JAX_AVAILABLE = importlib.util.find_spec("jax") is not None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _raise_jax_error():
|
|
19
|
+
raise ImportError(
|
|
20
|
+
"The autodiff module requires JAX to be installed. "
|
|
21
|
+
"Install with: pip install marginaleffects[autodiff] or pip install jax"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
if _JAX_AVAILABLE:
|
|
26
|
+
# Import submodules to make them accessible
|
|
27
|
+
from . import linear as linear
|
|
28
|
+
from . import glm as glm
|
|
29
|
+
|
|
30
|
+
# Re-export types for convenience
|
|
31
|
+
from .comparisons import ComparisonType as ComparisonType
|
|
32
|
+
from .glm.families import Family as Family, Link as Link
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"linear",
|
|
36
|
+
"glm",
|
|
37
|
+
"ComparisonType",
|
|
38
|
+
"Family",
|
|
39
|
+
"Link",
|
|
40
|
+
]
|
|
41
|
+
else:
|
|
42
|
+
# Create dummy module objects that raise helpful errors
|
|
43
|
+
class _DummyModule:
|
|
44
|
+
def __getattr__(self, name):
|
|
45
|
+
_raise_jax_error()
|
|
46
|
+
|
|
47
|
+
linear = _DummyModule()
|
|
48
|
+
glm = _DummyModule()
|
|
49
|
+
|
|
50
|
+
# Create dummy enums that raise errors
|
|
51
|
+
class ComparisonType:
|
|
52
|
+
def __getattribute__(self, name):
|
|
53
|
+
_raise_jax_error()
|
|
54
|
+
|
|
55
|
+
class Family:
|
|
56
|
+
def __getattribute__(self, name):
|
|
57
|
+
_raise_jax_error()
|
|
58
|
+
|
|
59
|
+
class Link:
|
|
60
|
+
def __getattribute__(self, name):
|
|
61
|
+
_raise_jax_error()
|
|
62
|
+
|
|
63
|
+
__all__ = []
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Comparison types and functions using enum-based approach for JAX compatibility."""
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jax import lax
|
|
5
|
+
from enum import IntEnum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ComparisonType(IntEnum):
|
|
9
|
+
"""Comparison types for marginal effects."""
|
|
10
|
+
|
|
11
|
+
DIFFERENCE = 0
|
|
12
|
+
RATIO = 1
|
|
13
|
+
LNRATIO = 2
|
|
14
|
+
LNOR = 3
|
|
15
|
+
LIFT = 4
|
|
16
|
+
DIFFERENCEAVG = 5
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _compute_comparison_vector(
|
|
20
|
+
comparison_type: int, pred_hi: jnp.ndarray, pred_lo: jnp.ndarray
|
|
21
|
+
) -> jnp.ndarray:
|
|
22
|
+
"""Apply comparison function element-wise (returns N-length array)."""
|
|
23
|
+
return lax.switch(
|
|
24
|
+
comparison_type,
|
|
25
|
+
[
|
|
26
|
+
lambda hi, lo: hi - lo, # difference
|
|
27
|
+
lambda hi, lo: hi / lo, # ratio
|
|
28
|
+
lambda hi, lo: jnp.log(hi / lo), # lnratio
|
|
29
|
+
lambda hi, lo: jnp.log((hi / (1 - hi)) / (lo / (1 - lo))), # lnor
|
|
30
|
+
lambda hi, lo: (hi - lo) / lo, # lift
|
|
31
|
+
],
|
|
32
|
+
pred_hi,
|
|
33
|
+
pred_lo,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _compute_comparison_scalar(
|
|
38
|
+
comparison_type: int, pred_hi: jnp.ndarray, pred_lo: jnp.ndarray
|
|
39
|
+
) -> jnp.ndarray:
|
|
40
|
+
"""Apply comparison function with aggregation (returns scalar)."""
|
|
41
|
+
return lax.switch(
|
|
42
|
+
comparison_type,
|
|
43
|
+
[
|
|
44
|
+
lambda hi, lo: jnp.mean(hi) - jnp.mean(lo), # differenceavg
|
|
45
|
+
],
|
|
46
|
+
pred_hi,
|
|
47
|
+
pred_lo,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _compute_comparison(
|
|
52
|
+
comparison_type: int, pred_hi: jnp.ndarray, pred_lo: jnp.ndarray
|
|
53
|
+
) -> jnp.ndarray:
|
|
54
|
+
"""Apply comparison function element-wise (returns N-length array).
|
|
55
|
+
|
|
56
|
+
Note: DIFFERENCEAVG should use _compute_comparison_scalar directly.
|
|
57
|
+
"""
|
|
58
|
+
return _compute_comparison_vector(comparison_type, pred_hi, pred_lo)
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import numpy as np
|
|
3
|
+
from jax import jit, jacrev
|
|
4
|
+
from .families import linkinv, resolve_link
|
|
5
|
+
from ..comparisons import (
|
|
6
|
+
_compute_comparison,
|
|
7
|
+
_compute_comparison_scalar,
|
|
8
|
+
ComparisonType,
|
|
9
|
+
)
|
|
10
|
+
from ..utils import group_reducer, standard_errors
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _comparison_core(
|
|
14
|
+
beta: jnp.ndarray,
|
|
15
|
+
X_hi: jnp.ndarray,
|
|
16
|
+
X_lo: jnp.ndarray,
|
|
17
|
+
comparison_type: int,
|
|
18
|
+
family_type: int,
|
|
19
|
+
link_type: int,
|
|
20
|
+
) -> jnp.ndarray:
|
|
21
|
+
"""Core comparison function - single source of truth for comparison vector computation."""
|
|
22
|
+
pred_hi = linkinv(link_type, X_hi @ beta)
|
|
23
|
+
pred_lo = linkinv(link_type, X_lo @ beta)
|
|
24
|
+
return _compute_comparison(comparison_type, pred_hi, pred_lo)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@jit
|
|
28
|
+
def _comparison_byT(
|
|
29
|
+
beta: jnp.ndarray,
|
|
30
|
+
X_hi: jnp.ndarray,
|
|
31
|
+
X_lo: jnp.ndarray,
|
|
32
|
+
comparison_type: int,
|
|
33
|
+
family_type: int,
|
|
34
|
+
link_type: int = None,
|
|
35
|
+
) -> jnp.ndarray:
|
|
36
|
+
comp = _comparison_core(beta, X_hi, X_lo, comparison_type, family_type, link_type)
|
|
37
|
+
return jnp.mean(comp)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _comparison_byG(
|
|
41
|
+
beta: jnp.ndarray,
|
|
42
|
+
X_hi: jnp.ndarray,
|
|
43
|
+
X_lo: jnp.ndarray,
|
|
44
|
+
groups: jnp.ndarray,
|
|
45
|
+
num_groups: int,
|
|
46
|
+
comparison_type: int,
|
|
47
|
+
family_type: int,
|
|
48
|
+
link_type: int = None,
|
|
49
|
+
) -> jnp.ndarray:
|
|
50
|
+
comp = _comparison_core(beta, X_hi, X_lo, comparison_type, family_type, link_type)
|
|
51
|
+
return group_reducer(comp, groups, num_groups)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@jit
|
|
55
|
+
def _comparisons_core(
|
|
56
|
+
beta: jnp.ndarray,
|
|
57
|
+
X_hi: jnp.ndarray,
|
|
58
|
+
X_lo: jnp.ndarray,
|
|
59
|
+
comparison_type: int,
|
|
60
|
+
family_type: int,
|
|
61
|
+
link_type: int = None,
|
|
62
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
63
|
+
comp = _comparison_core(beta, X_hi, X_lo, comparison_type, family_type, link_type)
|
|
64
|
+
jac = jacrev(
|
|
65
|
+
lambda b: _comparison_core(
|
|
66
|
+
b, X_hi, X_lo, comparison_type, family_type, link_type
|
|
67
|
+
)
|
|
68
|
+
)(beta)
|
|
69
|
+
return comp, jac
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def comparisons(
|
|
73
|
+
beta: jnp.ndarray,
|
|
74
|
+
X_hi: jnp.ndarray,
|
|
75
|
+
X_lo: jnp.ndarray,
|
|
76
|
+
vcov: jnp.ndarray,
|
|
77
|
+
comparison_type: int,
|
|
78
|
+
family_type: int,
|
|
79
|
+
link_type: int = None,
|
|
80
|
+
) -> dict[str, np.ndarray]:
|
|
81
|
+
link_type = resolve_link(family_type, link_type)
|
|
82
|
+
|
|
83
|
+
# Handle DIFFERENCEAVG separately (returns scalar)
|
|
84
|
+
if comparison_type == ComparisonType.DIFFERENCEAVG:
|
|
85
|
+
|
|
86
|
+
@jit
|
|
87
|
+
def _scalar_core(b, X_h, X_l, lt):
|
|
88
|
+
pred_hi = linkinv(lt, X_h @ b)
|
|
89
|
+
pred_lo = linkinv(lt, X_l @ b)
|
|
90
|
+
return _compute_comparison_scalar(0, pred_hi, pred_lo)
|
|
91
|
+
|
|
92
|
+
comp = _scalar_core(beta, X_hi, X_lo, link_type)
|
|
93
|
+
jac = jacrev(lambda b: _scalar_core(b, X_hi, X_lo, link_type))(beta)
|
|
94
|
+
se = standard_errors(jac.reshape(1, -1), vcov)
|
|
95
|
+
return {
|
|
96
|
+
"estimate": np.array(comp),
|
|
97
|
+
"jacobian": np.array(jac),
|
|
98
|
+
"std_error": se[0],
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
# Handle element-wise comparisons
|
|
102
|
+
comp, jac = _comparisons_core(
|
|
103
|
+
beta, X_hi, X_lo, comparison_type, family_type, link_type
|
|
104
|
+
)
|
|
105
|
+
se = standard_errors(jac, vcov)
|
|
106
|
+
return {
|
|
107
|
+
"estimate": np.array(comp),
|
|
108
|
+
"jacobian": np.array(jac),
|
|
109
|
+
"std_error": se,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@jit
|
|
114
|
+
def _comparisons_byT_core(
|
|
115
|
+
beta: jnp.ndarray,
|
|
116
|
+
X_hi: jnp.ndarray,
|
|
117
|
+
X_lo: jnp.ndarray,
|
|
118
|
+
comparison_type: int,
|
|
119
|
+
family_type: int,
|
|
120
|
+
link_type: int = None,
|
|
121
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
122
|
+
comp = _comparison_byT(beta, X_hi, X_lo, comparison_type, family_type, link_type)
|
|
123
|
+
jac = jacrev(
|
|
124
|
+
lambda b: _comparison_byT(
|
|
125
|
+
b, X_hi, X_lo, comparison_type, family_type, link_type
|
|
126
|
+
)
|
|
127
|
+
)(beta)
|
|
128
|
+
return comp, jac
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def comparisons_byT(
|
|
132
|
+
beta: jnp.ndarray,
|
|
133
|
+
X_hi: jnp.ndarray,
|
|
134
|
+
X_lo: jnp.ndarray,
|
|
135
|
+
vcov: jnp.ndarray,
|
|
136
|
+
comparison_type: int,
|
|
137
|
+
family_type: int,
|
|
138
|
+
link_type: int = None,
|
|
139
|
+
) -> dict[str, np.ndarray]:
|
|
140
|
+
link_type = resolve_link(family_type, link_type)
|
|
141
|
+
comp, jac = _comparisons_byT_core(
|
|
142
|
+
beta, X_hi, X_lo, comparison_type, family_type, link_type
|
|
143
|
+
)
|
|
144
|
+
se = standard_errors(jac.reshape(1, -1), vcov)
|
|
145
|
+
return {
|
|
146
|
+
"estimate": np.array(comp),
|
|
147
|
+
"jacobian": np.array(jac),
|
|
148
|
+
"std_error": se[0],
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _comparisons_byG_core(
|
|
153
|
+
beta: jnp.ndarray,
|
|
154
|
+
X_hi: jnp.ndarray,
|
|
155
|
+
X_lo: jnp.ndarray,
|
|
156
|
+
groups: jnp.ndarray,
|
|
157
|
+
num_groups: int,
|
|
158
|
+
comparison_type: int,
|
|
159
|
+
family_type: int,
|
|
160
|
+
link_type: int = None,
|
|
161
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
162
|
+
comp = _comparison_byG(
|
|
163
|
+
beta, X_hi, X_lo, groups, num_groups, comparison_type, family_type, link_type
|
|
164
|
+
)
|
|
165
|
+
jac = jacrev(
|
|
166
|
+
lambda b: _comparison_byG(
|
|
167
|
+
b, X_hi, X_lo, groups, num_groups, comparison_type, family_type, link_type
|
|
168
|
+
)
|
|
169
|
+
)(beta)
|
|
170
|
+
return comp, jac
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def comparisons_byG(
|
|
174
|
+
beta: jnp.ndarray,
|
|
175
|
+
X_hi: jnp.ndarray,
|
|
176
|
+
X_lo: jnp.ndarray,
|
|
177
|
+
vcov: jnp.ndarray,
|
|
178
|
+
groups: jnp.ndarray,
|
|
179
|
+
num_groups: int,
|
|
180
|
+
comparison_type: int,
|
|
181
|
+
family_type: int,
|
|
182
|
+
link_type: int = None,
|
|
183
|
+
) -> dict[str, np.ndarray]:
|
|
184
|
+
link_type = resolve_link(family_type, link_type)
|
|
185
|
+
comp, jac = _comparisons_byG_core(
|
|
186
|
+
beta, X_hi, X_lo, groups, num_groups, comparison_type, family_type, link_type
|
|
187
|
+
)
|
|
188
|
+
se = standard_errors(jac, vcov)
|
|
189
|
+
return {
|
|
190
|
+
"estimate": np.array(comp),
|
|
191
|
+
"jacobian": np.array(jac),
|
|
192
|
+
"std_error": se,
|
|
193
|
+
}
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""GLM families and link functions using enum-based approach for JAX compatibility."""
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jax import lax
|
|
5
|
+
from jax.scipy.stats import norm
|
|
6
|
+
from enum import IntEnum
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Family(IntEnum):
|
|
10
|
+
"""GLM family types."""
|
|
11
|
+
|
|
12
|
+
GAUSSIAN = 0
|
|
13
|
+
BINOMIAL = 1
|
|
14
|
+
POISSON = 2
|
|
15
|
+
GAMMA = 3
|
|
16
|
+
INVERSE_GAUSSIAN = 4
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Link(IntEnum):
|
|
20
|
+
"""Link function types."""
|
|
21
|
+
|
|
22
|
+
IDENTITY = 0
|
|
23
|
+
LOG = 1
|
|
24
|
+
LOGIT = 2
|
|
25
|
+
PROBIT = 3
|
|
26
|
+
INVERSE = 4
|
|
27
|
+
SQRT = 5
|
|
28
|
+
CLOGLOG = 6
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Default links for each family
|
|
32
|
+
DEFAULT_LINKS = {
|
|
33
|
+
Family.GAUSSIAN: Link.IDENTITY,
|
|
34
|
+
Family.BINOMIAL: Link.LOGIT,
|
|
35
|
+
Family.POISSON: Link.LOG,
|
|
36
|
+
Family.GAMMA: Link.INVERSE,
|
|
37
|
+
Family.INVERSE_GAUSSIAN: Link.INVERSE,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def linkinv(link_type: int, eta: jnp.ndarray) -> jnp.ndarray:
|
|
42
|
+
"""Inverse link function: transform linear predictor to mean."""
|
|
43
|
+
return lax.switch(
|
|
44
|
+
link_type,
|
|
45
|
+
[
|
|
46
|
+
lambda x: x, # identity
|
|
47
|
+
lambda x: jnp.exp(x), # log
|
|
48
|
+
lambda x: 1 / (1 + jnp.exp(-x)), # logit
|
|
49
|
+
lambda x: norm.cdf(x), # probit
|
|
50
|
+
lambda x: 1.0 / x, # inverse
|
|
51
|
+
lambda x: x**2, # sqrt
|
|
52
|
+
lambda x: 1 - jnp.exp(-jnp.exp(x)), # cloglog
|
|
53
|
+
],
|
|
54
|
+
eta,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def linkfun(link_type: int, mu: jnp.ndarray) -> jnp.ndarray:
|
|
59
|
+
"""Link function: transform mean to linear predictor."""
|
|
60
|
+
return lax.switch(
|
|
61
|
+
link_type,
|
|
62
|
+
[
|
|
63
|
+
lambda x: x, # identity
|
|
64
|
+
lambda x: jnp.log(x), # log
|
|
65
|
+
lambda x: jnp.log(x / (1 - x)), # logit
|
|
66
|
+
lambda x: norm.ppf(x), # probit
|
|
67
|
+
lambda x: 1.0 / x, # inverse
|
|
68
|
+
lambda x: jnp.sqrt(x), # sqrt
|
|
69
|
+
lambda x: jnp.log(-jnp.log(1 - x)), # cloglog
|
|
70
|
+
],
|
|
71
|
+
mu,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# Valid link functions for each family
|
|
76
|
+
VALID_LINKS = {
|
|
77
|
+
Family.GAUSSIAN: [Link.IDENTITY, Link.LOG, Link.INVERSE],
|
|
78
|
+
Family.BINOMIAL: [Link.LOGIT, Link.PROBIT, Link.CLOGLOG, Link.LOG],
|
|
79
|
+
Family.POISSON: [Link.LOG, Link.IDENTITY, Link.SQRT],
|
|
80
|
+
Family.GAMMA: [Link.INVERSE, Link.IDENTITY, Link.LOG],
|
|
81
|
+
Family.INVERSE_GAUSSIAN: [Link.INVERSE, Link.IDENTITY, Link.LOG],
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def validate_family_link(family_type: int, link_type: int) -> bool:
|
|
86
|
+
"""Check if link function is valid for the given family."""
|
|
87
|
+
if family_type not in VALID_LINKS:
|
|
88
|
+
return False
|
|
89
|
+
return link_type in VALID_LINKS[family_type]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def resolve_link(family_type: int, link_type: int = None) -> int:
|
|
93
|
+
"""Resolve link type, using default if None, and validate the combination."""
|
|
94
|
+
if link_type is None:
|
|
95
|
+
return DEFAULT_LINKS.get(family_type, Link.IDENTITY)
|
|
96
|
+
if not validate_family_link(family_type, link_type):
|
|
97
|
+
default = DEFAULT_LINKS.get(family_type, Link.IDENTITY)
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"Invalid link {link_type} for family {family_type}. "
|
|
100
|
+
f"Valid links: {VALID_LINKS.get(family_type, [])}. "
|
|
101
|
+
f"Using default: {default}"
|
|
102
|
+
)
|
|
103
|
+
return link_type
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# Convenience instances for direct use
|
|
107
|
+
family = Family
|
|
108
|
+
link = Link
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import numpy as np
|
|
3
|
+
from jax import jacfwd, jacrev, jit
|
|
4
|
+
from .families import linkinv, resolve_link
|
|
5
|
+
from ..utils import group_reducer, standard_errors
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _predict_core(
|
|
9
|
+
beta: jnp.ndarray,
|
|
10
|
+
X: jnp.ndarray,
|
|
11
|
+
family_type: int,
|
|
12
|
+
link_type: int,
|
|
13
|
+
) -> jnp.ndarray:
|
|
14
|
+
"""Core prediction function - single source of truth for prediction computation."""
|
|
15
|
+
linear_pred = X @ beta
|
|
16
|
+
return linkinv(link_type, linear_pred)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@jit
|
|
20
|
+
def _predict(
|
|
21
|
+
beta: jnp.ndarray, X: jnp.ndarray, family_type: int, link_type: int = None
|
|
22
|
+
) -> jnp.ndarray:
|
|
23
|
+
return _predict_core(beta, X, family_type, link_type)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@jit
|
|
27
|
+
def _predict_byT(
|
|
28
|
+
beta: jnp.ndarray, X: jnp.ndarray, family_type: int, link_type: int = None
|
|
29
|
+
) -> jnp.ndarray:
|
|
30
|
+
pred = _predict_core(beta, X, family_type, link_type)
|
|
31
|
+
return jnp.mean(pred)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _predict_byG(
|
|
35
|
+
beta: jnp.ndarray,
|
|
36
|
+
X: jnp.ndarray,
|
|
37
|
+
groups: jnp.ndarray,
|
|
38
|
+
num_groups: int,
|
|
39
|
+
family_type: int,
|
|
40
|
+
link_type: int = None,
|
|
41
|
+
) -> jnp.ndarray:
|
|
42
|
+
pred = _predict_core(beta, X, family_type, link_type)
|
|
43
|
+
return group_reducer(pred, groups, num_groups)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@jit
|
|
47
|
+
def _predictions_core(
|
|
48
|
+
beta: jnp.ndarray, X: jnp.ndarray, family_type: int, link_type: int = None
|
|
49
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
50
|
+
pred = _predict_core(beta, X, family_type, link_type)
|
|
51
|
+
jac = jacfwd(_predict, argnums=0)(beta, X, family_type, link_type)
|
|
52
|
+
return pred, jac
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def predictions(
|
|
56
|
+
beta: jnp.ndarray,
|
|
57
|
+
X: jnp.ndarray,
|
|
58
|
+
vcov: jnp.ndarray,
|
|
59
|
+
family_type: int,
|
|
60
|
+
link_type: int = None,
|
|
61
|
+
) -> dict[str, np.ndarray]:
|
|
62
|
+
link_type = resolve_link(family_type, link_type)
|
|
63
|
+
pred, jac = _predictions_core(beta, X, family_type, link_type)
|
|
64
|
+
se = standard_errors(jac, vcov)
|
|
65
|
+
return {
|
|
66
|
+
"estimate": np.array(pred),
|
|
67
|
+
"jacobian": np.array(jac),
|
|
68
|
+
"std_error": se,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@jit
|
|
73
|
+
def _predictions_byT_core(
|
|
74
|
+
beta: jnp.ndarray, X: jnp.ndarray, family_type: int, link_type: int = None
|
|
75
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
76
|
+
pred = _predict_byT(beta, X, family_type, link_type)
|
|
77
|
+
jac = jacrev(_predict_byT, argnums=0)(beta, X, family_type, link_type)
|
|
78
|
+
return pred, jac
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def predictions_byT(
|
|
82
|
+
beta: jnp.ndarray,
|
|
83
|
+
X: jnp.ndarray,
|
|
84
|
+
vcov: jnp.ndarray,
|
|
85
|
+
family_type: int,
|
|
86
|
+
link_type: int = None,
|
|
87
|
+
) -> dict[str, np.ndarray]:
|
|
88
|
+
link_type = resolve_link(family_type, link_type)
|
|
89
|
+
pred, jac = _predictions_byT_core(beta, X, family_type, link_type)
|
|
90
|
+
se = standard_errors(jac.reshape(1, -1), vcov)
|
|
91
|
+
return {
|
|
92
|
+
"estimate": np.array(pred),
|
|
93
|
+
"jacobian": np.array(jac),
|
|
94
|
+
"std_error": se[0],
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _predictions_byG_core(
|
|
99
|
+
beta: jnp.ndarray,
|
|
100
|
+
X: jnp.ndarray,
|
|
101
|
+
groups: jnp.ndarray,
|
|
102
|
+
num_groups: int,
|
|
103
|
+
family_type: int,
|
|
104
|
+
link_type: int = None,
|
|
105
|
+
) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
106
|
+
pred = _predict_byG(beta, X, groups, num_groups, family_type, link_type)
|
|
107
|
+
jac = jacrev(
|
|
108
|
+
lambda b: _predict_byG(b, X, groups, num_groups, family_type, link_type)
|
|
109
|
+
)(beta)
|
|
110
|
+
return pred, jac
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def predictions_byG(
|
|
114
|
+
beta: jnp.ndarray,
|
|
115
|
+
X: jnp.ndarray,
|
|
116
|
+
vcov: jnp.ndarray,
|
|
117
|
+
groups: jnp.ndarray,
|
|
118
|
+
num_groups: int,
|
|
119
|
+
family_type: int,
|
|
120
|
+
link_type: int = None,
|
|
121
|
+
) -> dict[str, np.ndarray]:
|
|
122
|
+
link_type = resolve_link(family_type, link_type)
|
|
123
|
+
pred, jac = _predictions_byG_core(
|
|
124
|
+
beta, X, groups, num_groups, family_type, link_type
|
|
125
|
+
)
|
|
126
|
+
se = standard_errors(jac, vcov)
|
|
127
|
+
return {
|
|
128
|
+
"estimate": np.array(pred),
|
|
129
|
+
"jacobian": np.array(jac),
|
|
130
|
+
"std_error": se,
|
|
131
|
+
}
|