numerax 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- numerax-0.1.0/.claude/settings.local.json +13 -0
- numerax-0.1.0/.github/workflows/docs.yml +59 -0
- numerax-0.1.0/.github/workflows/publish.yml +84 -0
- numerax-0.1.0/.github/workflows/test.yml +39 -0
- numerax-0.1.0/.gitignore +5 -0
- numerax-0.1.0/LICENSE +21 -0
- numerax-0.1.0/PKG-INFO +110 -0
- numerax-0.1.0/README.md +89 -0
- numerax-0.1.0/docs/api/index.md +9 -0
- numerax-0.1.0/docs/api/special.md +6 -0
- numerax-0.1.0/docs/api/stats.md +6 -0
- numerax-0.1.0/docs/api/utils.md +6 -0
- numerax-0.1.0/docs/index.md +6 -0
- numerax-0.1.0/docs/stylesheets/extra.css +18 -0
- numerax-0.1.0/mkdocs.yml +73 -0
- numerax-0.1.0/pyproject.toml +133 -0
- numerax-0.1.0/src/numerax/__init__.py +37 -0
- numerax-0.1.0/src/numerax/special/__init__.py +3 -0
- numerax-0.1.0/src/numerax/special/gamma.py +217 -0
- numerax-0.1.0/src/numerax/stats/__init__.py +5 -0
- numerax-0.1.0/src/numerax/stats/profile.py +165 -0
- numerax-0.1.0/src/numerax/utils.py +48 -0
- numerax-0.1.0/tests/__init__.py +1 -0
- numerax-0.1.0/tests/test_gammap_inverse.py +50 -0
- numerax-0.1.0/tests/test_profile_gaussian.py +116 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
{
|
|
2
|
+
"permissions": {
|
|
3
|
+
"allow": [
|
|
4
|
+
"Bash(find:*)",
|
|
5
|
+
"Bash(ruff format:*)",
|
|
6
|
+
"mcp__vscode-tools__get_symbol_definition_code",
|
|
7
|
+
"mcp__vscode-tools__search_symbols_code",
|
|
8
|
+
"Bash(hatch run lint:*)",
|
|
9
|
+
"mcp__vscode-tools__get_document_symbols_code"
|
|
10
|
+
],
|
|
11
|
+
"deny": []
|
|
12
|
+
}
|
|
13
|
+
}
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
name: documentation
|
|
2
|
+
|
|
3
|
+
# build the documentation whenever there are new commits on main
|
|
4
|
+
on:
|
|
5
|
+
push:
|
|
6
|
+
branches:
|
|
7
|
+
- main
|
|
8
|
+
# Alternative: only build for tags.
|
|
9
|
+
# tags:
|
|
10
|
+
# - '*'
|
|
11
|
+
|
|
12
|
+
# security: restrict permissions for CI jobs.
|
|
13
|
+
permissions:
|
|
14
|
+
contents: read
|
|
15
|
+
|
|
16
|
+
jobs:
|
|
17
|
+
# Build the documentation and upload the static HTML files as an artifact.
|
|
18
|
+
build:
|
|
19
|
+
runs-on: ubuntu-latest
|
|
20
|
+
steps:
|
|
21
|
+
- uses: actions/checkout@v4
|
|
22
|
+
with:
|
|
23
|
+
persist-credentials: false
|
|
24
|
+
|
|
25
|
+
# Extract Python version from pyproject.toml
|
|
26
|
+
- name: Get Python version from pyproject.toml
|
|
27
|
+
id: python-version
|
|
28
|
+
run: |
|
|
29
|
+
python_version=$(grep -E "requires-python.*=" pyproject.toml | sed -E 's/.*">=([0-9.]+)".*/\1/')
|
|
30
|
+
echo "version=$python_version" >> $GITHUB_OUTPUT
|
|
31
|
+
|
|
32
|
+
- uses: actions/setup-python@v5
|
|
33
|
+
with:
|
|
34
|
+
python-version: ${{ steps.python-version.outputs.version }}
|
|
35
|
+
|
|
36
|
+
# Install Hatch
|
|
37
|
+
- run: pip install hatch
|
|
38
|
+
|
|
39
|
+
# Build documentation with MkDocs using dev environment
|
|
40
|
+
- run: hatch run mkdocs build
|
|
41
|
+
|
|
42
|
+
- uses: actions/upload-pages-artifact@v3
|
|
43
|
+
with:
|
|
44
|
+
path: site/
|
|
45
|
+
|
|
46
|
+
# Deploy the artifact to GitHub pages.
|
|
47
|
+
# This is a separate job so that only actions/deploy-pages has the necessary permissions.
|
|
48
|
+
deploy:
|
|
49
|
+
needs: build
|
|
50
|
+
runs-on: ubuntu-latest
|
|
51
|
+
permissions:
|
|
52
|
+
pages: write
|
|
53
|
+
id-token: write
|
|
54
|
+
environment:
|
|
55
|
+
name: github-pages
|
|
56
|
+
url: ${{ steps.deployment.outputs.page_url }}
|
|
57
|
+
steps:
|
|
58
|
+
- id: deployment
|
|
59
|
+
uses: actions/deploy-pages@v4
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
name: publish
|
|
2
|
+
|
|
3
|
+
# Publish to PyPI when a new release is created
|
|
4
|
+
on:
|
|
5
|
+
release:
|
|
6
|
+
types: [published]
|
|
7
|
+
|
|
8
|
+
# Security: restrict permissions for CI jobs
|
|
9
|
+
permissions:
|
|
10
|
+
contents: read
|
|
11
|
+
|
|
12
|
+
jobs:
|
|
13
|
+
# Build and test before publishing
|
|
14
|
+
test:
|
|
15
|
+
runs-on: ubuntu-latest
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/checkout@v4
|
|
18
|
+
with:
|
|
19
|
+
persist-credentials: false
|
|
20
|
+
|
|
21
|
+
- uses: actions/setup-python@v5
|
|
22
|
+
with:
|
|
23
|
+
python-version: '3.12'
|
|
24
|
+
cache: 'pip'
|
|
25
|
+
|
|
26
|
+
# Install hatch for environment management
|
|
27
|
+
- name: Install Hatch
|
|
28
|
+
uses: pypa/hatch@install
|
|
29
|
+
|
|
30
|
+
# Run code quality checks using hatch
|
|
31
|
+
- name: Run linting and formatting checks
|
|
32
|
+
run: hatch run lint:check
|
|
33
|
+
|
|
34
|
+
# Run tests using hatch
|
|
35
|
+
- name: Run tests
|
|
36
|
+
run: hatch run test
|
|
37
|
+
|
|
38
|
+
# Build the package
|
|
39
|
+
build:
|
|
40
|
+
needs: test
|
|
41
|
+
runs-on: ubuntu-latest
|
|
42
|
+
steps:
|
|
43
|
+
- uses: actions/checkout@v4
|
|
44
|
+
with:
|
|
45
|
+
persist-credentials: false
|
|
46
|
+
|
|
47
|
+
- uses: actions/setup-python@v5
|
|
48
|
+
with:
|
|
49
|
+
python-version: '3.12'
|
|
50
|
+
cache: 'pip'
|
|
51
|
+
|
|
52
|
+
# Install hatch for building
|
|
53
|
+
- name: Install Hatch
|
|
54
|
+
uses: pypa/hatch@install
|
|
55
|
+
|
|
56
|
+
# Build the package
|
|
57
|
+
- name: Build package
|
|
58
|
+
run: hatch build
|
|
59
|
+
|
|
60
|
+
# Upload build artifacts
|
|
61
|
+
- uses: actions/upload-artifact@v4
|
|
62
|
+
with:
|
|
63
|
+
name: dist-artifacts
|
|
64
|
+
path: dist/
|
|
65
|
+
|
|
66
|
+
# Publish to PyPI
|
|
67
|
+
publish:
|
|
68
|
+
needs: build
|
|
69
|
+
runs-on: ubuntu-latest
|
|
70
|
+
permissions:
|
|
71
|
+
id-token: write # Required for trusted publishing
|
|
72
|
+
environment:
|
|
73
|
+
name: pypi
|
|
74
|
+
url: https://pypi.org/p/numerax
|
|
75
|
+
steps:
|
|
76
|
+
- name: Download build artifacts
|
|
77
|
+
uses: actions/download-artifact@v4
|
|
78
|
+
with:
|
|
79
|
+
name: dist-artifacts
|
|
80
|
+
path: dist/
|
|
81
|
+
|
|
82
|
+
# Publish to PyPI using trusted publishing
|
|
83
|
+
- name: Publish to PyPI
|
|
84
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
name: tests
|
|
2
|
+
|
|
3
|
+
# Run tests on pushes to main and pull requests
|
|
4
|
+
on:
|
|
5
|
+
push:
|
|
6
|
+
branches:
|
|
7
|
+
- main
|
|
8
|
+
pull_request:
|
|
9
|
+
branches:
|
|
10
|
+
- main
|
|
11
|
+
|
|
12
|
+
# Security: restrict permissions for CI jobs
|
|
13
|
+
permissions:
|
|
14
|
+
contents: read
|
|
15
|
+
|
|
16
|
+
jobs:
|
|
17
|
+
test:
|
|
18
|
+
runs-on: ubuntu-latest
|
|
19
|
+
steps:
|
|
20
|
+
- uses: actions/checkout@v4
|
|
21
|
+
with:
|
|
22
|
+
persist-credentials: false
|
|
23
|
+
|
|
24
|
+
- uses: actions/setup-python@v5
|
|
25
|
+
with:
|
|
26
|
+
python-version: '3.12'
|
|
27
|
+
cache: 'pip'
|
|
28
|
+
|
|
29
|
+
# Install hatch for environment management
|
|
30
|
+
- name: Install Hatch
|
|
31
|
+
uses: pypa/hatch@install
|
|
32
|
+
|
|
33
|
+
# Run code quality checks using hatch
|
|
34
|
+
- name: Run linting and formatting checks
|
|
35
|
+
run: hatch run lint:check
|
|
36
|
+
|
|
37
|
+
# Run tests using hatch
|
|
38
|
+
- name: Run tests
|
|
39
|
+
run: hatch run test
|
numerax-0.1.0/.gitignore
ADDED
numerax-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Juehang Qin
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
numerax-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: numerax
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Project-URL: Documentation, https://github.com/juehang/numerax#readme
|
|
5
|
+
Project-URL: Issues, https://github.com/juehang/numerax/issues
|
|
6
|
+
Project-URL: Source, https://github.com/juehang/numerax
|
|
7
|
+
Author: Juehang Qin
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Programming Language :: Python
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
14
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
15
|
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
16
|
+
Requires-Python: >=3.12
|
|
17
|
+
Requires-Dist: jax
|
|
18
|
+
Requires-Dist: jaxtyping
|
|
19
|
+
Requires-Dist: optax
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
|
|
22
|
+
# numerax
|
|
23
|
+
|
|
24
|
+
[](https://github.com/juehang/numerax/actions/workflows/test.yml)
|
|
25
|
+
[](https://juehang.github.io/numerax/)
|
|
26
|
+
|
|
27
|
+
Statistical and numerical computation functions for JAX, focusing on tools not available in the main JAX API.
|
|
28
|
+
|
|
29
|
+
**[📖 Documentation](https://juehang.github.io/numerax/)**
|
|
30
|
+
|
|
31
|
+
## Installation
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
pip install numerax
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Features
|
|
38
|
+
|
|
39
|
+
### Special Functions
|
|
40
|
+
|
|
41
|
+
Inverse regularized incomplete gamma function with differentiability support:
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
import jax.numpy as jnp
|
|
45
|
+
import numerax
|
|
46
|
+
|
|
47
|
+
# Compute gamma quantiles (inverse CDF)
|
|
48
|
+
p = jnp.array([0.1, 0.5, 0.9]) # Probabilities
|
|
49
|
+
a = 2.0 # Shape parameter
|
|
50
|
+
|
|
51
|
+
x = numerax.special.gammap_inverse(p, a)
|
|
52
|
+
# Returns quantiles where gammainc(a, x) = p
|
|
53
|
+
|
|
54
|
+
# Fully differentiable with custom JVP
|
|
55
|
+
grad_fn = jax.grad(numerax.special.gammap_inverse)
|
|
56
|
+
dx_dp = grad_fn(0.5, 2.0) # Gradient with respect to probability
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
**Key features:**
|
|
60
|
+
- Halley's method for fast convergence
|
|
61
|
+
- Custom JVP implementation for exact gradients
|
|
62
|
+
- Numerical stability with adaptive precision
|
|
63
|
+
- Equivalent to gamma distribution inverse CDF
|
|
64
|
+
|
|
65
|
+
### Profile Likelihood
|
|
66
|
+
|
|
67
|
+
Efficient profile likelihood computation for statistical inference with nuisance parameters:
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
import jax.numpy as jnp
|
|
71
|
+
import numerax
|
|
72
|
+
|
|
73
|
+
# Example: Normal distribution with mean inference, variance profiling
|
|
74
|
+
def normal_llh(params, data):
|
|
75
|
+
mu, log_sigma = params
|
|
76
|
+
sigma = jnp.exp(log_sigma)
|
|
77
|
+
return jnp.sum(-0.5 * jnp.log(2 * jnp.pi) - log_sigma
|
|
78
|
+
- 0.5 * ((data - mu) / sigma) ** 2)
|
|
79
|
+
|
|
80
|
+
# Profile over log_sigma, infer mu
|
|
81
|
+
is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
|
|
82
|
+
|
|
83
|
+
def get_initial_log_sigma(data):
|
|
84
|
+
return jnp.array([jnp.log(jnp.std(data))])
|
|
85
|
+
|
|
86
|
+
profile_llh = numerax.stats.make_profile_llh(
|
|
87
|
+
normal_llh, is_nuisance, get_initial_log_sigma
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Evaluate profile likelihood
|
|
91
|
+
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1])
|
|
92
|
+
llh_val, opt_nuisance, diff, n_iter = profile_llh(jnp.array([1.0]), data)
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
**Key features:**
|
|
96
|
+
- JIT-compiled for performance
|
|
97
|
+
- L-BFGS optimization with convergence diagnostics
|
|
98
|
+
- Configurable tolerance and initial values
|
|
99
|
+
- Handles parameter masking automatically
|
|
100
|
+
|
|
101
|
+
### Utilities
|
|
102
|
+
|
|
103
|
+
Development utilities for creating JAX functions with custom derivatives while ensuring proper documentation support. Includes decorators for preserving function metadata when using JAX's advanced features.
|
|
104
|
+
|
|
105
|
+
## Requirements
|
|
106
|
+
|
|
107
|
+
- Python ≥ 3.12
|
|
108
|
+
- JAX
|
|
109
|
+
- jaxtyping
|
|
110
|
+
- optax
|
numerax-0.1.0/README.md
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# numerax
|
|
2
|
+
|
|
3
|
+
[](https://github.com/juehang/numerax/actions/workflows/test.yml)
|
|
4
|
+
[](https://juehang.github.io/numerax/)
|
|
5
|
+
|
|
6
|
+
Statistical and numerical computation functions for JAX, focusing on tools not available in the main JAX API.
|
|
7
|
+
|
|
8
|
+
**[📖 Documentation](https://juehang.github.io/numerax/)**
|
|
9
|
+
|
|
10
|
+
## Installation
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
pip install numerax
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
## Features
|
|
17
|
+
|
|
18
|
+
### Special Functions
|
|
19
|
+
|
|
20
|
+
Inverse regularized incomplete gamma function with differentiability support:
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
import jax.numpy as jnp
|
|
24
|
+
import numerax
|
|
25
|
+
|
|
26
|
+
# Compute gamma quantiles (inverse CDF)
|
|
27
|
+
p = jnp.array([0.1, 0.5, 0.9]) # Probabilities
|
|
28
|
+
a = 2.0 # Shape parameter
|
|
29
|
+
|
|
30
|
+
x = numerax.special.gammap_inverse(p, a)
|
|
31
|
+
# Returns quantiles where gammainc(a, x) = p
|
|
32
|
+
|
|
33
|
+
# Fully differentiable with custom JVP
|
|
34
|
+
grad_fn = jax.grad(numerax.special.gammap_inverse)
|
|
35
|
+
dx_dp = grad_fn(0.5, 2.0) # Gradient with respect to probability
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
**Key features:**
|
|
39
|
+
- Halley's method for fast convergence
|
|
40
|
+
- Custom JVP implementation for exact gradients
|
|
41
|
+
- Numerical stability with adaptive precision
|
|
42
|
+
- Equivalent to gamma distribution inverse CDF
|
|
43
|
+
|
|
44
|
+
### Profile Likelihood
|
|
45
|
+
|
|
46
|
+
Efficient profile likelihood computation for statistical inference with nuisance parameters:
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
import jax.numpy as jnp
|
|
50
|
+
import numerax
|
|
51
|
+
|
|
52
|
+
# Example: Normal distribution with mean inference, variance profiling
|
|
53
|
+
def normal_llh(params, data):
|
|
54
|
+
mu, log_sigma = params
|
|
55
|
+
sigma = jnp.exp(log_sigma)
|
|
56
|
+
return jnp.sum(-0.5 * jnp.log(2 * jnp.pi) - log_sigma
|
|
57
|
+
- 0.5 * ((data - mu) / sigma) ** 2)
|
|
58
|
+
|
|
59
|
+
# Profile over log_sigma, infer mu
|
|
60
|
+
is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
|
|
61
|
+
|
|
62
|
+
def get_initial_log_sigma(data):
|
|
63
|
+
return jnp.array([jnp.log(jnp.std(data))])
|
|
64
|
+
|
|
65
|
+
profile_llh = numerax.stats.make_profile_llh(
|
|
66
|
+
normal_llh, is_nuisance, get_initial_log_sigma
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Evaluate profile likelihood
|
|
70
|
+
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1])
|
|
71
|
+
llh_val, opt_nuisance, diff, n_iter = profile_llh(jnp.array([1.0]), data)
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
**Key features:**
|
|
75
|
+
- JIT-compiled for performance
|
|
76
|
+
- L-BFGS optimization with convergence diagnostics
|
|
77
|
+
- Configurable tolerance and initial values
|
|
78
|
+
- Handles parameter masking automatically
|
|
79
|
+
|
|
80
|
+
### Utilities
|
|
81
|
+
|
|
82
|
+
Development utilities for creating JAX functions with custom derivatives while ensuring proper documentation support. Includes decorators for preserving function metadata when using JAX's advanced features.
|
|
83
|
+
|
|
84
|
+
## Requirements
|
|
85
|
+
|
|
86
|
+
- Python ≥ 3.12
|
|
87
|
+
- JAX
|
|
88
|
+
- jaxtyping
|
|
89
|
+
- optax
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# API Reference
|
|
2
|
+
|
|
3
|
+
Comprehensive API documentation for all Numerax modules.
|
|
4
|
+
|
|
5
|
+
## Modules
|
|
6
|
+
|
|
7
|
+
- **[Special Functions](special.md)** - Mathematical special functions with custom derivatives
|
|
8
|
+
- **[Statistics](stats.md)** - Advanced statistical computation tools
|
|
9
|
+
- **[Utilities](utils.md)** - Development utilities for JAX functions
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
/* Custom styles for MkDocs Material theme */
|
|
2
|
+
|
|
3
|
+
/* Improve math rendering */
|
|
4
|
+
.arithmatex {
|
|
5
|
+
padding: 0.5em;
|
|
6
|
+
margin: 1em 0;
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
/* Better code block styling */
|
|
10
|
+
.highlight pre {
|
|
11
|
+
padding: 1em;
|
|
12
|
+
overflow-x: auto;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
/* Improve navigation */
|
|
16
|
+
.md-nav__item .md-nav__link--active {
|
|
17
|
+
font-weight: bold;
|
|
18
|
+
}
|
numerax-0.1.0/mkdocs.yml
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
site_name: Numerax
|
|
2
|
+
site_description: Statistical and numerical computation functions for JAX
|
|
3
|
+
site_url: https://juehang.github.io/numerax/
|
|
4
|
+
repo_url: https://github.com/juehang/numerax
|
|
5
|
+
repo_name: juehang/numerax
|
|
6
|
+
|
|
7
|
+
theme:
|
|
8
|
+
name: material
|
|
9
|
+
features:
|
|
10
|
+
- navigation.tabs
|
|
11
|
+
- navigation.sections
|
|
12
|
+
- navigation.expand
|
|
13
|
+
- navigation.top
|
|
14
|
+
- search.highlight
|
|
15
|
+
- search.share
|
|
16
|
+
- content.code.copy
|
|
17
|
+
palette:
|
|
18
|
+
- scheme: default
|
|
19
|
+
primary: blue
|
|
20
|
+
accent: blue
|
|
21
|
+
toggle:
|
|
22
|
+
icon: material/brightness-7
|
|
23
|
+
name: Switch to dark mode
|
|
24
|
+
- scheme: slate
|
|
25
|
+
primary: blue
|
|
26
|
+
accent: blue
|
|
27
|
+
toggle:
|
|
28
|
+
icon: material/brightness-4
|
|
29
|
+
name: Switch to light mode
|
|
30
|
+
|
|
31
|
+
plugins:
|
|
32
|
+
- search
|
|
33
|
+
- mkdocstrings:
|
|
34
|
+
handlers:
|
|
35
|
+
python:
|
|
36
|
+
options:
|
|
37
|
+
docstring_style: google
|
|
38
|
+
show_source: false
|
|
39
|
+
show_root_heading: true
|
|
40
|
+
show_root_toc_entry: false
|
|
41
|
+
merge_init_into_class: true
|
|
42
|
+
show_signature_annotations: true
|
|
43
|
+
separate_signature: true
|
|
44
|
+
|
|
45
|
+
markdown_extensions:
|
|
46
|
+
- pymdownx.arithmatex:
|
|
47
|
+
generic: true
|
|
48
|
+
- pymdownx.highlight:
|
|
49
|
+
anchor_linenums: true
|
|
50
|
+
- pymdownx.inlinehilite
|
|
51
|
+
- pymdownx.snippets
|
|
52
|
+
- pymdownx.superfences
|
|
53
|
+
- admonition
|
|
54
|
+
- pymdownx.details
|
|
55
|
+
- pymdownx.tabbed:
|
|
56
|
+
alternate_style: true
|
|
57
|
+
- toc:
|
|
58
|
+
permalink: true
|
|
59
|
+
|
|
60
|
+
extra_javascript:
|
|
61
|
+
- https://polyfill.io/v3/polyfill.min.js?features=es6
|
|
62
|
+
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
|
|
63
|
+
|
|
64
|
+
extra_css:
|
|
65
|
+
- stylesheets/extra.css
|
|
66
|
+
|
|
67
|
+
nav:
|
|
68
|
+
- Home: index.md
|
|
69
|
+
- API Reference:
|
|
70
|
+
- Overview: api/index.md
|
|
71
|
+
- Special Functions: api/special.md
|
|
72
|
+
- Statistics: api/stats.md
|
|
73
|
+
- Utilities: api/utils.md
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "numerax"
|
|
7
|
+
dynamic = ["version"]
|
|
8
|
+
description = ""
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
license = "MIT"
|
|
12
|
+
keywords = []
|
|
13
|
+
authors = [
|
|
14
|
+
{ name = "Juehang Qin"},
|
|
15
|
+
]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Development Status :: 4 - Beta",
|
|
18
|
+
"Programming Language :: Python",
|
|
19
|
+
"Programming Language :: Python :: 3.12",
|
|
20
|
+
"Programming Language :: Python :: 3.13",
|
|
21
|
+
"Programming Language :: Python :: Implementation :: CPython",
|
|
22
|
+
"Programming Language :: Python :: Implementation :: PyPy",
|
|
23
|
+
]
|
|
24
|
+
dependencies = [
|
|
25
|
+
"jax",
|
|
26
|
+
"jaxtyping",
|
|
27
|
+
"optax"
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.urls]
|
|
31
|
+
Documentation = "https://github.com/juehang/numerax#readme"
|
|
32
|
+
Issues = "https://github.com/juehang/numerax/issues"
|
|
33
|
+
Source = "https://github.com/juehang/numerax"
|
|
34
|
+
|
|
35
|
+
[tool.hatch.version]
|
|
36
|
+
path = "src/numerax/__init__.py"
|
|
37
|
+
|
|
38
|
+
[tool.hatch.envs.default]
|
|
39
|
+
dependencies = [
|
|
40
|
+
"coverage[toml]>=6.5",
|
|
41
|
+
"pytest",
|
|
42
|
+
"ruff",
|
|
43
|
+
"mkdocs-material",
|
|
44
|
+
"mkdocstrings[python]"
|
|
45
|
+
]
|
|
46
|
+
[tool.hatch.envs.default.scripts]
|
|
47
|
+
test = "pytest {args:tests}"
|
|
48
|
+
test-cov = "coverage run -m pytest {args:tests}"
|
|
49
|
+
cov-report = [
|
|
50
|
+
"- coverage combine",
|
|
51
|
+
"coverage report",
|
|
52
|
+
]
|
|
53
|
+
cov = [
|
|
54
|
+
"test-cov",
|
|
55
|
+
"cov-report",
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
[[tool.hatch.envs.all.matrix]]
|
|
59
|
+
python = ["3.12"]
|
|
60
|
+
|
|
61
|
+
[tool.hatch.envs.lint]
|
|
62
|
+
detached = true
|
|
63
|
+
dependencies = [
|
|
64
|
+
"ruff>=0.0.243",
|
|
65
|
+
]
|
|
66
|
+
[tool.hatch.envs.lint.scripts]
|
|
67
|
+
check = [
|
|
68
|
+
"ruff check {args:.}",
|
|
69
|
+
"ruff format --check --diff {args:.}",
|
|
70
|
+
]
|
|
71
|
+
fmt = [
|
|
72
|
+
"ruff format {args:.}",
|
|
73
|
+
"ruff check --fix {args:.}",
|
|
74
|
+
"check",
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
[tool.coverage.run]
|
|
78
|
+
source_pkgs = ["numerax", "tests"]
|
|
79
|
+
branch = true
|
|
80
|
+
parallel = true
|
|
81
|
+
omit = [
|
|
82
|
+
"src/numerax/__about__.py",
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
[tool.coverage.paths]
|
|
86
|
+
numerax = ["src/numerax", "*/numerax/src/numerax"]
|
|
87
|
+
tests = ["tests", "*/numerax/tests"]
|
|
88
|
+
|
|
89
|
+
[tool.coverage.report]
|
|
90
|
+
exclude_lines = [
|
|
91
|
+
"no cov",
|
|
92
|
+
"if __name__ == .__main__.:",
|
|
93
|
+
"if TYPE_CHECKING:",
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
[tool.ruff]
|
|
97
|
+
line-length = 79
|
|
98
|
+
target-version = "py312"
|
|
99
|
+
extend-exclude = ["old_code"]
|
|
100
|
+
|
|
101
|
+
[tool.ruff.lint]
|
|
102
|
+
select = [
|
|
103
|
+
"E", # pycodestyle errors
|
|
104
|
+
"F", # pyflakes
|
|
105
|
+
"UP", # pyupgrade
|
|
106
|
+
"B", # flake8-bugbear
|
|
107
|
+
"SIM", # flake8-simplify
|
|
108
|
+
"I", # isort
|
|
109
|
+
"ASYNC", # flake8-async
|
|
110
|
+
"S", # flake8-bandit
|
|
111
|
+
"ARG", # flake8-unused-arguments
|
|
112
|
+
"Q", # flake8-quotes
|
|
113
|
+
"SIM", # flake8-simplify
|
|
114
|
+
"NPY", # numpy-specific rules
|
|
115
|
+
"PD", # pandas-specific rules
|
|
116
|
+
"N", # pep8-naming
|
|
117
|
+
"W", # warning
|
|
118
|
+
"PLC", # pylint convention
|
|
119
|
+
"PLE", # pylint error
|
|
120
|
+
"PLW", # pylint warning
|
|
121
|
+
"RUF", # ruff-specific rules
|
|
122
|
+
]
|
|
123
|
+
ignore = ["COM812"] # Ignore trailing comma rule that conflicts with formatter
|
|
124
|
+
|
|
125
|
+
[tool.ruff.lint.per-file-ignores]
|
|
126
|
+
"tests/**" = ["S101"] # Allow assert statements in tests
|
|
127
|
+
|
|
128
|
+
# Enable Ruff's formatter
|
|
129
|
+
[tool.ruff.format]
|
|
130
|
+
quote-style = "double"
|
|
131
|
+
indent-style = "space"
|
|
132
|
+
line-ending = "auto"
|
|
133
|
+
docstring-code-format = true
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Statistical and numerical computation functions for JAX, focusing on tools
|
|
3
|
+
not available in the main JAX API.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
This package provides JAX-compatible implementations of specialized numerical
|
|
8
|
+
functions with full differentiability support. All functions are designed to
|
|
9
|
+
work seamlessly with JAX's transformations (JIT, grad, vmap, etc.) and follow
|
|
10
|
+
JAX's functional programming paradigms.
|
|
11
|
+
|
|
12
|
+
### Special Functions (`numerax.special`)
|
|
13
|
+
|
|
14
|
+
Mathematical special functions with custom derivative implementations.
|
|
15
|
+
Functions use numerically stable algorithms and provide exact gradients
|
|
16
|
+
through custom JVP rules where standard automatic differentiation would
|
|
17
|
+
be inefficient or unstable.
|
|
18
|
+
|
|
19
|
+
### Statistical Methods (`numerax.stats`)
|
|
20
|
+
|
|
21
|
+
Advanced statistical computation tools for inference problems. Implements
|
|
22
|
+
efficient algorithms for complex statistical models, with particular focus
|
|
23
|
+
on optimization-based methods that benefit from JAX's compilation and
|
|
24
|
+
differentiation capabilities.
|
|
25
|
+
|
|
26
|
+
### Utilities (`numerax.utils`)
|
|
27
|
+
|
|
28
|
+
Development utilities for creating JAX-compatible functions with proper
|
|
29
|
+
documentation support. Includes decorators and helpers for preserving
|
|
30
|
+
function metadata when using JAX's advanced features like custom derivatives.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from . import special, stats, utils
|
|
34
|
+
|
|
35
|
+
__version__ = "0.1.0"
|
|
36
|
+
|
|
37
|
+
__all__ = ["special", "stats", "utils"]
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import jax.scipy.special as special
|
|
4
|
+
from jaxtyping import ArrayLike
|
|
5
|
+
|
|
6
|
+
# Global constants for numerical stability - adapt to JAX precision setting
|
|
7
|
+
_DTYPE = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32
|
|
8
|
+
TINY = jnp.finfo(_DTYPE).smallest_normal # For preventing underflow
|
|
9
|
+
EPS = jnp.finfo(_DTYPE).eps # For convergence tolerance (machine epsilon)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@jax.custom_jvp
|
|
13
|
+
def gammap_inverse(p: ArrayLike, a: float) -> ArrayLike:
|
|
14
|
+
r"""
|
|
15
|
+
Inverse of the regularized incomplete gamma function.
|
|
16
|
+
|
|
17
|
+
## Overview
|
|
18
|
+
|
|
19
|
+
Computes the inverse of the regularized incomplete gamma function, finding
|
|
20
|
+
$x$ such that $P(a, x) = p$, where $P(a, x)$ is the regularized incomplete
|
|
21
|
+
gamma function. This is equivalent to computing quantiles of the
|
|
22
|
+
$\text{Gamma}(a, 1)$ distribution. The general strategy and the initial
|
|
23
|
+
guess are based on the methods described in
|
|
24
|
+
Numerical Recipes (Press et al., 2007).
|
|
25
|
+
|
|
26
|
+
## Mathematical Background
|
|
27
|
+
|
|
28
|
+
The regularized incomplete gamma function is defined as:
|
|
29
|
+
|
|
30
|
+
$$P(a, x) = \frac{\gamma(a, x)}{\Gamma(a)} = \frac{1}{\Gamma(a)}
|
|
31
|
+
\int_0^x t^{a-1} e^{-t} dt$$
|
|
32
|
+
|
|
33
|
+
This function solves the inverse problem:
|
|
34
|
+
|
|
35
|
+
$$x = P^{-1}(a, p) \quad \text{such that} \quad P(a, x) = p$$
|
|
36
|
+
|
|
37
|
+
For a random variable $X \sim \text{Gamma}(a, 1)$, this gives:
|
|
38
|
+
|
|
39
|
+
$$x = F^{-1}(p) \quad \text{where} \quad F(x) = P(\Gamma(a), x)$$
|
|
40
|
+
|
|
41
|
+
## Numerical Method
|
|
42
|
+
|
|
43
|
+
Uses Halley's method for fast quadratic convergence:
|
|
44
|
+
|
|
45
|
+
$$x_{n+1} = x_n - \frac{2f(x_n)f'(x_n)}{2[f'(x_n)]^2 - f(x_n)f''(x_n)}$$
|
|
46
|
+
|
|
47
|
+
where $f(x) = P(a, x) - p$.
|
|
48
|
+
|
|
49
|
+
**Initial guess** based on Numerical Recipes (Press et al., 2007):
|
|
50
|
+
- For $a > 1$: Wilson-Hilferty approximation
|
|
51
|
+
- For $a \leq 1$: Asymptotic expansions
|
|
52
|
+
|
|
53
|
+
## Args
|
|
54
|
+
|
|
55
|
+
- **p**: Probability values in $[0, 1]$. Can be scalar or array.
|
|
56
|
+
- **a**: Shape parameter (must be positive). Scalar value.
|
|
57
|
+
|
|
58
|
+
## Returns
|
|
59
|
+
|
|
60
|
+
Quantiles $x$ where $P(a, x) = p$. Same shape as input `p`.
|
|
61
|
+
|
|
62
|
+
## Example
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
import jax.numpy as jnp
|
|
66
|
+
import numerax
|
|
67
|
+
|
|
68
|
+
# Single quantile
|
|
69
|
+
x = numerax.special.gammap_inverse(0.5, 2.0) # Median of Gamma(2, 1)
|
|
70
|
+
|
|
71
|
+
# Multiple quantiles
|
|
72
|
+
p_vals = jnp.array([0.1, 0.25, 0.5, 0.75, 0.9])
|
|
73
|
+
x_vals = numerax.special.gammap_inverse(p_vals, 3.0)
|
|
74
|
+
|
|
75
|
+
# Verify inverse relationship
|
|
76
|
+
from jax.scipy.special import gammainc
|
|
77
|
+
|
|
78
|
+
p_recovered = gammainc(2.0, x) # Should equal original p
|
|
79
|
+
|
|
80
|
+
# Differentiable for optimization
|
|
81
|
+
grad_fn = jax.grad(numerax.special.gammap_inverse)
|
|
82
|
+
sensitivity = grad_fn(0.5, 2.0) # ∂x/∂p at median
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
## Notes
|
|
86
|
+
|
|
87
|
+
- **Convergence**: Typically converges in 3-8 iterations
|
|
88
|
+
- **Differentiable**: Custom JVP implementation using implicit function
|
|
89
|
+
theorem
|
|
90
|
+
- **Numerical stability**: Handles edge cases near 0 and 1
|
|
91
|
+
- **Performance**: JIT-compiled with adaptive precision
|
|
92
|
+
- **Domain**: $p \in [0, 1]$ and $a > 0$
|
|
93
|
+
|
|
94
|
+
## References
|
|
95
|
+
|
|
96
|
+
Press, W. H., Teukolsky, S. A., Vetterling, W. T., & Flannery, B. P.
|
|
97
|
+
(2007).
|
|
98
|
+
*Numerical Recipes: The Art of Scientific Computing* (3rd ed.).
|
|
99
|
+
Cambridge University Press.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def objective(x):
|
|
103
|
+
"""F(x) = gammainc(a, x) - p"""
|
|
104
|
+
return special.gammainc(a, x) - p
|
|
105
|
+
|
|
106
|
+
# Initial guess from Numerical Recipes
|
|
107
|
+
def initial_guess(u_val, a_val):
|
|
108
|
+
# a = dof/2 for chi-squared
|
|
109
|
+
|
|
110
|
+
def large_a_guess():
|
|
111
|
+
# For a > 1: use Wilson-Hilferty approximation
|
|
112
|
+
pp = jnp.where(u_val < 0.5, u_val, 1.0 - u_val)
|
|
113
|
+
t = jnp.sqrt(-2.0 * jnp.log(pp))
|
|
114
|
+
x = (2.30753 + t * 0.27061) / (
|
|
115
|
+
1.0 + t * (0.99229 + t * 0.04481)
|
|
116
|
+
) - t
|
|
117
|
+
x = jnp.where(u_val < 0.5, -x, x)
|
|
118
|
+
return jnp.fmax(
|
|
119
|
+
1e-3,
|
|
120
|
+
a_val
|
|
121
|
+
* (1.0 - 1.0 / (9.0 * a_val) - x / (3.0 * jnp.sqrt(a_val)))
|
|
122
|
+
** 3,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def small_a_guess():
|
|
126
|
+
# For a <= 1: use equations (6.2.8) and (6.2.9)
|
|
127
|
+
t = 1.0 - a_val * (0.253 + a_val * 0.12)
|
|
128
|
+
return jnp.where(
|
|
129
|
+
u_val < t,
|
|
130
|
+
(u_val / t) ** (1.0 / a_val),
|
|
131
|
+
1.0 - jnp.log(1.0 - (u_val - t) / (1.0 - t)),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return jnp.real(
|
|
135
|
+
jnp.where(a_val > 1.0, large_a_guess(), small_a_guess())
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Derivatives for Halley's method
|
|
139
|
+
f = objective
|
|
140
|
+
df_dx = jax.grad(objective)
|
|
141
|
+
d2f_dx2 = jax.grad(df_dx)
|
|
142
|
+
|
|
143
|
+
x = initial_guess(p, a)
|
|
144
|
+
|
|
145
|
+
# Use while_loop for dynamic convergence
|
|
146
|
+
def cond_fn(state):
|
|
147
|
+
x, step, iteration = state
|
|
148
|
+
# Continue while step is large and we haven't exceeded max iterations
|
|
149
|
+
return (jnp.abs(step) > EPS * jnp.abs(x)) & (iteration < 12)
|
|
150
|
+
|
|
151
|
+
def body_fn(state):
|
|
152
|
+
x, _, iteration = state
|
|
153
|
+
|
|
154
|
+
f_val = f(x)
|
|
155
|
+
df_val = df_dx(x)
|
|
156
|
+
d2f_val = d2f_dx2(x)
|
|
157
|
+
|
|
158
|
+
# Halley's method: x_{n+1} = x_n - 2*f*f' / (2*f'^2 - f*f'')
|
|
159
|
+
numerator = 2 * f_val * df_val
|
|
160
|
+
denominator = 2 * df_val**2 - f_val * d2f_val
|
|
161
|
+
|
|
162
|
+
# Avoid division by zero and ensure step is reasonable
|
|
163
|
+
denominator = jnp.where(
|
|
164
|
+
jnp.abs(denominator) < TINY,
|
|
165
|
+
jnp.sign(denominator) * TINY,
|
|
166
|
+
denominator,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
step = numerator / denominator
|
|
170
|
+
x_new = x - step
|
|
171
|
+
|
|
172
|
+
# Ensure x stays positive
|
|
173
|
+
x_new = jnp.fmax(x_new, TINY)
|
|
174
|
+
|
|
175
|
+
return (x_new, step, iteration + 1)
|
|
176
|
+
|
|
177
|
+
# Initial state: (x, step, iteration)
|
|
178
|
+
initial_state = (x, jnp.inf, 0)
|
|
179
|
+
final_state = jax.lax.while_loop(cond_fn, body_fn, initial_state)
|
|
180
|
+
x = final_state[0]
|
|
181
|
+
|
|
182
|
+
return x
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@gammap_inverse.defjvp
|
|
186
|
+
def gammap_inverse_jvp(primals, tangents):
|
|
187
|
+
"""
|
|
188
|
+
Custom JVP for gammap_inverse using implicit function theorem.
|
|
189
|
+
|
|
190
|
+
For F(x, p) = gammainc(a, x) - p = 0:
|
|
191
|
+
dx/dp = -∂F/∂p / ∂F/∂x = 1 / (∂/∂x gammainc(a, x))
|
|
192
|
+
"""
|
|
193
|
+
p, a = primals
|
|
194
|
+
p_dot, _ = tangents
|
|
195
|
+
|
|
196
|
+
# Forward pass
|
|
197
|
+
x = gammap_inverse(p, a)
|
|
198
|
+
|
|
199
|
+
# Compute derivative: dx/dp = 1 / (d/dx gammainc(a, x))
|
|
200
|
+
def gammainc_x(x_val):
|
|
201
|
+
return special.gammainc(a, x_val)
|
|
202
|
+
|
|
203
|
+
dgammainc_dx = jax.grad(gammainc_x)(x)
|
|
204
|
+
|
|
205
|
+
# Avoid division by zero
|
|
206
|
+
dgammainc_dx = jnp.where(
|
|
207
|
+
jnp.abs(dgammainc_dx) < TINY,
|
|
208
|
+
jnp.sign(dgammainc_dx) * TINY,
|
|
209
|
+
dgammainc_dx,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
dx_dp = 1.0 / dgammainc_dx
|
|
213
|
+
|
|
214
|
+
# For now, ignore a derivatives (could be added if needed)
|
|
215
|
+
x_dot = dx_dp * p_dot
|
|
216
|
+
|
|
217
|
+
return x, x_dot
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Profile likelihood functions for statistical inference."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import optax
|
|
8
|
+
from optax import lbfgs
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def make_profile_llh(
|
|
12
|
+
llh_fn: Callable,
|
|
13
|
+
is_nuisance: list[bool] | jnp.ndarray,
|
|
14
|
+
get_initial_nuisance: Callable,
|
|
15
|
+
tol: float = 1e-6,
|
|
16
|
+
initial_value: float = 1e-9,
|
|
17
|
+
initial_diff: float = 1e9,
|
|
18
|
+
) -> Callable:
|
|
19
|
+
r"""
|
|
20
|
+
Factory function for creating profile likelihood functions.
|
|
21
|
+
|
|
22
|
+
## Overview
|
|
23
|
+
|
|
24
|
+
Profile likelihood is a statistical technique used when dealing with
|
|
25
|
+
nuisance parameters that are not of primary interest but are necessary
|
|
26
|
+
for the model. This function creates an optimized profile likelihood
|
|
27
|
+
that maximizes over nuisance parameters while keeping inference
|
|
28
|
+
parameters fixed.
|
|
29
|
+
|
|
30
|
+
## Mathematical Background
|
|
31
|
+
|
|
32
|
+
Given a likelihood function $L(\boldsymbol{\theta}, \boldsymbol{\lambda})$
|
|
33
|
+
where $\boldsymbol{\theta}$ are parameters of interest and
|
|
34
|
+
$\boldsymbol{\lambda}$ are nuisance parameters, the profile likelihood is:
|
|
35
|
+
|
|
36
|
+
$$L_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}}
|
|
37
|
+
L(\boldsymbol{\theta}, \boldsymbol{\lambda})$$
|
|
38
|
+
|
|
39
|
+
In practice, we work with the log-likelihood
|
|
40
|
+
$\ell(\boldsymbol{\theta}, \boldsymbol{\lambda}) =
|
|
41
|
+
\log L(\boldsymbol{\theta}, \boldsymbol{\lambda})$:
|
|
42
|
+
|
|
43
|
+
$$\ell_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}}
|
|
44
|
+
\ell(\boldsymbol{\theta}, \boldsymbol{\lambda})$$
|
|
45
|
+
|
|
46
|
+
This function uses L-BFGS optimization to find the maximum likelihood
|
|
47
|
+
estimates of nuisance parameters for each fixed value of inference
|
|
48
|
+
parameters.
|
|
49
|
+
|
|
50
|
+
## Args
|
|
51
|
+
|
|
52
|
+
- **llh_fn**: Log likelihood function taking (params, *args) and
|
|
53
|
+
returning scalar log likelihood value
|
|
54
|
+
- **is_nuisance**: Boolean array where True indicates nuisance
|
|
55
|
+
parameters and False indicates inference parameters
|
|
56
|
+
- **get_initial_nuisance**: Function taking (*args) and returning
|
|
57
|
+
initial values for nuisance parameters
|
|
58
|
+
- **tol**: Convergence tolerance for the optimization (default: 1e-6)
|
|
59
|
+
- **initial_value**: Initial objective value for convergence tracking
|
|
60
|
+
(default: 1e-9)
|
|
61
|
+
- **initial_diff**: Initial difference for convergence tracking
|
|
62
|
+
(default: 1e9)
|
|
63
|
+
|
|
64
|
+
## Returns
|
|
65
|
+
|
|
66
|
+
Profile likelihood function with signature:
|
|
67
|
+
`(inference_values, *args) -> (profile_llh_value, optimal_nuisance,
|
|
68
|
+
convergence_diff, num_iterations)`
|
|
69
|
+
|
|
70
|
+
## Example
|
|
71
|
+
|
|
72
|
+
Consider fitting a normal distribution where we want to infer the mean
|
|
73
|
+
$\mu$ but treat the variance $\sigma^2$ as a nuisance parameter:
|
|
74
|
+
|
|
75
|
+
```python
|
|
76
|
+
import jax.numpy as jnp
|
|
77
|
+
import numerax
|
|
78
|
+
|
|
79
|
+
# Sample data
|
|
80
|
+
data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1, 1.3, 0.7, 1.4])
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Log likelihood for normal distribution
|
|
84
|
+
def normal_llh(params, data):
|
|
85
|
+
mu, log_sigma = params # Use log(sigma) for numerical stability
|
|
86
|
+
sigma = jnp.exp(log_sigma)
|
|
87
|
+
return jnp.sum(
|
|
88
|
+
-0.5 * jnp.log(2 * jnp.pi)
|
|
89
|
+
- log_sigma
|
|
90
|
+
- 0.5 * ((data - mu) / sigma) ** 2
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Profile over log_sigma (nuisance), infer mu
|
|
95
|
+
is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_initial_log_sigma(data):
|
|
99
|
+
# Initialize with log of sample standard deviation
|
|
100
|
+
return jnp.array([jnp.log(jnp.std(data))])
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
profile_llh = numerax.stats.make_profile_llh(
|
|
104
|
+
normal_llh, is_nuisance, get_initial_log_sigma
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Evaluate profile likelihood at different mu values
|
|
108
|
+
mu_test = 1.0
|
|
109
|
+
llh_val, opt_log_sigma, diff, n_iter = profile_llh(
|
|
110
|
+
jnp.array([mu_test]), data
|
|
111
|
+
)
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
## Notes
|
|
115
|
+
|
|
116
|
+
- The function is JIT-compiled for performance
|
|
117
|
+
- Uses L-BFGS optimization which is well-suited for smooth likelihood
|
|
118
|
+
surfaces
|
|
119
|
+
- Returns convergence information for diagnostics
|
|
120
|
+
- Handles parameter masking automatically
|
|
121
|
+
- Consider using log-parameterization for positive parameters
|
|
122
|
+
(e.g., $\log \sigma$) to improve numerical stability
|
|
123
|
+
"""
|
|
124
|
+
nuisance_mask = jnp.array(is_nuisance)
|
|
125
|
+
inference_mask = ~nuisance_mask
|
|
126
|
+
|
|
127
|
+
@jax.jit
|
|
128
|
+
def profile_llh(inference_values, *args):
|
|
129
|
+
solver = lbfgs()
|
|
130
|
+
initial_nuisance = get_initial_nuisance(*args)
|
|
131
|
+
opt_state = solver.init(initial_nuisance)
|
|
132
|
+
|
|
133
|
+
def objective(nuisance_params):
|
|
134
|
+
# Reconstruct full parameter vector
|
|
135
|
+
full_params = jnp.zeros(len(nuisance_mask))
|
|
136
|
+
full_params = full_params.at[inference_mask].set(inference_values)
|
|
137
|
+
full_params = full_params.at[nuisance_mask].set(nuisance_params)
|
|
138
|
+
return -llh_fn(full_params, *args)
|
|
139
|
+
|
|
140
|
+
value_and_grad = optax.value_and_grad_from_state(objective)
|
|
141
|
+
|
|
142
|
+
def profile_llh_loopfun(var):
|
|
143
|
+
params, last_value, opt_state, _, n = var
|
|
144
|
+
value, grad = value_and_grad(params, state=opt_state)
|
|
145
|
+
updates, opt_state = solver.update(
|
|
146
|
+
grad,
|
|
147
|
+
opt_state,
|
|
148
|
+
params,
|
|
149
|
+
value=value,
|
|
150
|
+
grad=grad,
|
|
151
|
+
value_fn=objective,
|
|
152
|
+
)
|
|
153
|
+
params = optax.apply_updates(params, updates)
|
|
154
|
+
diff = last_value - value
|
|
155
|
+
return params, value, opt_state, diff, n + 1
|
|
156
|
+
|
|
157
|
+
params, value, opt_state, diff, n = jax.lax.while_loop(
|
|
158
|
+
lambda var: jnp.abs(var[-2]) > jnp.abs(var[1] * tol),
|
|
159
|
+
profile_llh_loopfun,
|
|
160
|
+
(initial_nuisance, initial_value, opt_state, initial_diff, 0),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return -value, params, diff, n
|
|
164
|
+
|
|
165
|
+
return profile_llh
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Utility functions for the numerax package."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import TypeVar
|
|
6
|
+
|
|
7
|
+
F = TypeVar("F", bound=Callable)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def preserve_metadata(decorator):
|
|
11
|
+
"""
|
|
12
|
+
Wrapper that ensures a decorator preserves function metadata for
|
|
13
|
+
documentation tools.
|
|
14
|
+
|
|
15
|
+
## Overview
|
|
16
|
+
|
|
17
|
+
This is particularly useful for JAX decorators like `@custom_jvp` that
|
|
18
|
+
create special objects which may not preserve `__doc__` and other metadata
|
|
19
|
+
properly for documentation generators like pdoc.
|
|
20
|
+
|
|
21
|
+
## Args
|
|
22
|
+
|
|
23
|
+
- **decorator**: The decorator function to wrap
|
|
24
|
+
|
|
25
|
+
## Returns
|
|
26
|
+
|
|
27
|
+
A new decorator that preserves metadata
|
|
28
|
+
|
|
29
|
+
## Example
|
|
30
|
+
|
|
31
|
+
```python
|
|
32
|
+
import jax
|
|
33
|
+
from numerax.utils import preserve_metadata
|
|
34
|
+
|
|
35
|
+
@preserve_metadata(jax.custom_jvp)
|
|
36
|
+
def my_function(x):
|
|
37
|
+
\"\"\"This docstring will be preserved for pdoc.\"\"\"
|
|
38
|
+
return x
|
|
39
|
+
```
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def metadata_preserving_decorator(func: F) -> F:
|
|
43
|
+
# Apply the original decorator
|
|
44
|
+
decorated = decorator(func)
|
|
45
|
+
# Ensure metadata is preserved using functools.wraps pattern
|
|
46
|
+
return functools.wraps(func)(decorated)
|
|
47
|
+
|
|
48
|
+
return metadata_preserving_decorator
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Test configuration for numerax."""
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test suite for gamma special functions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import pytest
|
|
8
|
+
from jax.scipy import special
|
|
9
|
+
|
|
10
|
+
from numerax.special import gammap_inverse
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.mark.parametrize(
|
|
14
|
+
"p,a",
|
|
15
|
+
[
|
|
16
|
+
(0.1, 0.5),
|
|
17
|
+
(0.3, 1.0),
|
|
18
|
+
(0.5, 1.5),
|
|
19
|
+
(0.8, 2.0),
|
|
20
|
+
(0.95, 3.0),
|
|
21
|
+
],
|
|
22
|
+
)
|
|
23
|
+
def test_gamma_inverse_correctness(p, a):
|
|
24
|
+
"""Test that gammap_inverse correctly inverts gammainc."""
|
|
25
|
+
result = special.gammainc(a, gammap_inverse(p, a))
|
|
26
|
+
assert jnp.abs(result - p) < 1e-6
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.mark.parametrize(
|
|
30
|
+
"p,a",
|
|
31
|
+
[
|
|
32
|
+
(0.2, 0.5),
|
|
33
|
+
(0.4, 1.0),
|
|
34
|
+
(0.6, 1.5),
|
|
35
|
+
(0.8, 2.0),
|
|
36
|
+
(0.9, 2.5),
|
|
37
|
+
],
|
|
38
|
+
)
|
|
39
|
+
def test_gamma_inverse_gradients(p, a):
|
|
40
|
+
"""Test that gradients of gammap_inverse are computed correctly."""
|
|
41
|
+
# Compute x such that gammainc(a, x) = p
|
|
42
|
+
x = gammap_inverse(p, a)
|
|
43
|
+
|
|
44
|
+
# Manual gradient: 1 / d/dx gammainc(a, x)
|
|
45
|
+
manual_grad = 1 / jax.grad(lambda x_val: special.gammainc(a, x_val))(x)
|
|
46
|
+
|
|
47
|
+
# Autodiff gradient
|
|
48
|
+
auto_grad = jax.grad(gammap_inverse)(p, a)
|
|
49
|
+
|
|
50
|
+
assert jnp.abs(manual_grad - auto_grad) < 1e-6
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Test profile likelihood with Gaussian data and analytical validation."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
|
|
6
|
+
from numerax.stats import make_profile_llh
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_gaussian_profile_analytical_validation():
|
|
10
|
+
r"""Test profile likelihood against analytical Gaussian MLE solution.
|
|
11
|
+
|
|
12
|
+
For Gaussian data with fixed $\mu$, the MLE of $\sigma$ is:
|
|
13
|
+
$\hat{\sigma}(\mu) = \sqrt{\Sigma(x_i-\mu)^2/N}$
|
|
14
|
+
This test validates that our profile likelihood optimization finds the same
|
|
15
|
+
result as this analytical solution.
|
|
16
|
+
"""
|
|
17
|
+
# Set random seed for reproducible test
|
|
18
|
+
key = jax.random.PRNGKey(42)
|
|
19
|
+
|
|
20
|
+
# Generate synthetic Gaussian data
|
|
21
|
+
true_mu = 2.0
|
|
22
|
+
true_sigma = 1.5
|
|
23
|
+
n_samples = 50
|
|
24
|
+
|
|
25
|
+
data = jax.random.normal(key, (n_samples,)) * true_sigma + true_mu
|
|
26
|
+
|
|
27
|
+
# Define normal log likelihood function
|
|
28
|
+
def normal_llh(params, data):
|
|
29
|
+
mu, log_sigma = params
|
|
30
|
+
sigma = jnp.exp(log_sigma)
|
|
31
|
+
return jnp.sum(
|
|
32
|
+
-0.5 * jnp.log(2 * jnp.pi)
|
|
33
|
+
- log_sigma
|
|
34
|
+
- 0.5 * ((data - mu) / sigma) ** 2
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Set up profile likelihood: mu=inference, log_sigma=nuisance
|
|
38
|
+
is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
|
|
39
|
+
|
|
40
|
+
def get_initial_log_sigma(data):
|
|
41
|
+
# Initialize with log of sample standard deviation
|
|
42
|
+
return jnp.array([jnp.log(jnp.std(data))])
|
|
43
|
+
|
|
44
|
+
# Create profile likelihood function
|
|
45
|
+
profile_llh = make_profile_llh(
|
|
46
|
+
normal_llh, is_nuisance, get_initial_log_sigma
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Test at a fixed mu value
|
|
50
|
+
test_mu = 1.8
|
|
51
|
+
llh_val, opt_log_sigma, diff, n_iter = profile_llh(
|
|
52
|
+
jnp.array([test_mu]), data
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Extract optimized sigma
|
|
56
|
+
sigma_opt = jnp.exp(opt_log_sigma[0])
|
|
57
|
+
|
|
58
|
+
# Calculate analytical solution: sigma_hat(mu) = sqrt(sum((x_i-mu)^2)/N)
|
|
59
|
+
sigma_analytical = jnp.sqrt(jnp.mean((data - test_mu) ** 2))
|
|
60
|
+
|
|
61
|
+
# Validate convergence
|
|
62
|
+
assert n_iter > 0, "Optimization should have run at least one iteration"
|
|
63
|
+
assert jnp.abs(diff) <= jnp.abs(llh_val * 1e-6), (
|
|
64
|
+
f"Should have converged, but diff={diff}"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Validate against analytical solution
|
|
68
|
+
assert jnp.allclose(sigma_opt, sigma_analytical, atol=1e-3), (
|
|
69
|
+
f"Optimized sigma={sigma_opt:.6f} should match "
|
|
70
|
+
f"analytical sigma={sigma_analytical:.6f}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Validate likelihood value is finite and reasonable
|
|
74
|
+
assert jnp.isfinite(llh_val), "Profile likelihood value should be finite"
|
|
75
|
+
assert llh_val < 0, "Log likelihood should be negative"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_gaussian_profile_at_true_mean():
|
|
79
|
+
"""Test profile likelihood when mu equals the true data-generating mean."""
|
|
80
|
+
key = jax.random.PRNGKey(123)
|
|
81
|
+
|
|
82
|
+
# Generate data
|
|
83
|
+
true_mu = 3.0
|
|
84
|
+
true_sigma = 0.8
|
|
85
|
+
n_samples = 30
|
|
86
|
+
|
|
87
|
+
data = jax.random.normal(key, (n_samples,)) * true_sigma + true_mu
|
|
88
|
+
|
|
89
|
+
def normal_llh(params, data):
|
|
90
|
+
mu, log_sigma = params
|
|
91
|
+
sigma = jnp.exp(log_sigma)
|
|
92
|
+
return jnp.sum(
|
|
93
|
+
-0.5 * jnp.log(2 * jnp.pi)
|
|
94
|
+
- log_sigma
|
|
95
|
+
- 0.5 * ((data - mu) / sigma) ** 2
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
is_nuisance = [False, True]
|
|
99
|
+
|
|
100
|
+
def get_initial_log_sigma(data):
|
|
101
|
+
return jnp.array([jnp.log(jnp.std(data))])
|
|
102
|
+
|
|
103
|
+
profile_llh = make_profile_llh(
|
|
104
|
+
normal_llh, is_nuisance, get_initial_log_sigma
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Test at sample mean (should be close to optimal)
|
|
108
|
+
sample_mean = jnp.mean(data)
|
|
109
|
+
llh_val, opt_log_sigma, _, _ = profile_llh(jnp.array([sample_mean]), data)
|
|
110
|
+
|
|
111
|
+
sigma_opt = jnp.exp(opt_log_sigma[0])
|
|
112
|
+
sigma_analytical = jnp.sqrt(jnp.mean((data - sample_mean) ** 2))
|
|
113
|
+
|
|
114
|
+
# Should converge quickly and match analytical solution well
|
|
115
|
+
assert jnp.allclose(sigma_opt, sigma_analytical, atol=1e-4)
|
|
116
|
+
assert jnp.isfinite(llh_val)
|