marginaleffects 0.1.5__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/PKG-INFO +3 -1
  2. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/__init__.py +14 -0
  3. marginaleffects-0.2.1/marginaleffects/autodiff/__init__.py +63 -0
  4. marginaleffects-0.2.1/marginaleffects/autodiff/comparisons.py +58 -0
  5. marginaleffects-0.2.1/marginaleffects/autodiff/glm/__init__.py +5 -0
  6. marginaleffects-0.2.1/marginaleffects/autodiff/glm/comparisons.py +193 -0
  7. marginaleffects-0.2.1/marginaleffects/autodiff/glm/families.py +108 -0
  8. marginaleffects-0.2.1/marginaleffects/autodiff/glm/predictions.py +131 -0
  9. marginaleffects-0.2.1/marginaleffects/autodiff/linear/__init__.py +4 -0
  10. marginaleffects-0.2.1/marginaleffects/autodiff/linear/comparisons.py +145 -0
  11. marginaleffects-0.2.1/marginaleffects/autodiff/linear/predictions.py +71 -0
  12. marginaleffects-0.2.1/marginaleffects/autodiff/utils.py +31 -0
  13. marginaleffects-0.2.1/marginaleffects/classes.py +64 -0
  14. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/comparisons.py +71 -27
  15. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/datagrid.py +5 -3
  16. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/docs.py +8 -1
  17. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/formulaic_utils.py +32 -3
  18. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/hypotheses.py +3 -3
  19. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/hypotheses_joint.py +2 -2
  20. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/hypothesis.py +26 -2
  21. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_abstract.py +9 -0
  22. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_pyfixest.py +8 -3
  23. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/plot_common.py +53 -1
  24. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/plot_predictions.py +19 -1
  25. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/predictions.py +37 -8
  26. marginaleffects-0.2.1/marginaleffects/result.py +218 -0
  27. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/sanity.py +111 -0
  28. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/slopes.py +4 -0
  29. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/utils.py +7 -9
  30. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects.egg-info/PKG-INFO +3 -1
  31. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects.egg-info/SOURCES.txt +21 -0
  32. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects.egg-info/requires.txt +3 -0
  33. marginaleffects-0.2.1/marginaleffects.egg-info/top_level.txt +6 -0
  34. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/__init__.py +9 -0
  35. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/comparisons.py +58 -0
  36. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/glm/__init__.py +5 -0
  37. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/glm/comparisons.py +140 -0
  38. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/glm/families.py +108 -0
  39. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/glm/predictions.py +114 -0
  40. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/linear/__init__.py +4 -0
  41. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/linear/comparisons.py +119 -0
  42. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/linear/predictions.py +62 -0
  43. marginaleffects-0.2.1/marginaleffectsAD/build/lib/marginaleffectsAD/utils.py +31 -0
  44. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/pyproject.toml +9 -1
  45. marginaleffects-0.2.1/tests/test_bugfix.py +24 -0
  46. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_comparisons.py +94 -26
  47. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_datagrid_01.py +25 -0
  48. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_datagrid_02.py +38 -0
  49. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_formula.py +2 -2
  50. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_formulaic_utils.py +3 -2
  51. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_hypotheses.py +3 -1
  52. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_hypotheses_joint.py +3 -3
  53. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_hypothesis.py +69 -2
  54. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_jss.py +36 -29
  55. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_plot_predictions.py +12 -0
  56. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_predictions.py +9 -6
  57. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_pyfixest.py +52 -45
  58. marginaleffects-0.2.1/tests/test_sklearn.py +146 -0
  59. marginaleffects-0.1.5/marginaleffects/classes.py +0 -226
  60. marginaleffects-0.1.5/marginaleffects.egg-info/top_level.txt +0 -3
  61. marginaleffects-0.1.5/tests/test_bugfix.py +0 -15
  62. marginaleffects-0.1.5/tests/test_sklearn.py +0 -38
  63. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/README.md +0 -0
  64. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/by.py +0 -0
  65. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/equivalence.py +0 -0
  66. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/estimands.py +0 -0
  67. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/hypothesis_formula.py +0 -0
  68. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/inject_docs.py +0 -0
  69. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_linearmodels.py +0 -0
  70. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_sklearn.py +0 -0
  71. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/model_statsmodels.py +0 -0
  72. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/plot_comparisons.py +0 -0
  73. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/plot_slopes.py +0 -0
  74. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/sanitize_model.py +0 -0
  75. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/transform.py +0 -0
  76. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/uncertainty.py +0 -0
  77. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects/validation.py +0 -0
  78. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/marginaleffects.egg-info/dependency_links.txt +0 -0
  79. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/setup.cfg +0 -0
  80. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/__init__.py +0 -0
  81. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/helpers.py +0 -0
  82. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_analytic.py +0 -0
  83. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_by.py +0 -0
  84. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_categorical.py +0 -0
  85. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_equivalence.py +0 -0
  86. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_linearmodels_panelols.py +0 -0
  87. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_missing.py +0 -0
  88. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_newdata.py +0 -0
  89. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_plot_comparisons.py +0 -0
  90. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_plot_slopes.py +0 -0
  91. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_slopes.py +0 -0
  92. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels.py +0 -0
  93. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_logit.py +0 -0
  94. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_mixedlm.py +0 -0
  95. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_mnlogit.py +0 -0
  96. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_negativebinomial.py +0 -0
  97. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_ols.py +0 -0
  98. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_poisson.py +0 -0
  99. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_probit.py +0 -0
  100. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_quantreg.py +0 -0
  101. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_statsmodels_wls.py +0 -0
  102. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/test_utils.py +0 -0
  103. {marginaleffects-0.1.5 → marginaleffects-0.2.1}/tests/utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: marginaleffects
