numerax 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- numerax/__init__.py +37 -0
- numerax/special/__init__.py +3 -0
- numerax/special/gamma.py +217 -0
- numerax/stats/__init__.py +5 -0
- numerax/stats/profile.py +165 -0
- numerax/utils.py +48 -0
- numerax-0.1.0.dist-info/METADATA +110 -0
- numerax-0.1.0.dist-info/RECORD +10 -0
- numerax-0.1.0.dist-info/WHEEL +4 -0
- numerax-0.1.0.dist-info/licenses/LICENSE +21 -0
numerax/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Statistical and numerical computation functions for JAX, focusing on tools
|
|
3
|
+
not available in the main JAX API.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
This package provides JAX-compatible implementations of specialized numerical
|
|
8
|
+
functions with full differentiability support. All functions are designed to
|
|
9
|
+
work seamlessly with JAX's transformations (JIT, grad, vmap, etc.) and follow
|
|
10
|
+
JAX's functional programming paradigms.
|
|
11
|
+
|
|
12
|
+
### Special Functions (`numerax.special`)
|
|
13
|
+
|
|
14
|
+
Mathematical special functions with custom derivative implementations.
|
|
15
|
+
Functions use numerically stable algorithms and provide exact gradients
|
|
16
|
+
through custom JVP rules where standard automatic differentiation would
|
|
17
|
+
be inefficient or unstable.
|
|
18
|
+
|
|
19
|
+
### Statistical Methods (`numerax.stats`)
|
|
20
|
+
|
|
21
|
+
Advanced statistical computation tools for inference problems. Implements
|
|
22
|
+
efficient algorithms for complex statistical models, with particular focus
|
|
23
|
+
on optimization-based methods that benefit from JAX's compilation and
|
|
24
|
+
differentiation capabilities.
|
|
25
|
+
|
|
26
|
+
### Utilities (`numerax.utils`)
|
|
27
|
+
|
|
28
|
+
Development utilities for creating JAX-compatible functions with proper
|
|
29
|
+
documentation support. Includes decorators and helpers for preserving
|
|
30
|
+
function metadata when using JAX's advanced features like custom derivatives.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from . import special, stats, utils
|
|
34
|
+
|
|
35
|
+
__version__ = "0.1.0"
|
|
36
|
+
|
|
37
|
+
__all__ = ["special", "stats", "utils"]
|
numerax/special/gamma.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import jax.scipy.special as special
|
|
4
|
+
from jaxtyping import ArrayLike
|
|
5
|
+
|
|
6
|
+
# Global constants for numerical stability - adapt to JAX precision setting
|
|
7
|
+
_DTYPE = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32
|
|
8
|
+
TINY = jnp.finfo(_DTYPE).smallest_normal # For preventing underflow
|
|
9
|
+
EPS = jnp.finfo(_DTYPE).eps # For convergence tolerance (machine epsilon)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@jax.custom_jvp
|
|
13
|
+
def gammap_inverse(p: ArrayLike, a: float) -> ArrayLike:
|
|
14
|
+
r"""
|
|
15
|
+
Inverse of the regularized incomplete gamma function.
|
|
16
|
+
|
|
17
|
+
## Overview
|
|
18
|
+
|
|
19
|
+
Computes the inverse of the regularized incomplete gamma function, finding
|
|
20
|
+
$x$ such that $P(a, x) = p$, where $P(a, x)$ is the regularized incomplete
|
|
21
|
+
gamma function. This is equivalent to computing quantiles of the
|
|
22
|
+
$\text{Gamma}(a, 1)$ distribution. The general strategy and the initial
|
|
23
|
+
guess are based on the methods described in
|
|
24
|
+
Numerical Recipes (Press et al., 2007).
|
|
25
|
+
|
|
26
|
+
## Mathematical Background
|
|
27
|
+
|
|
28
|
+
The regularized incomplete gamma function is defined as:
|
|
29
|
+
|
|
30
|
+
$$P(a, x) = \frac{\gamma(a, x)}{\Gamma(a)} = \frac{1}{\Gamma(a)}
|
|
31
|
+
\int_0^x t^{a-1} e^{-t} dt$$
|
|
32
|
+
|
|
33
|
+
This function solves the inverse problem:
|
|
34
|
+
|
|
35
|
+
$$x = P^{-1}(a, p) \quad \text{such that} \quad P(a, x) = p$$
|
|
36
|
+
|
|
37
|
+
For a random variable $X \sim \text{Gamma}(a, 1)$, this gives:
|
|
38
|
+
|
|
39
|
+
$$x = F^{-1}(p) \quad \text{where} \quad F(x) = P(\Gamma(a), x)$$
|
|
40
|
+
|
|
41
|
+
## Numerical Method
|
|
42
|
+
|
|
43
|
+
Uses Halley's method for fast quadratic convergence:
|
|
44
|
+
|
|
45
|
+
$$x_{n+1} = x_n - \frac{2f(x_n)f'(x_n)}{2[f'(x_n)]^2 - f(x_n)f''(x_n)}$$
|
|
46
|
+
|
|
47
|
+
where $f(x) = P(a, x) - p$.
|
|
48
|
+
|
|
49
|
+
**Initial guess** based on Numerical Recipes (Press et al., 2007):
|
|
50
|
+
- For $a > 1$: Wilson-Hilferty approximation
|
|
51
|
+
- For $a \leq 1$: Asymptotic expansions
|
|
52
|
+
|
|
53
|
+
## Args
|
|
54
|
+
|
|
55
|
+
- **p**: Probability values in $[0, 1]$. Can be scalar or array.
|
|
56
|
+
- **a**: Shape parameter (must be positive). Scalar value.
|
|
57
|
+
|
|
58
|
+
## Returns
|
|
59
|
+
|
|
60
|
+
Quantiles $x$ where $P(a, x) = p$. Same shape as input `p`.
|
|
61
|
+
|
|
62
|
+
## Example
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
import jax.numpy as jnp
|
|
66
|
+
import numerax
|
|
67
|
+
|
|
68
|
+
# Single quantile
|
|
69
|
+
x = numerax.special.gammap_inverse(0.5, 2.0) # Median of Gamma(2, 1)
|
|
70
|
+
|
|
71
|
+
# Multiple quantiles
|
|
72
|
+
p_vals = jnp.array([0.1, 0.25, 0.5, 0.75, 0.9])
|
|
73
|
+
x_vals = numerax.special.gammap_inverse(p_vals, 3.0)
|
|
74
|
+
|
|
75
|
+
# Verify inverse relationship
|
|
76
|
+
from jax.scipy.special import gammainc
|
|
77
|
+
|
|
78
|
+
p_recovered = gammainc(2.0, x) # Should equal original p
|
|
79
|
+
|
|
80
|
+
# Differentiable for optimization
|
|
81
|
+
grad_fn = jax.grad(numerax.special.gammap_inverse)
|
|
82
|
+
sensitivity = grad_fn(0.5, 2.0) # ∂x/∂p at median
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
## Notes
|
|
86
|
+
|
|
87
|
+
- **Convergence**: Typically converges in 3-8 iterations
|
|
88
|
+
- **Differentiable**: Custom JVP implementation using implicit function
|
|
89
|
+
theorem
|
|
90
|
+
- **Numerical stability**: Handles edge cases near 0 and 1
|
|
91
|
+
- **Performance**: JIT-compiled with adaptive precision
|
|
92
|
+
- **Domain**: $p \in [0, 1]$ and $a > 0$
|
|
93
|
+
|
|
94
|
+
## References
|
|
95
|
+
|
|
96
|
+
Press, W. H., Teukolsky, S. A., Vetterling, W. T., & Flannery, B. P.
|
|
97
|
+
(2007).
|
|
98
|
+
*Numerical Recipes: The Art of Scientific Computing* (3rd ed.).
|
|
99
|
+
Cambridge University Press.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def objective(x):
|
|
103
|
+
"""F(x) = gammainc(a, x) - p"""
|
|
104
|
+
return special.gammainc(a, x) - p
|
|
105
|
+
|
|
106
|
+
# Initial guess from Numerical Recipes
|
|
107
|
+
def initial_guess(u_val, a_val):
|
|
108
|
+
# a = dof/2 for chi-squared
|
|
109
|
+
|
|
110
|
+
def large_a_guess():
|
|
111
|
+
# For a > 1: use Wilson-Hilferty approximation
|
|
112
|
+
pp = jnp.where(u_val < 0.5, u_val, 1.0 - u_val)
|
|
113
|
+
t = jnp.sqrt(-2.0 * jnp.log(pp))
|
|
114
|
+
x = (2.30753 + t * 0.27061) / (
|
|
115
|
+
1.0 + t * (0.99229 + t * 0.04481)
|
|
116
|
+
) - t
|
|
117
|
+
x = jnp.where(u_val < 0.5, -x, x)
|
|
118
|
+
return jnp.fmax(
|
|
119
|
+
1e-3,
|
|
120
|
+
a_val
|
|
121
|
+
* (1.0 - 1.0 / (9.0 * a_val) - x / (3.0 * jnp.sqrt(a_val)))
|
|
122
|
+
** 3,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def small_a_guess():
|
|
126
|
+
# For a <= 1: use equations (6.2.8) and (6.2.9)
|
|
127
|
+
t = 1.0 - a_val * (0.253 + a_val * 0.12)
|
|
128
|
+
return jnp.where(
|
|
129
|
+
u_val < t,
|
|
130
|
+
(u_val / t) ** (1.0 / a_val),
|
|
131
|
+
1.0 - jnp.log(1.0 - (u_val - t) / (1.0 - t)),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return jnp.real(
|
|
135
|
+
jnp.where(a_val > 1.0, large_a_guess(), small_a_guess())
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Derivatives for Halley's method
|
|
139
|
+
f = objective
|
|
140
|
+
df_dx = jax.grad(objective)
|
|
141
|
+
d2f_dx2 = jax.grad(df_dx)
|
|
142
|
+
|
|
143
|
+
x = initial_guess(p, a)
|
|
144
|
+
|
|
145
|
+
# Use while_loop for dynamic convergence
|
|
146
|
+
def cond_fn(state):
|
|
147
|
+
x, step, iteration = state
|
|
148
|
+
# Continue while step is large and we haven't exceeded max iterations
|
|
149
|
+
return (jnp.abs(step) > EPS * jnp.abs(x)) & (iteration < 12)
|
|
150
|
+
|
|
151
|
+
def body_fn(state):
|
|
152
|
+
x, _, iteration = state
|
|
153
|
+
|
|
154
|
+
f_val = f(x)
|
|
155
|
+
df_val = df_dx(x)
|
|
156
|
+
d2f_val = d2f_dx2(x)
|
|
157
|
+
|
|
158
|
+
# Halley's method: x_{n+1} = x_n - 2*f*f' / (2*f'^2 - f*f'')
|
|
159
|
+
numerator = 2 * f_val * df_val
|
|
160
|
+
denominator = 2 * df_val**2 - f_val * d2f_val
|
|
161
|
+
|
|
162
|
+
# Avoid division by zero and ensure step is reasonable
|
|
163
|
+
denominator = jnp.where(
|
|
164
|
+
jnp.abs(denominator) < TINY,
|
|
165
|
+
jnp.sign(denominator) * TINY,
|
|
166
|
+
denominator,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
step = numerator / denominator
|
|
170
|
+
x_new = x - step
|
|
171
|
+
|
|
172
|
+
# Ensure x stays positive
|
|
173
|
+
x_new = jnp.fmax(x_new, TINY)
|
|
174
|
+
|
|
175
|
+
return (x_new, step, iteration + 1)
|
|
176
|
+
|
|
177
|
+
# Initial state: (x, step, iteration)
|
|
178
|
+
initial_state = (x, jnp.inf, 0)
|
|
179
|
+
final_state = jax.lax.while_loop(cond_fn, body_fn, initial_state)
|
|
180
|
+
x = final_state[0]
|
|
181
|
+
|
|
182
|
+
return x
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@gammap_inverse.defjvp
|
|
186
|
+
def gammap_inverse_jvp(primals, tangents):
|
|
187
|
+
"""
|
|
188
|
+
Custom JVP for gammap_inverse using implicit function theorem.
|
|
189
|
+
|
|
190
|
+
For F(x, p) = gammainc(a, x) - p = 0:
|
|
191
|
+
dx/dp = -∂F/∂p / ∂F/∂x = 1 / (∂/∂x gammainc(a, x))
|
|
192
|
+
"""
|
|
193
|
+
p, a = primals
|
|
194
|
+
p_dot, _ = tangents
|
|
195
|
+
|
|
196
|
+
# Forward pass
|
|
197
|
+
x = gammap_inverse(p, a)
|
|
198
|
+
|
|
199
|
+
# Compute derivative: dx/dp = 1 / (d/dx gammainc(a, x))
|
|
200
|
+
def gammainc_x(x_val):
|
|
201
|
+
return special.gammainc(a, x_val)
|
|
202
|
+
|
|
203
|
+
dgammainc_dx = jax.grad(gammainc_x)(x)
|
|
204
|
+
|
|
205
|
+
# Avoid division by zero
|
|
206
|
+
dgammainc_dx = jnp.where(
|
|
207
|
+
jnp.abs(dgammainc_dx) < TINY,
|
|
208
|
+
jnp.sign(dgammainc_dx) * TINY,
|
|
209
|
+
dgammainc_dx,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
dx_dp = 1.0 / dgammainc_dx
|
|
213
|
+
|
|
214
|
+
# For now, ignore a derivatives (could be added if needed)
|
|
215
|
+
x_dot = dx_dp * p_dot
|
|
216
|
+
|
|
217
|
+
return x, x_dot
|
numerax/stats/profile.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Profile likelihood functions for statistical inference."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import optax
|
|
8
|
+
from optax import lbfgs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def make_profile_llh(
|
|
12
|
+
llh_fn: Callable,
|
|
13
|
+
is_nuisance: list[bool] | jnp.ndarray,
|
|
14
|
+
get_initial_nuisance: Callable,
|
|
15
|
+
tol: float = 1e-6,
|
|
16
|
+
initial_value: float = 1e-9,
|
|
17
|
+
initial_diff: float = 1e9,
|
|
18
|
+
) -> Callable:
|
|
19
|
+
r"""
|
|
20
|
+
Factory function for creating profile likelihood functions.
|
|
21
|
+
|
|
22
|
+
## Overview
|
|
23
|
+
|
|
24
|
+
Profile likelihood is a statistical technique used when dealing with
|
|
25
|
+
nuisance parameters that are not of primary interest but are necessary
|
|
26
|
+
for the model. This function creates an optimized profile likelihood
|
|
27
|
+
that maximizes over nuisance parameters while keeping inference
|
|
28
|
+
parameters fixed.
|
|
29
|
+
|
|
30
|
+
## Mathematical Background
|
|
31
|
+
|
|
32
|
+
Given a likelihood function $L(\boldsymbol{\theta}, \boldsymbol{\lambda})$
|
|
33
|
+
where $\boldsymbol{\theta}$ are parameters of interest and
|
|
34
|
+
$\boldsymbol{\lambda}$ are nuisance parameters, the profile likelihood is:
|
|
35
|
+
|
|
36
|
+
$$L_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}}
|
|
37
|
+
L(\boldsymbol{\theta}, \boldsymbol{\lambda})$$
|
|
38
|
+
|
|
39
|
+
In practice, we work with the log-likelihood
|
|
40
|
+
$\ell(\boldsymbol{\theta}, \boldsymbol{\lambda}) =
|
|
41
|
+
\log L(\boldsymbol{\theta}, \boldsymbol{\lambda})$:
|
|
42
|
+
|
|
43
|
+
$$\ell_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}}
|
|
44
|
+
\ell(\boldsymbol{\theta}, \boldsymbol{\lambda})$$
|
|
45
|
+
|
|
46
|
+
This function uses L-BFGS optimization to find the maximum likelihood
|
|
47
|
+
estimates of nuisance parameters for each fixed value of inference
|
|
48
|
+
parameters.
|
|
49
|
+
|
|
50
|
+
## Args
|
|
51
|
+
|
|
52
|
+
- **llh_fn**: Log likelihood function taking (params, *args) and
|
|
53
|
+
returning scalar log likelihood value
|
|
54
|
+
- **is_nuisance**: Boolean array where True indicates nuisance
|
|
55
|
+
parameters and False indicates inference parameters
|
|
56
|
+
- **get_initial_nuisance**: Function taking (*args) and returning
|
|
57
|
+
initial values for nuisance parameters
|
|
58
|
+
- **tol**: Convergence tolerance for the optimization (default: 1e-6)
|
|
59
|
+
- **initial_value**: Initial objective value for convergence tracking
|
|
60
|
+
(default: 1e-9)
|
|
61
|
+
- **initial_diff**: Initial difference for convergence tracking
|
|
62
|
+
(default: 1e9)
|
|
63
|
+
|
|
64
|
+
## Returns
|
|
65
|
+
|
|
66
|
+
Profile likelihood function with signature:
|
|
67
|
+
`(inference_values, *args) -> (profile_llh_value, optimal_nuisance,
|
|
68
|
+
convergence_diff, num_iterations)`
|
|
69
|
+
|
|
70
|
+
## Example
|
|
71
|
+
|
|
72
|
+
Consider fitting a normal distribution where we want to infer the mean
|
|
73
|
+
$\mu$ but treat the variance $\sigma^2$ as a nuisance parameter:
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
import jax.numpy as jnp
|
|
77
|
+
import numerax
|
|
78
|
+
|
|
79
|
+
# Sample data
|
|
80
|
+
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1, 1.3, 0.7, 1.4])
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Log likelihood for normal distribution
|
|
84
|
+
def normal_llh(params, data):
|
|
85
|
+
mu, log_sigma = params # Use log(sigma) for numerical stability
|
|
86
|
+
sigma = jnp.exp(log_sigma)
|
|
87
|
+
return jnp.sum(
|
|
88
|
+
-0.5 * jnp.log(2 * jnp.pi)
|
|
89
|
+
- log_sigma
|
|
90
|
+
- 0.5 * ((data - mu) / sigma) ** 2
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Profile over log_sigma (nuisance), infer mu
|
|
95
|
+
is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_initial_log_sigma(data):
|
|
99
|
+
# Initialize with log of sample standard deviation
|
|
100
|
+
return jnp.array([jnp.log(jnp.std(data))])
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
profile_llh = numerax.stats.make_profile_llh(
|
|
104
|
+
normal_llh, is_nuisance, get_initial_log_sigma
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Evaluate profile likelihood at different mu values
|
|
108
|
+
mu_test = 1.0
|
|
109
|
+
llh_val, opt_log_sigma, diff, n_iter = profile_llh(
|
|
110
|
+
jnp.array([mu_test]), data
|
|
111
|
+
)
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
## Notes
|
|
115
|
+
|
|
116
|
+
- The function is JIT-compiled for performance
|
|
117
|
+
- Uses L-BFGS optimization which is well-suited for smooth likelihood
|
|
118
|
+
surfaces
|
|
119
|
+
- Returns convergence information for diagnostics
|
|
120
|
+
- Handles parameter masking automatically
|
|
121
|
+
- Consider using log-parameterization for positive parameters
|
|
122
|
+
(e.g., $\log \sigma$) to improve numerical stability
|
|
123
|
+
"""
|
|
124
|
+
nuisance_mask = jnp.array(is_nuisance)
|
|
125
|
+
inference_mask = ~nuisance_mask
|
|
126
|
+
|
|
127
|
+
@jax.jit
|
|
128
|
+
def profile_llh(inference_values, *args):
|
|
129
|
+
solver = lbfgs()
|
|
130
|
+
initial_nuisance = get_initial_nuisance(*args)
|
|
131
|
+
opt_state = solver.init(initial_nuisance)
|
|
132
|
+
|
|
133
|
+
def objective(nuisance_params):
|
|
134
|
+
# Reconstruct full parameter vector
|
|
135
|
+
full_params = jnp.zeros(len(nuisance_mask))
|
|
136
|
+
full_params = full_params.at[inference_mask].set(inference_values)
|
|
137
|
+
full_params = full_params.at[nuisance_mask].set(nuisance_params)
|
|
138
|
+
return -llh_fn(full_params, *args)
|
|
139
|
+
|
|
140
|
+
value_and_grad = optax.value_and_grad_from_state(objective)
|
|
141
|
+
|
|
142
|
+
def profile_llh_loopfun(var):
|
|
143
|
+
params, last_value, opt_state, _, n = var
|
|
144
|
+
value, grad = value_and_grad(params, state=opt_state)
|
|
145
|
+
updates, opt_state = solver.update(
|
|
146
|
+
grad,
|
|
147
|
+
opt_state,
|
|
148
|
+
params,
|
|
149
|
+
value=value,
|
|
150
|
+
grad=grad,
|
|
151
|
+
value_fn=objective,
|
|
152
|
+
)
|
|
153
|
+
params = optax.apply_updates(params, updates)
|
|
154
|
+
diff = last_value - value
|
|
155
|
+
return params, value, opt_state, diff, n + 1
|
|
156
|
+
|
|
157
|
+
params, value, opt_state, diff, n = jax.lax.while_loop(
|
|
158
|
+
lambda var: jnp.abs(var[-2]) > jnp.abs(var[1] * tol),
|
|
159
|
+
profile_llh_loopfun,
|
|
160
|
+
(initial_nuisance, initial_value, opt_state, initial_diff, 0),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return -value, params, diff, n
|
|
164
|
+
|
|
165
|
+
return profile_llh
|
numerax/utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Utility functions for the numerax package."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import TypeVar
|
|
6
|
+
|
|
7
|
+
F = TypeVar("F", bound=Callable)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def preserve_metadata(decorator):
|
|
11
|
+
"""
|
|
12
|
+
Wrapper that ensures a decorator preserves function metadata for
|
|
13
|
+
documentation tools.
|
|
14
|
+
|
|
15
|
+
## Overview
|
|
16
|
+
|
|
17
|
+
This is particularly useful for JAX decorators like `@custom_jvp` that
|
|
18
|
+
create special objects which may not preserve `__doc__` and other metadata
|
|
19
|
+
properly for documentation generators like pdoc.
|
|
20
|
+
|
|
21
|
+
## Args
|
|
22
|
+
|
|
23
|
+
- **decorator**: The decorator function to wrap
|
|
24
|
+
|
|
25
|
+
## Returns
|
|
26
|
+
|
|
27
|
+
A new decorator that preserves metadata
|
|
28
|
+
|
|
29
|
+
## Example
|
|
30
|
+
|
|
31
|
+
```python
|
|
32
|
+
import jax
|
|
33
|
+
from numerax.utils import preserve_metadata
|
|
34
|
+
|
|
35
|
+
@preserve_metadata(jax.custom_jvp)
|
|
36
|
+
def my_function(x):
|
|
37
|
+
\"\"\"This docstring will be preserved for pdoc.\"\"\"
|
|
38
|
+
return x
|
|
39
|
+
```
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def metadata_preserving_decorator(func: F) -> F:
|
|
43
|
+
# Apply the original decorator
|
|
44
|
+
decorated = decorator(func)
|
|
45
|
+
# Ensure metadata is preserved using functools.wraps pattern
|
|
46
|
+
return functools.wraps(func)(decorated)
|
|
47
|
+
|
|
48
|
+
return metadata_preserving_decorator
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: numerax
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Project-URL: Documentation, https://github.com/juehang/numerax#readme
|
|
5
|
+
Project-URL: Issues, https://github.com/juehang/numerax/issues
|
|
6
|
+
Project-URL: Source, https://github.com/juehang/numerax
|
|
7
|
+
Author: Juehang Qin
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Programming Language :: Python
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
14
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
15
|
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
16
|
+
Requires-Python: >=3.12
|
|
17
|
+
Requires-Dist: jax
|
|
18
|
+
Requires-Dist: jaxtyping
|
|
19
|
+
Requires-Dist: optax
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
|
|
22
|
+
# numerax
|
|
23
|
+
|
|
24
|
+
[](https://github.com/juehang/numerax/actions/workflows/test.yml)
|
|
25
|
+
[](https://juehang.github.io/numerax/)
|
|
26
|
+
|
|
27
|
+
Statistical and numerical computation functions for JAX, focusing on tools not available in the main JAX API.
|
|
28
|
+
|
|
29
|
+
**[📖 Documentation](https://juehang.github.io/numerax/)**
|
|
30
|
+
|
|
31
|
+
## Installation
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
pip install numerax
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Features
|
|
38
|
+
|
|
39
|
+
### Special Functions
|
|
40
|
+
|
|
41
|
+
Inverse regularized incomplete gamma function with differentiability support:
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
import jax.numpy as jnp
|
|
45
|
+
import numerax
|
|
46
|
+
|
|
47
|
+
# Compute gamma quantiles (inverse CDF)
|
|
48
|
+
p = jnp.array([0.1, 0.5, 0.9]) # Probabilities
|
|
49
|
+
a = 2.0 # Shape parameter
|
|
50
|
+
|
|
51
|
+
x = numerax.special.gammap_inverse(p, a)
|
|
52
|
+
# Returns quantiles where gammainc(a, x) = p
|
|
53
|
+
|
|
54
|
+
# Fully differentiable with custom JVP
|
|
55
|
+
grad_fn = jax.grad(numerax.special.gammap_inverse)
|
|
56
|
+
dx_dp = grad_fn(0.5, 2.0) # Gradient with respect to probability
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
**Key features:**
|
|
60
|
+
- Halley's method for fast convergence
|
|
61
|
+
- Custom JVP implementation for exact gradients
|
|
62
|
+
- Numerical stability with adaptive precision
|
|
63
|
+
- Equivalent to gamma distribution inverse CDF
|
|
64
|
+
|
|
65
|
+
### Profile Likelihood
|
|
66
|
+
|
|
67
|
+
Efficient profile likelihood computation for statistical inference with nuisance parameters:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
import jax.numpy as jnp
|
|
71
|
+
import numerax
|
|
72
|
+
|
|
73
|
+
# Example: Normal distribution with mean inference, variance profiling
|
|
74
|
+
def normal_llh(params, data):
|
|
75
|
+
mu, log_sigma = params
|
|
76
|
+
sigma = jnp.exp(log_sigma)
|
|
77
|
+
return jnp.sum(-0.5 * jnp.log(2 * jnp.pi) - log_sigma
|
|
78
|
+
- 0.5 * ((data - mu) / sigma) ** 2)
|
|
79
|
+
|
|
80
|
+
# Profile over log_sigma, infer mu
|
|
81
|
+
is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
|
|
82
|
+
|
|
83
|
+
def get_initial_log_sigma(data):
|
|
84
|
+
return jnp.array([jnp.log(jnp.std(data))])
|
|
85
|
+
|
|
86
|
+
profile_llh = numerax.stats.make_profile_llh(
|
|
87
|
+
normal_llh, is_nuisance, get_initial_log_sigma
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Evaluate profile likelihood
|
|
91
|
+
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1])
|
|
92
|
+
llh_val, opt_nuisance, diff, n_iter = profile_llh(jnp.array([1.0]), data)
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
**Key features:**
|
|
96
|
+
- JIT-compiled for performance
|
|
97
|
+
- L-BFGS optimization with convergence diagnostics
|
|
98
|
+
- Configurable tolerance and initial values
|
|
99
|
+
- Handles parameter masking automatically
|
|
100
|
+
|
|
101
|
+
### Utilities
|
|
102
|
+
|
|
103
|
+
Development utilities for creating JAX functions with custom derivatives while ensuring proper documentation support. Includes decorators for preserving function metadata when using JAX's advanced features.
|
|
104
|
+
|
|
105
|
+
## Requirements
|
|
106
|
+
|
|
107
|
+
- Python ≥ 3.12
|
|
108
|
+
- JAX
|
|
109
|
+
- jaxtyping
|
|
110
|
+
- optax
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
numerax/__init__.py,sha256=y9LIfr6Fo2w7s9y1Y9QmpyNRfHEaGQuEY6OTs4jKbkw,1340
|
|
2
|
+
numerax/utils.py,sha256=oDEjLGKNJv7_lJQ5JVAZ-pXkYnqzrBjYYirLyj1CSOM,1197
|
|
3
|
+
numerax/special/__init__.py,sha256=sG_dsAOoNynwF1DI_Jn_JlYcA3hWDxMNvgSsE7M0pBs,64
|
|
4
|
+
numerax/special/gamma.py,sha256=SdA3Bb_Pe-NHl02rpDGO2D0B436eGJCSLbJ2lWR7pmY,6443
|
|
5
|
+
numerax/stats/__init__.py,sha256=0FyAqMs_xPWXTlQO-LI4y2RZ0wFQNq00uI_8ZhM_XzE,111
|
|
6
|
+
numerax/stats/profile.py,sha256=KfYG0XyRCguiARrTv3C2N-u-p9GMSS9tIPPA5C9YYj4,5540
|
|
7
|
+
numerax-0.1.0.dist-info/METADATA,sha256=BjwCUYfgtWWQZDDuUX_41zDR9PtZVnSxFptiAGgir2Y,3341
|
|
8
|
+
numerax-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
9
|
+
numerax-0.1.0.dist-info/licenses/LICENSE,sha256=q_748ZfuhHCem3AhXv-C9g5u17674OyDlE74DVDu_Ec,1068
|
|
10
|
+
numerax-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Juehang Qin
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|