3
- Version: 0.1.5
3
+ Version: 0.2.1
4
4
  Summary: Predictions, counterfactual comparisons, slopes, and hypothesis tests for statistical models.
5
5
  License-Expression: GPL-3.0-or-later
6
6
  Requires-Python: >=3.10
@@ -14,6 +14,8 @@ Requires-Dist: pydantic>=2.10.3
14
14
  Requires-Dist: plotnine>=0.14.5
15
15
  Requires-Dist: scipy>=1.14.1
16
16
  Requires-Dist: pyarrow>=19.0.1
17
+ Provides-Extra: autodiff
18
+ Requires-Dist: jax>=0.4.0; extra == "autodiff"
17
19
  Provides-Extra: test
18
20
  Requires-Dist: duckdb>=1.1.2; extra == "test"
19
21
  Requires-Dist: matplotlib>=3.7.1; extra == "test"
@@ -10,6 +10,16 @@ from .plot_slopes import plot_slopes
10
10
  from .predictions import avg_predictions, predictions
11
11
  from .slopes import avg_slopes, slopes
12
12
  from .utils import get_dataset
13
+ from .result import MarginaleffectsResult
14
+
15
+ # Conditionally import autodiff module if JAX is available
16
+ try:
17
+ from . import autodiff
18
+
19
+ _AUTODIFF_AVAILABLE = True
20
+ except ImportError:
21
+ _AUTODIFF_AVAILABLE = False
22
+ autodiff = None
13
23
 
14
24
  __all__ = [
15
25
  "avg_comparisons",
@@ -27,4 +37,8 @@ __all__ = [
27
37
  "fit_sklearn",
28
38
  "fit_linearmodels",
29
39
  "get_dataset",
40
+ "MarginaleffectsResult",
30
41
  ]
42
+
43
+ if _AUTODIFF_AVAILABLE:
44
+ __all__.append("autodiff")
@@ -0,0 +1,63 @@
1
+ """
2
+ JAX-based automatic differentiation module for marginal effects.
3
+
4
+ This module provides high-performance computation of predictions, comparisons,
5
+ and standard errors using JAX's automatic differentiation capabilities.
6
+
7
+ Note: This module requires JAX to be installed. Install with:
8
+ pip install marginaleffects[autodiff]
9
+ or
10
+ pip install jax
11
+ """
12
+
13
+ import importlib.util
14
+
15
+ _JAX_AVAILABLE = importlib.util.find_spec("jax") is not None
16
+
17
+
18
+ def _raise_jax_error():
19
+ raise ImportError(
20
+ "The autodiff module requires JAX to be installed. "
21
+ "Install with: pip install marginaleffects[autodiff] or pip install jax"
22
+ )
23
+
24
+
25
+ if _JAX_AVAILABLE:
26
+ # Import submodules to make them accessible
27
+ from . import linear as linear
28
+ from . import glm as glm
29
+
30
+ # Re-export types for convenience
31
+ from .comparisons import ComparisonType as ComparisonType
32
+ from .glm.families import Family as Family, Link as Link
33
+
34
+ __all__ = [
35
+ "linear",
36
+ "glm",
37
+ "ComparisonType",
38
+ "Family",
39
+ "Link",
40
+ ]
41
+ else:
42
+ # Create dummy module objects that raise helpful errors
43
+ class _DummyModule:
44
+ def __getattr__(self, name):
45
+ _raise_jax_error()
46
+
47
+ linear = _DummyModule()
48
+ glm = _DummyModule()
49
+
50
+ # Create dummy enums that raise errors
51
+ class ComparisonType:
52
+ def __getattribute__(self, name):
53
+ _raise_jax_error()
54
+
55
+ class Family:
56
+ def __getattribute__(self, name):
57
+ _raise_jax_error()
58
+
59
+ class Link:
60
+ def __getattribute__(self, name):
61
+ _raise_jax_error()
62
+
63
+ __all__ = []
@@ -0,0 +1,58 @@
1
+ """Comparison types and functions using enum-based approach for JAX compatibility."""
2
+
3
+ import jax.numpy as jnp
4
+ from jax import lax
5
+ from enum import IntEnum
6
+
7
+
8
+ class ComparisonType(IntEnum):
9
+ """Comparison types for marginal effects."""
10
+
11
+ DIFFERENCE = 0
12
+ RATIO = 1
13
+ LNRATIO = 2
14
+ LNOR = 3
15
+ LIFT = 4
16
+ DIFFERENCEAVG = 5
17
+
18
+
19
+ def _compute_comparison_vector(
20
+ comparison_type: int, pred_hi: jnp.ndarray, pred_lo: jnp.ndarray
21
+ ) -> jnp.ndarray:
22
+ """Apply comparison function element-wise (returns N-length array)."""
23
+ return lax.switch(
24
+ comparison_type,
25
+ [
26
+ lambda hi, lo: hi - lo, # difference
27
+ lambda hi, lo: hi / lo, # ratio
28
+ lambda hi, lo: jnp.log(hi / lo), # lnratio
29
+ lambda hi, lo: jnp.log((hi / (1 - hi)) / (lo / (1 - lo))), # lnor
30
+ lambda hi, lo: (hi - lo) / lo, # lift
31
+ ],
32
+ pred_hi,
33
+ pred_lo,
34
+ )
35
+
36
+
37
+ def _compute_comparison_scalar(
38
+ comparison_type: int, pred_hi: jnp.ndarray, pred_lo: jnp.ndarray
39
+ ) -> jnp.ndarray:
40
+ """Apply comparison function with aggregation (returns scalar)."""
41
+ return lax.switch(
42
+ comparison_type,
43
+ [
44
+ lambda hi, lo: jnp.mean(hi) - jnp.mean(lo), # differenceavg
45
+ ],
46
+ pred_hi,
47
+ pred_lo,
48
+ )
49
+
50
+
51
+ def _compute_comparison(
52
+ comparison_type: int, pred_hi: jnp.ndarray, pred_lo: jnp.ndarray
53
+ ) -> jnp.ndarray:
54
+ """Apply comparison function element-wise (returns N-length array).
55
+
56
+ Note: DIFFERENCEAVG should use _compute_comparison_scalar directly.
57
+ """
58
+ return _compute_comparison_vector(comparison_type, pred_hi, pred_lo)
@@ -0,0 +1,5 @@
1
+ # GLM model implementations with switchable link functions
2
+
3
+ from . import predictions as predictions
4
+ from . import comparisons as comparisons
5
+ from . import families as families
@@ -0,0 +1,193 @@
1
+ import jax.numpy as jnp
2
+ import numpy as np
3
+ from jax import jit, jacrev
4
+ from .families import linkinv, resolve_link
5
+ from ..comparisons import (
6
+ _compute_comparison,
7
+ _compute_comparison_scalar,
8
+ ComparisonType,
9
+ )
10
+ from ..utils import group_reducer, standard_errors
11
+
12
+
13
+ def _comparison_core(
14
+ beta: jnp.ndarray,
15
+ X_hi: jnp.ndarray,
16
+ X_lo: jnp.ndarray,
17
+ comparison_type: int,
18
+ family_type: int,
19
+ link_type: int,
20
+ ) -> jnp.ndarray:
21
+ """Core comparison function - single source of truth for comparison vector computation."""
22
+ pred_hi = linkinv(link_type, X_hi @ beta)
23
+ pred_lo = linkinv(link_type, X_lo @ beta)
24
+ return _compute_comparison(comparison_type, pred_hi, pred_lo)
25
+
26
+
27
+ @jit
28
+ def _comparison_byT(
29
+ beta: jnp.ndarray,
30
+ X_hi: jnp.ndarray,
31
+ X_lo: jnp.ndarray,
32
+ comparison_type: int,
33
+ family_type: int,
34
+ link_type: int = None,
35
+ ) -> jnp.ndarray:
36
+ comp = _comparison_core(beta, X_hi, X_lo, comparison_type, family_type, link_type)
37
+ return jnp.mean(comp)
38
+
39
+
40
+ def _comparison_byG(
41
+ beta: jnp.ndarray,
42
+ X_hi: jnp.ndarray,
43
+ X_lo: jnp.ndarray,
44
+ groups: jnp.ndarray,
45
+ num_groups: int,
46
+ comparison_type: int,
47
+ family_type: int,
48
+ link_type: int = None,
49
+ ) -> jnp.ndarray:
50
+ comp = _comparison_core(beta, X_hi, X_lo, comparison_type, family_type, link_type)
51
+ return group_reducer(comp, groups, num_groups)
52
+
53
+
54
+ @jit
55
+ def _comparisons_core(
56
+ beta: jnp.ndarray,
57
+ X_hi: jnp.ndarray,
58
+ X_lo: jnp.ndarray,
59
+ comparison_type: int,
60
+ family_type: int,
61
+ link_type: int = None,
62
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
63
+ comp = _comparison_core(beta, X_hi, X_lo, comparison_type, family_type, link_type)
64
+ jac = jacrev(
65
+ lambda b: _comparison_core(
66
+ b, X_hi, X_lo, comparison_type, family_type, link_type
67
+ )
68
+ )(beta)
69
+ return comp, jac
70
+
71
+
72
+ def comparisons(
73
+ beta: jnp.ndarray,
74
+ X_hi: jnp.ndarray,
75
+ X_lo: jnp.ndarray,
76
+ vcov: jnp.ndarray,
77
+ comparison_type: int,
78
+ family_type: int,
79
+ link_type: int = None,
80
+ ) -> dict[str, np.ndarray]:
81
+ link_type = resolve_link(family_type, link_type)
82
+
83
+ # Handle DIFFERENCEAVG separately (returns scalar)
84
+ if comparison_type == ComparisonType.DIFFERENCEAVG:
85
+
86
+ @jit
87
+ def _scalar_core(b, X_h, X_l, lt):
88
+ pred_hi = linkinv(lt, X_h @ b)
89
+ pred_lo = linkinv(lt, X_l @ b)
90
+ return _compute_comparison_scalar(0, pred_hi, pred_lo)
91
+
92
+ comp = _scalar_core(beta, X_hi, X_lo, link_type)
93
+ jac = jacrev(lambda b: _scalar_core(b, X_hi, X_lo, link_type))(beta)
94
+ se = standard_errors(jac.reshape(1, -1), vcov)
95
+ return {
96
+ "estimate": np.array(comp),
97
+ "jacobian": np.array(jac),
98
+ "std_error": se[0],
99
+ }
100
+
101
+ # Handle element-wise comparisons
102
+ comp, jac = _comparisons_core(
103
+ beta, X_hi, X_lo, comparison_type, family_type, link_type
104
+ )
105
+ se = standard_errors(jac, vcov)
106
+ return {
107
+ "estimate": np.array(comp),
108
+ "jacobian": np.array(jac),
109
+ "std_error": se,
110
+ }
111
+
112
+
113
+ @jit
114
+ def _comparisons_byT_core(
115
+ beta: jnp.ndarray,
116
+ X_hi: jnp.ndarray,
117
+ X_lo: jnp.ndarray,
118
+ comparison_type: int,
119
+ family_type: int,
120
+ link_type: int = None,
121
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
122
+ comp = _comparison_byT(beta, X_hi, X_lo, comparison_type, family_type, link_type)
123
+ jac = jacrev(
124
+ lambda b: _comparison_byT(
125
+ b, X_hi, X_lo, comparison_type, family_type, link_type
126
+ )
127
+ )(beta)
128
+ return comp, jac
129
+
130
+
131
+ def comparisons_byT(
132
+ beta: jnp.ndarray,
133
+ X_hi: jnp.ndarray,
134
+ X_lo: jnp.ndarray,
135
+ vcov: jnp.ndarray,
136
+ comparison_type: int,
137
+ family_type: int,
138
+ link_type: int = None,
139
+ ) -> dict[str, np.ndarray]:
140
+ link_type = resolve_link(family_type, link_type)
141
+ comp, jac = _comparisons_byT_core(
142
+ beta, X_hi, X_lo, comparison_type, family_type, link_type
143
+ )
144
+ se = standard_errors(jac.reshape(1, -1), vcov)
145
+ return {
146
+ "estimate": np.array(comp),
147
+ "jacobian": np.array(jac),
148
+ "std_error": se[0],
149
+ }
150
+
151
+
152
+ def _comparisons_byG_core(
153
+ beta: jnp.ndarray,
154
+ X_hi: jnp.ndarray,
155
+ X_lo: jnp.ndarray,
156
+ groups: jnp.ndarray,
157
+ num_groups: int,
158
+ comparison_type: int,
159
+ family_type: int,
160
+ link_type: int = None,
161
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
162
+ comp = _comparison_byG(
163
+ beta, X_hi, X_lo, groups, num_groups, comparison_type, family_type, link_type
164
+ )
165
+ jac = jacrev(
166
+ lambda b: _comparison_byG(
167
+ b, X_hi, X_lo, groups, num_groups, comparison_type, family_type, link_type
168
+ )
169
+ )(beta)
170
+ return comp, jac
171
+
172
+
173
+ def comparisons_byG(
174
+ beta: jnp.ndarray,
175
+ X_hi: jnp.ndarray,
176
+ X_lo: jnp.ndarray,
177
+ vcov: jnp.ndarray,
178
+ groups: jnp.ndarray,
179
+ num_groups: int,
180
+ comparison_type: int,
181
+ family_type: int,
182
+ link_type: int = None,
183
+ ) -> dict[str, np.ndarray]:
184
+ link_type = resolve_link(family_type, link_type)
185
+ comp, jac = _comparisons_byG_core(
186
+ beta, X_hi, X_lo, groups, num_groups, comparison_type, family_type, link_type
187
+ )
188
+ se = standard_errors(jac, vcov)
189
+ return {
190
+ "estimate": np.array(comp),
191
+ "jacobian": np.array(jac),
192
+ "std_error": se,
193
+ }
@@ -0,0 +1,108 @@
1
+ """GLM families and link functions using enum-based approach for JAX compatibility."""
2
+
3
+ import jax.numpy as jnp
4
+ from jax import lax
5
+ from jax.scipy.stats import norm
6
+ from enum import IntEnum
7
+
8
+
9
+ class Family(IntEnum):
10
+ """GLM family types."""
11
+
12
+ GAUSSIAN = 0
13
+ BINOMIAL = 1
14
+ POISSON = 2
15
+ GAMMA = 3
16
+ INVERSE_GAUSSIAN = 4
17
+
18
+
19
+ class Link(IntEnum):
20
+ """Link function types."""
21
+
22
+ IDENTITY = 0
23
+ LOG = 1
24
+ LOGIT = 2
25
+ PROBIT = 3
26
+ INVERSE = 4
27
+ SQRT = 5
28
+ CLOGLOG = 6
29
+
30
+
31
+ # Default links for each family
32
+ DEFAULT_LINKS = {
33
+ Family.GAUSSIAN: Link.IDENTITY,
34
+ Family.BINOMIAL: Link.LOGIT,
35
+ Family.POISSON: Link.LOG,
36
+ Family.GAMMA: Link.INVERSE,
37
+ Family.INVERSE_GAUSSIAN: Link.INVERSE,
38
+ }
39
+
40
+
41
+ def linkinv(link_type: int, eta: jnp.ndarray) -> jnp.ndarray:
42
+ """Inverse link function: transform linear predictor to mean."""
43
+ return lax.switch(
44
+ link_type,
45
+ [
46
+ lambda x: x, # identity
47
+ lambda x: jnp.exp(x), # log
48
+ lambda x: 1 / (1 + jnp.exp(-x)), # logit
49
+ lambda x: norm.cdf(x), # probit
50
+ lambda x: 1.0 / x, # inverse
51
+ lambda x: x**2, # sqrt
52
+ lambda x: 1 - jnp.exp(-jnp.exp(x)), # cloglog
53
+ ],
54
+ eta,
55
+ )
56
+
57
+
58
+ def linkfun(link_type: int, mu: jnp.ndarray) -> jnp.ndarray:
59
+ """Link function: transform mean to linear predictor."""
60
+ return lax.switch(
61
+ link_type,
62
+ [
63
+ lambda x: x, # identity
64
+ lambda x: jnp.log(x), # log
65
+ lambda x: jnp.log(x / (1 - x)), # logit
66
+ lambda x: norm.ppf(x), # probit
67
+ lambda x: 1.0 / x, # inverse
68
+ lambda x: jnp.sqrt(x), # sqrt
69
+ lambda x: jnp.log(-jnp.log(1 - x)), # cloglog
70
+ ],
71
+ mu,
72
+ )
73
+
74
+
75
+ # Valid link functions for each family
76
+ VALID_LINKS = {
77
+ Family.GAUSSIAN: [Link.IDENTITY, Link.LOG, Link.INVERSE],
78
+ Family.BINOMIAL: [Link.LOGIT, Link.PROBIT, Link.CLOGLOG, Link.LOG],
79
+ Family.POISSON: [Link.LOG, Link.IDENTITY, Link.SQRT],
80
+ Family.GAMMA: [Link.INVERSE, Link.IDENTITY, Link.LOG],
81
+ Family.INVERSE_GAUSSIAN: [Link.INVERSE, Link.IDENTITY, Link.LOG],
82
+ }
83
+
84
+
85
+ def validate_family_link(family_type: int, link_type: int) -> bool:
86
+ """Check if link function is valid for the given family."""
87
+ if family_type not in VALID_LINKS:
88
+ return False
89
+ return link_type in VALID_LINKS[family_type]
90
+
91
+
92
+ def resolve_link(family_type: int, link_type: int = None) -> int:
93
+ """Resolve link type, using default if None, and validate the combination."""
94
+ if link_type is None:
95
+ return DEFAULT_LINKS.get(family_type, Link.IDENTITY)
96
+ if not validate_family_link(family_type, link_type):
97
+ default = DEFAULT_LINKS.get(family_type, Link.IDENTITY)
98
+ raise ValueError(
99
+ f"Invalid link {link_type} for family {family_type}. "
100
+ f"Valid links: {VALID_LINKS.get(family_type, [])}. "
101
+ f"Using default: {default}"
102
+ )
103
+ return link_type
104
+
105
+
106
+ # Convenience instances for direct use
107
+ family = Family
108
+ link = Link
@@ -0,0 +1,131 @@
1
+ import jax.numpy as jnp
2
+ import numpy as np
3
+ from jax import jacfwd, jacrev, jit
4
+ from .families import linkinv, resolve_link
5
+ from ..utils import group_reducer, standard_errors
6
+
7
+
8
+ def _predict_core(
9
+ beta: jnp.ndarray,
10
+ X: jnp.ndarray,
11
+ family_type: int,
12
+ link_type: int,
13
+ ) -> jnp.ndarray:
14
+ """Core prediction function - single source of truth for prediction computation."""
15
+ linear_pred = X @ beta
16
+ return linkinv(link_type, linear_pred)
17
+
18
+
19
+ @jit
20
+ def _predict(
21
+ beta: jnp.ndarray, X: jnp.ndarray, family_type: int, link_type: int = None
22
+ ) -> jnp.ndarray:
23
+ return _predict_core(beta, X, family_type, link_type)
24
+
25
+
26
+ @jit
27
+ def _predict_byT(
28
+ beta: jnp.ndarray, X: jnp.ndarray, family_type: int, link_type: int = None
29
+ ) -> jnp.ndarray:
30
+ pred = _predict_core(beta, X, family_type, link_type)
31
+ return jnp.mean(pred)
32
+
33
+
34
+ def _predict_byG(
35
+ beta: jnp.ndarray,
36
+ X: jnp.ndarray,
37
+ groups: jnp.ndarray,
38
+ num_groups: int,
39
+ family_type: int,
40
+ link_type: int = None,
41
+ ) -> jnp.ndarray:
42
+ pred = _predict_core(beta, X, family_type, link_type)
43
+ return group_reducer(pred, groups, num_groups)
44
+
45
+
46
+ @jit
47
+ def _predictions_core(
48
+ beta: jnp.ndarray, X: jnp.ndarray, family_type: int, link_type: int = None
49
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
50
+ pred = _predict_core(beta, X, family_type, link_type)
51
+ jac = jacfwd(_predict, argnums=0)(beta, X, family_type, link_type)
52
+ return pred, jac
53
+
54
+
55
+ def predictions(
56
+ beta: jnp.ndarray,
57
+ X: jnp.ndarray,
58
+ vcov: jnp.ndarray,
59
+ family_type: int,
60
+ link_type: int = None,
61
+ ) -> dict[str, np.ndarray]:
62
+ link_type = resolve_link(family_type, link_type)
63
+ pred, jac = _predictions_core(beta, X, family_type, link_type)
64
+ se = standard_errors(jac, vcov)
65
+ return {
66
+ "estimate": np.array(pred),
67
+ "jacobian": np.array(jac),
68
+ "std_error": se,
69
+ }
70
+
71
+
72
+ @jit
73
+ def _predictions_byT_core(
74
+ beta: jnp.ndarray, X: jnp.ndarray, family_type: int, link_type: int = None
75
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
76
+ pred = _predict_byT(beta, X, family_type, link_type)
77
+ jac = jacrev(_predict_byT, argnums=0)(beta, X, family_type, link_type)
78
+ return pred, jac
79
+
80
+
81
+ def predictions_byT(
82
+ beta: jnp.ndarray,
83
+ X: jnp.ndarray,
84
+ vcov: jnp.ndarray,
85
+ family_type: int,
86
+ link_type: int = None,
87
+ ) -> dict[str, np.ndarray]:
88
+ link_type = resolve_link(family_type, link_type)
89
+ pred, jac = _predictions_byT_core(beta, X, family_type, link_type)
90
+ se = standard_errors(jac.reshape(1, -1), vcov)
91
+ return {
92
+ "estimate": np.array(pred),
93
+ "jacobian": np.array(jac),
94
+ "std_error": se[0],
95
+ }
96
+
97
+
98
+ def _predictions_byG_core(
99
+ beta: jnp.ndarray,
100
+ X: jnp.ndarray,
101
+ groups: jnp.ndarray,
102
+ num_groups: int,
103
+ family_type: int,
104
+ link_type: int = None,
105
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
106
+ pred = _predict_byG(beta, X, groups, num_groups, family_type, link_type)
107
+ jac = jacrev(
108
+ lambda b: _predict_byG(b, X, groups, num_groups, family_type, link_type)
109
+ )(beta)
110
+ return pred, jac
111
+
112
+
113
+ def predictions_byG(
114
+ beta: jnp.ndarray,
115
+ X: jnp.ndarray,
116
+ vcov: jnp.ndarray,
117
+ groups: jnp.ndarray,
118
+ num_groups: int,
119
+ family_type: int,
120
+ link_type: int = None,
121
+ ) -> dict[str, np.ndarray]:
122
+ link_type = resolve_link(family_type, link_type)
123
+ pred, jac = _predictions_byG_core(
124
+ beta, X, groups, num_groups, family_type, link_type
125
+ )
126
+ se = standard_errors(jac, vcov)
127
+ return {
128
+ "estimate": np.array(pred),
129
+ "jacobian": np.array(jac),
130
+ "std_error": se,
131
+ }
@@ -0,0 +1,4 @@
1
+ # Linear model implementations
2
+
3
+ from . import predictions as predictions
4
+ from . import comparisons as comparisons