numerax 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,13 @@
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(find:*)",
5
+ "Bash(ruff format:*)",
6
+ "mcp__vscode-tools__get_symbol_definition_code",
7
+ "mcp__vscode-tools__search_symbols_code",
8
+ "Bash(hatch run lint:*)",
9
+ "mcp__vscode-tools__get_document_symbols_code"
10
+ ],
11
+ "deny": []
12
+ }
13
+ }
@@ -0,0 +1,59 @@
1
+ name: documentation
2
+
3
+ # build the documentation whenever there are new commits on main
4
+ on:
5
+ push:
6
+ branches:
7
+ - main
8
+ # Alternative: only build for tags.
9
+ # tags:
10
+ # - '*'
11
+
12
+ # security: restrict permissions for CI jobs.
13
+ permissions:
14
+ contents: read
15
+
16
+ jobs:
17
+ # Build the documentation and upload the static HTML files as an artifact.
18
+ build:
19
+ runs-on: ubuntu-latest
20
+ steps:
21
+ - uses: actions/checkout@v4
22
+ with:
23
+ persist-credentials: false
24
+
25
+ # Extract Python version from pyproject.toml
26
+ - name: Get Python version from pyproject.toml
27
+ id: python-version
28
+ run: |
29
+ python_version=$(grep -E "requires-python.*=" pyproject.toml | sed -E 's/.*">=([0-9.]+)".*/\1/')
30
+ echo "version=$python_version" >> $GITHUB_OUTPUT
31
+
32
+ - uses: actions/setup-python@v5
33
+ with:
34
+ python-version: ${{ steps.python-version.outputs.version }}
35
+
36
+ # Install Hatch
37
+ - run: pip install hatch
38
+
39
+ # Build documentation with MkDocs using dev environment
40
+ - run: hatch run mkdocs build
41
+
42
+ - uses: actions/upload-pages-artifact@v3
43
+ with:
44
+ path: site/
45
+
46
+ # Deploy the artifact to GitHub pages.
47
+ # This is a separate job so that only actions/deploy-pages has the necessary permissions.
48
+ deploy:
49
+ needs: build
50
+ runs-on: ubuntu-latest
51
+ permissions:
52
+ pages: write
53
+ id-token: write
54
+ environment:
55
+ name: github-pages
56
+ url: ${{ steps.deployment.outputs.page_url }}
57
+ steps:
58
+ - id: deployment
59
+ uses: actions/deploy-pages@v4
@@ -0,0 +1,84 @@
1
+ name: publish
2
+
3
+ # Publish to PyPI when a new release is created
4
+ on:
5
+ release:
6
+ types: [published]
7
+
8
+ # Security: restrict permissions for CI jobs
9
+ permissions:
10
+ contents: read
11
+
12
+ jobs:
13
+ # Build and test before publishing
14
+ test:
15
+ runs-on: ubuntu-latest
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+ with:
19
+ persist-credentials: false
20
+
21
+ - uses: actions/setup-python@v5
22
+ with:
23
+ python-version: '3.12'
24
+ cache: 'pip'
25
+
26
+ # Install hatch for environment management
27
+ - name: Install Hatch
28
+ uses: pypa/hatch@install
29
+
30
+ # Run code quality checks using hatch
31
+ - name: Run linting and formatting checks
32
+ run: hatch run lint:check
33
+
34
+ # Run tests using hatch
35
+ - name: Run tests
36
+ run: hatch run test
37
+
38
+ # Build the package
39
+ build:
40
+ needs: test
41
+ runs-on: ubuntu-latest
42
+ steps:
43
+ - uses: actions/checkout@v4
44
+ with:
45
+ persist-credentials: false
46
+
47
+ - uses: actions/setup-python@v5
48
+ with:
49
+ python-version: '3.12'
50
+ cache: 'pip'
51
+
52
+ # Install hatch for building
53
+ - name: Install Hatch
54
+ uses: pypa/hatch@install
55
+
56
+ # Build the package
57
+ - name: Build package
58
+ run: hatch build
59
+
60
+ # Upload build artifacts
61
+ - uses: actions/upload-artifact@v4
62
+ with:
63
+ name: dist-artifacts
64
+ path: dist/
65
+
66
+ # Publish to PyPI
67
+ publish:
68
+ needs: build
69
+ runs-on: ubuntu-latest
70
+ permissions:
71
+ id-token: write # Required for trusted publishing
72
+ environment:
73
+ name: pypi
74
+ url: https://pypi.org/p/numerax
75
+ steps:
76
+ - name: Download build artifacts
77
+ uses: actions/download-artifact@v4
78
+ with:
79
+ name: dist-artifacts
80
+ path: dist/
81
+
82
+ # Publish to PyPI using trusted publishing
83
+ - name: Publish to PyPI
84
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,39 @@
1
+ name: tests
2
+
3
+ # Run tests on pushes to main and pull requests
4
+ on:
5
+ push:
6
+ branches:
7
+ - main
8
+ pull_request:
9
+ branches:
10
+ - main
11
+
12
+ # Security: restrict permissions for CI jobs
13
+ permissions:
14
+ contents: read
15
+
16
+ jobs:
17
+ test:
18
+ runs-on: ubuntu-latest
19
+ steps:
20
+ - uses: actions/checkout@v4
21
+ with:
22
+ persist-credentials: false
23
+
24
+ - uses: actions/setup-python@v5
25
+ with:
26
+ python-version: '3.12'
27
+ cache: 'pip'
28
+
29
+ # Install hatch for environment management
30
+ - name: Install Hatch
31
+ uses: pypa/hatch@install
32
+
33
+ # Run code quality checks using hatch
34
+ - name: Run linting and formatting checks
35
+ run: hatch run lint:check
36
+
37
+ # Run tests using hatch
38
+ - name: Run tests
39
+ run: hatch run test
@@ -0,0 +1,5 @@
1
+ .ruff_cache/*
2
+ CLAUDE.md
3
+ *.pyc
4
+ __pycache__/
5
+ *.pyo
numerax-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Juehang Qin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
numerax-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,110 @@
1
+ Metadata-Version: 2.4
2
+ Name: numerax
3
+ Version: 0.1.0
4
+ Project-URL: Documentation, https://github.com/juehang/numerax#readme
5
+ Project-URL: Issues, https://github.com/juehang/numerax/issues
6
+ Project-URL: Source, https://github.com/juehang/numerax
7
+ Author: Juehang Qin
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Programming Language :: Python
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
14
+ Classifier: Programming Language :: Python :: Implementation :: CPython
15
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
16
+ Requires-Python: >=3.12
17
+ Requires-Dist: jax
18
+ Requires-Dist: jaxtyping
19
+ Requires-Dist: optax
20
+ Description-Content-Type: text/markdown
21
+
22
+ # numerax
23
+
24
+ [![tests](https://github.com/juehang/numerax/actions/workflows/test.yml/badge.svg)](https://github.com/juehang/numerax/actions/workflows/test.yml)
25
+ [![docs](https://github.com/juehang/numerax/actions/workflows/docs.yml/badge.svg)](https://juehang.github.io/numerax/)
26
+
27
+ Statistical and numerical computation functions for JAX, focusing on tools not available in the main JAX API.
28
+
29
+ **[📖 Documentation](https://juehang.github.io/numerax/)**
30
+
31
+ ## Installation
32
+
33
+ ```bash
34
+ pip install numerax
35
+ ```
36
+
37
+ ## Features
38
+
39
+ ### Special Functions
40
+
41
+ Inverse regularized incomplete gamma function with differentiability support:
42
+
43
+ ```python
44
+ import jax.numpy as jnp
45
+ import numerax
46
+
47
+ # Compute gamma quantiles (inverse CDF)
48
+ p = jnp.array([0.1, 0.5, 0.9]) # Probabilities
49
+ a = 2.0 # Shape parameter
50
+
51
+ x = numerax.special.gammap_inverse(p, a)
52
+ # Returns quantiles where gammainc(a, x) = p
53
+
54
+ # Fully differentiable with custom JVP
55
+ grad_fn = jax.grad(numerax.special.gammap_inverse)
56
+ dx_dp = grad_fn(0.5, 2.0) # Gradient with respect to probability
57
+ ```
58
+
59
+ **Key features:**
60
+ - Halley's method for fast convergence
61
+ - Custom JVP implementation for exact gradients
62
+ - Numerical stability with adaptive precision
63
+ - Equivalent to gamma distribution inverse CDF
64
+
65
+ ### Profile Likelihood
66
+
67
+ Efficient profile likelihood computation for statistical inference with nuisance parameters:
68
+
69
+ ```python
70
+ import jax.numpy as jnp
71
+ import numerax
72
+
73
+ # Example: Normal distribution with mean inference, variance profiling
74
+ def normal_llh(params, data):
75
+ mu, log_sigma = params
76
+ sigma = jnp.exp(log_sigma)
77
+ return jnp.sum(-0.5 * jnp.log(2 * jnp.pi) - log_sigma
78
+ - 0.5 * ((data - mu) / sigma) ** 2)
79
+
80
+ # Profile over log_sigma, infer mu
81
+ is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
82
+
83
+ def get_initial_log_sigma(data):
84
+ return jnp.array([jnp.log(jnp.std(data))])
85
+
86
+ profile_llh = numerax.stats.make_profile_llh(
87
+ normal_llh, is_nuisance, get_initial_log_sigma
88
+ )
89
+
90
+ # Evaluate profile likelihood
91
+ data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1])
92
+ llh_val, opt_nuisance, diff, n_iter = profile_llh(jnp.array([1.0]), data)
93
+ ```
94
+
95
+ **Key features:**
96
+ - JIT-compiled for performance
97
+ - L-BFGS optimization with convergence diagnostics
98
+ - Configurable tolerance and initial values
99
+ - Handles parameter masking automatically
100
+
101
+ ### Utilities
102
+
103
+ Development utilities for creating JAX functions with custom derivatives while ensuring proper documentation support. Includes decorators for preserving function metadata when using JAX's advanced features.
104
+
105
+ ## Requirements
106
+
107
+ - Python ≥ 3.12
108
+ - JAX
109
+ - jaxtyping
110
+ - optax
@@ -0,0 +1,89 @@
1
+ # numerax
2
+
3
+ [![tests](https://github.com/juehang/numerax/actions/workflows/test.yml/badge.svg)](https://github.com/juehang/numerax/actions/workflows/test.yml)
4
+ [![docs](https://github.com/juehang/numerax/actions/workflows/docs.yml/badge.svg)](https://juehang.github.io/numerax/)
5
+
6
+ Statistical and numerical computation functions for JAX, focusing on tools not available in the main JAX API.
7
+
8
+ **[📖 Documentation](https://juehang.github.io/numerax/)**
9
+
10
+ ## Installation
11
+
12
+ ```bash
13
+ pip install numerax
14
+ ```
15
+
16
+ ## Features
17
+
18
+ ### Special Functions
19
+
20
+ Inverse regularized incomplete gamma function with differentiability support:
21
+
22
+ ```python
23
+ import jax.numpy as jnp
24
+ import numerax
25
+
26
+ # Compute gamma quantiles (inverse CDF)
27
+ p = jnp.array([0.1, 0.5, 0.9]) # Probabilities
28
+ a = 2.0 # Shape parameter
29
+
30
+ x = numerax.special.gammap_inverse(p, a)
31
+ # Returns quantiles where gammainc(a, x) = p
32
+
33
+ # Fully differentiable with custom JVP
34
+ grad_fn = jax.grad(numerax.special.gammap_inverse)
35
+ dx_dp = grad_fn(0.5, 2.0) # Gradient with respect to probability
36
+ ```
37
+
38
+ **Key features:**
39
+ - Halley's method for fast convergence
40
+ - Custom JVP implementation for exact gradients
41
+ - Numerical stability with adaptive precision
42
+ - Equivalent to gamma distribution inverse CDF
43
+
44
+ ### Profile Likelihood
45
+
46
+ Efficient profile likelihood computation for statistical inference with nuisance parameters:
47
+
48
+ ```python
49
+ import jax.numpy as jnp
50
+ import numerax
51
+
52
+ # Example: Normal distribution with mean inference, variance profiling
53
+ def normal_llh(params, data):
54
+ mu, log_sigma = params
55
+ sigma = jnp.exp(log_sigma)
56
+ return jnp.sum(-0.5 * jnp.log(2 * jnp.pi) - log_sigma
57
+ - 0.5 * ((data - mu) / sigma) ** 2)
58
+
59
+ # Profile over log_sigma, infer mu
60
+ is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
61
+
62
+ def get_initial_log_sigma(data):
63
+ return jnp.array([jnp.log(jnp.std(data))])
64
+
65
+ profile_llh = numerax.stats.make_profile_llh(
66
+ normal_llh, is_nuisance, get_initial_log_sigma
67
+ )
68
+
69
+ # Evaluate profile likelihood
70
+ data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1])
71
+ llh_val, opt_nuisance, diff, n_iter = profile_llh(jnp.array([1.0]), data)
72
+ ```
73
+
74
+ **Key features:**
75
+ - JIT-compiled for performance
76
+ - L-BFGS optimization with convergence diagnostics
77
+ - Configurable tolerance and initial values
78
+ - Handles parameter masking automatically
79
+
80
+ ### Utilities
81
+
82
+ Development utilities for creating JAX functions with custom derivatives while ensuring proper documentation support. Includes decorators for preserving function metadata when using JAX's advanced features.
83
+
84
+ ## Requirements
85
+
86
+ - Python ≥ 3.12
87
+ - JAX
88
+ - jaxtyping
89
+ - optax
@@ -0,0 +1,9 @@
1
+ # API Reference
2
+
3
+ Comprehensive API documentation for all Numerax modules.
4
+
5
+ ## Modules
6
+
7
+ - **[Special Functions](special.md)** - Mathematical special functions with custom derivatives
8
+ - **[Statistics](stats.md)** - Advanced statistical computation tools
9
+ - **[Utilities](utils.md)** - Development utilities for JAX functions
@@ -0,0 +1,6 @@
1
+ # Special Functions
2
+
3
+ ::: numerax.special
4
+ options:
5
+ show_source: false
6
+ heading_level: 2
@@ -0,0 +1,6 @@
1
+ # Statistics
2
+
3
+ ::: numerax.stats
4
+ options:
5
+ show_source: false
6
+ heading_level: 2
@@ -0,0 +1,6 @@
1
+ # Utilities
2
+
3
+ ::: numerax.utils
4
+ options:
5
+ show_source: false
6
+ heading_level: 2
@@ -0,0 +1,6 @@
1
+ # Numerax
2
+
3
+ ::: numerax
4
+ options:
5
+ show_source: false
6
+ heading_level: 1
@@ -0,0 +1,18 @@
1
+ /* Custom styles for MkDocs Material theme */
2
+
3
+ /* Improve math rendering */
4
+ .arithmatex {
5
+ padding: 0.5em;
6
+ margin: 1em 0;
7
+ }
8
+
9
+ /* Better code block styling */
10
+ .highlight pre {
11
+ padding: 1em;
12
+ overflow-x: auto;
13
+ }
14
+
15
+ /* Improve navigation */
16
+ .md-nav__item .md-nav__link--active {
17
+ font-weight: bold;
18
+ }
@@ -0,0 +1,73 @@
1
+ site_name: Numerax
2
+ site_description: Statistical and numerical computation functions for JAX
3
+ site_url: https://juehang.github.io/numerax/
4
+ repo_url: https://github.com/juehang/numerax
5
+ repo_name: juehang/numerax
6
+
7
+ theme:
8
+ name: material
9
+ features:
10
+ - navigation.tabs
11
+ - navigation.sections
12
+ - navigation.expand
13
+ - navigation.top
14
+ - search.highlight
15
+ - search.share
16
+ - content.code.copy
17
+ palette:
18
+ - scheme: default
19
+ primary: blue
20
+ accent: blue
21
+ toggle:
22
+ icon: material/brightness-7
23
+ name: Switch to dark mode
24
+ - scheme: slate
25
+ primary: blue
26
+ accent: blue
27
+ toggle:
28
+ icon: material/brightness-4
29
+ name: Switch to light mode
30
+
31
+ plugins:
32
+ - search
33
+ - mkdocstrings:
34
+ handlers:
35
+ python:
36
+ options:
37
+ docstring_style: google
38
+ show_source: false
39
+ show_root_heading: true
40
+ show_root_toc_entry: false
41
+ merge_init_into_class: true
42
+ show_signature_annotations: true
43
+ separate_signature: true
44
+
45
+ markdown_extensions:
46
+ - pymdownx.arithmatex:
47
+ generic: true
48
+ - pymdownx.highlight:
49
+ anchor_linenums: true
50
+ - pymdownx.inlinehilite
51
+ - pymdownx.snippets
52
+ - pymdownx.superfences
53
+ - admonition
54
+ - pymdownx.details
55
+ - pymdownx.tabbed:
56
+ alternate_style: true
57
+ - toc:
58
+ permalink: true
59
+
60
+ extra_javascript:
61
+ - https://polyfill.io/v3/polyfill.min.js?features=es6
62
+ - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
63
+
64
+ extra_css:
65
+ - stylesheets/extra.css
66
+
67
+ nav:
68
+ - Home: index.md
69
+ - API Reference:
70
+ - Overview: api/index.md
71
+ - Special Functions: api/special.md
72
+ - Statistics: api/stats.md
73
+ - Utilities: api/utils.md
@@ -0,0 +1,133 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "numerax"
7
+ dynamic = ["version"]
8
+ description = ""
9
+ readme = "README.md"
10
+ requires-python = ">=3.12"
11
+ license = "MIT"
12
+ keywords = []
13
+ authors = [
14
+ { name = "Juehang Qin"},
15
+ ]
16
+ classifiers = [
17
+ "Development Status :: 4 - Beta",
18
+ "Programming Language :: Python",
19
+ "Programming Language :: Python :: 3.12",
20
+ "Programming Language :: Python :: 3.13",
21
+ "Programming Language :: Python :: Implementation :: CPython",
22
+ "Programming Language :: Python :: Implementation :: PyPy",
23
+ ]
24
+ dependencies = [
25
+ "jax",
26
+ "jaxtyping",
27
+ "optax"
28
+ ]
29
+
30
+ [project.urls]
31
+ Documentation = "https://github.com/juehang/numerax#readme"
32
+ Issues = "https://github.com/juehang/numerax/issues"
33
+ Source = "https://github.com/juehang/numerax"
34
+
35
+ [tool.hatch.version]
36
+ path = "src/numerax/__init__.py"
37
+
38
+ [tool.hatch.envs.default]
39
+ dependencies = [
40
+ "coverage[toml]>=6.5",
41
+ "pytest",
42
+ "ruff",
43
+ "mkdocs-material",
44
+ "mkdocstrings[python]"
45
+ ]
46
+ [tool.hatch.envs.default.scripts]
47
+ test = "pytest {args:tests}"
48
+ test-cov = "coverage run -m pytest {args:tests}"
49
+ cov-report = [
50
+ "- coverage combine",
51
+ "coverage report",
52
+ ]
53
+ cov = [
54
+ "test-cov",
55
+ "cov-report",
56
+ ]
57
+
58
+ [[tool.hatch.envs.all.matrix]]
59
+ python = ["3.12"]
60
+
61
+ [tool.hatch.envs.lint]
62
+ detached = true
63
+ dependencies = [
64
+ "ruff>=0.0.243",
65
+ ]
66
+ [tool.hatch.envs.lint.scripts]
67
+ check = [
68
+ "ruff check {args:.}",
69
+ "ruff format --check --diff {args:.}",
70
+ ]
71
+ fmt = [
72
+ "ruff format {args:.}",
73
+ "ruff check --fix {args:.}",
74
+ "check",
75
+ ]
76
+
77
+ [tool.coverage.run]
78
+ source_pkgs = ["numerax", "tests"]
79
+ branch = true
80
+ parallel = true
81
+ omit = [
82
+ "src/numerax/__about__.py",
83
+ ]
84
+
85
+ [tool.coverage.paths]
86
+ numerax = ["src/numerax", "*/numerax/src/numerax"]
87
+ tests = ["tests", "*/numerax/tests"]
88
+
89
+ [tool.coverage.report]
90
+ exclude_lines = [
91
+ "no cov",
92
+ "if __name__ == .__main__.:",
93
+ "if TYPE_CHECKING:",
94
+ ]
95
+
96
+ [tool.ruff]
97
+ line-length = 79
98
+ target-version = "py312"
99
+ extend-exclude = ["old_code"]
100
+
101
+ [tool.ruff.lint]
102
+ select = [
103
+ "E", # pycodestyle errors
104
+ "F", # pyflakes
105
+ "UP", # pyupgrade
106
+ "B", # flake8-bugbear
107
+ "SIM", # flake8-simplify
108
+ "I", # isort
109
+ "ASYNC", # flake8-async
110
+ "S", # flake8-bandit
111
+ "ARG", # flake8-unused-arguments
112
+ "Q", # flake8-quotes
113
+ "SIM", # flake8-simplify
114
+ "NPY", # numpy-specific rules
115
+ "PD", # pandas-specific rules
116
+ "N", # pep8-naming
117
+ "W", # warning
118
+ "PLC", # pylint convention
119
+ "PLE", # pylint error
120
+ "PLW", # pylint warning
121
+ "RUF", # ruff-specific rules
122
+ ]
123
+ ignore = ["COM812"] # Ignore trailing comma rule that conflicts with formatter
124
+
125
+ [tool.ruff.lint.per-file-ignores]
126
+ "tests/**" = ["S101"] # Allow assert statements in tests
127
+
128
+ # Enable Ruff's formatter
129
+ [tool.ruff.format]
130
+ quote-style = "double"
131
+ indent-style = "space"
132
+ line-ending = "auto"
133
+ docstring-code-format = true
@@ -0,0 +1,37 @@
1
+ """
2
+ Statistical and numerical computation functions for JAX, focusing on tools
3
+ not available in the main JAX API.
4
+
5
+ ## Overview
6
+
7
+ This package provides JAX-compatible implementations of specialized numerical
8
+ functions with full differentiability support. All functions are designed to
9
+ work seamlessly with JAX's transformations (JIT, grad, vmap, etc.) and follow
10
+ JAX's functional programming paradigms.
11
+
12
+ ### Special Functions (`numerax.special`)
13
+
14
+ Mathematical special functions with custom derivative implementations.
15
+ Functions use numerically stable algorithms and provide exact gradients
16
+ through custom JVP rules where standard automatic differentiation would
17
+ be inefficient or unstable.
18
+
19
+ ### Statistical Methods (`numerax.stats`)
20
+
21
+ Advanced statistical computation tools for inference problems. Implements
22
+ efficient algorithms for complex statistical models, with particular focus
23
+ on optimization-based methods that benefit from JAX's compilation and
24
+ differentiation capabilities.
25
+
26
+ ### Utilities (`numerax.utils`)
27
+
28
+ Development utilities for creating JAX-compatible functions with proper
29
+ documentation support. Includes decorators and helpers for preserving
30
+ function metadata when using JAX's advanced features like custom derivatives.
31
+ """
32
+
33
+ from . import special, stats, utils
34
+
35
+ __version__ = "0.1.0"
36
+
37
+ __all__ = ["special", "stats", "utils"]
@@ -0,0 +1,3 @@
1
+ from .gamma import gammap_inverse
2
+
3
+ __all__ = ["gammap_inverse"]
@@ -0,0 +1,217 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jax.scipy.special as special
4
+ from jaxtyping import ArrayLike
5
+
6
+ # Global constants for numerical stability - adapt to JAX precision setting
7
+ _DTYPE = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32
8
+ TINY = jnp.finfo(_DTYPE).smallest_normal # For preventing underflow
9
+ EPS = jnp.finfo(_DTYPE).eps # For convergence tolerance (machine epsilon)
10
+
11
+
12
+ @jax.custom_jvp
13
+ def gammap_inverse(p: ArrayLike, a: float) -> ArrayLike:
14
+ r"""
15
+ Inverse of the regularized incomplete gamma function.
16
+
17
+ ## Overview
18
+
19
+ Computes the inverse of the regularized incomplete gamma function, finding
20
+ $x$ such that $P(a, x) = p$, where $P(a, x)$ is the regularized incomplete
21
+ gamma function. This is equivalent to computing quantiles of the
22
+ $\text{Gamma}(a, 1)$ distribution. The general strategy and the initial
23
+ guess are based on the methods described in
24
+ Numerical Recipes (Press et al., 2007).
25
+
26
+ ## Mathematical Background
27
+
28
+ The regularized incomplete gamma function is defined as:
29
+
30
+ $$P(a, x) = \frac{\gamma(a, x)}{\Gamma(a)} = \frac{1}{\Gamma(a)}
31
+ \int_0^x t^{a-1} e^{-t} dt$$
32
+
33
+ This function solves the inverse problem:
34
+
35
+ $$x = P^{-1}(a, p) \quad \text{such that} \quad P(a, x) = p$$
36
+
37
+ For a random variable $X \sim \text{Gamma}(a, 1)$, this gives:
38
+
39
+ $$x = F^{-1}(p) \quad \text{where} \quad F(x) = P(\Gamma(a), x)$$
40
+
41
+ ## Numerical Method
42
+
43
+ Uses Halley's method for fast quadratic convergence:
44
+
45
+ $$x_{n+1} = x_n - \frac{2f(x_n)f'(x_n)}{2[f'(x_n)]^2 - f(x_n)f''(x_n)}$$
46
+
47
+ where $f(x) = P(a, x) - p$.
48
+
49
+ **Initial guess** based on Numerical Recipes (Press et al., 2007):
50
+ - For $a > 1$: Wilson-Hilferty approximation
51
+ - For $a \leq 1$: Asymptotic expansions
52
+
53
+ ## Args
54
+
55
+ - **p**: Probability values in $[0, 1]$. Can be scalar or array.
56
+ - **a**: Shape parameter (must be positive). Scalar value.
57
+
58
+ ## Returns
59
+
60
+ Quantiles $x$ where $P(a, x) = p$. Same shape as input `p`.
61
+
62
+ ## Example
63
+
64
+ ```python
65
+ import jax.numpy as jnp
66
+ import numerax
67
+
68
+ # Single quantile
69
+ x = numerax.special.gammap_inverse(0.5, 2.0) # Median of Gamma(2, 1)
70
+
71
+ # Multiple quantiles
72
+ p_vals = jnp.array([0.1, 0.25, 0.5, 0.75, 0.9])
73
+ x_vals = numerax.special.gammap_inverse(p_vals, 3.0)
74
+
75
+ # Verify inverse relationship
76
+ from jax.scipy.special import gammainc
77
+
78
+ p_recovered = gammainc(2.0, x) # Should equal original p
79
+
80
+ # Differentiable for optimization
81
+ grad_fn = jax.grad(numerax.special.gammap_inverse)
82
+ sensitivity = grad_fn(0.5, 2.0) # ∂x/∂p at median
83
+ ```
84
+
85
+ ## Notes
86
+
87
+ - **Convergence**: Typically converges in 3-8 iterations
88
+ - **Differentiable**: Custom JVP implementation using implicit function
89
+ theorem
90
+ - **Numerical stability**: Handles edge cases near 0 and 1
91
+ - **Performance**: JIT-compiled with adaptive precision
92
+ - **Domain**: $p \in [0, 1]$ and $a > 0$
93
+
94
+ ## References
95
+
96
+ Press, W. H., Teukolsky, S. A., Vetterling, W. T., & Flannery, B. P.
97
+ (2007).
98
+ *Numerical Recipes: The Art of Scientific Computing* (3rd ed.).
99
+ Cambridge University Press.
100
+ """
101
+
102
+ def objective(x):
103
+ """F(x) = gammainc(a, x) - p"""
104
+ return special.gammainc(a, x) - p
105
+
106
+ # Initial guess from Numerical Recipes
107
+ def initial_guess(u_val, a_val):
108
+ # a = dof/2 for chi-squared
109
+
110
+ def large_a_guess():
111
+ # For a > 1: use Wilson-Hilferty approximation
112
+ pp = jnp.where(u_val < 0.5, u_val, 1.0 - u_val)
113
+ t = jnp.sqrt(-2.0 * jnp.log(pp))
114
+ x = (2.30753 + t * 0.27061) / (
115
+ 1.0 + t * (0.99229 + t * 0.04481)
116
+ ) - t
117
+ x = jnp.where(u_val < 0.5, -x, x)
118
+ return jnp.fmax(
119
+ 1e-3,
120
+ a_val
121
+ * (1.0 - 1.0 / (9.0 * a_val) - x / (3.0 * jnp.sqrt(a_val)))
122
+ ** 3,
123
+ )
124
+
125
+ def small_a_guess():
126
+ # For a <= 1: use equations (6.2.8) and (6.2.9)
127
+ t = 1.0 - a_val * (0.253 + a_val * 0.12)
128
+ return jnp.where(
129
+ u_val < t,
130
+ (u_val / t) ** (1.0 / a_val),
131
+ 1.0 - jnp.log(1.0 - (u_val - t) / (1.0 - t)),
132
+ )
133
+
134
+ return jnp.real(
135
+ jnp.where(a_val > 1.0, large_a_guess(), small_a_guess())
136
+ )
137
+
138
+ # Derivatives for Halley's method
139
+ f = objective
140
+ df_dx = jax.grad(objective)
141
+ d2f_dx2 = jax.grad(df_dx)
142
+
143
+ x = initial_guess(p, a)
144
+
145
+ # Use while_loop for dynamic convergence
146
+ def cond_fn(state):
147
+ x, step, iteration = state
148
+ # Continue while step is large and we haven't exceeded max iterations
149
+ return (jnp.abs(step) > EPS * jnp.abs(x)) & (iteration < 12)
150
+
151
+ def body_fn(state):
152
+ x, _, iteration = state
153
+
154
+ f_val = f(x)
155
+ df_val = df_dx(x)
156
+ d2f_val = d2f_dx2(x)
157
+
158
+ # Halley's method: x_{n+1} = x_n - 2*f*f' / (2*f'^2 - f*f'')
159
+ numerator = 2 * f_val * df_val
160
+ denominator = 2 * df_val**2 - f_val * d2f_val
161
+
162
+ # Avoid division by zero and ensure step is reasonable
163
+ denominator = jnp.where(
164
+ jnp.abs(denominator) < TINY,
165
+ jnp.sign(denominator) * TINY,
166
+ denominator,
167
+ )
168
+
169
+ step = numerator / denominator
170
+ x_new = x - step
171
+
172
+ # Ensure x stays positive
173
+ x_new = jnp.fmax(x_new, TINY)
174
+
175
+ return (x_new, step, iteration + 1)
176
+
177
+ # Initial state: (x, step, iteration)
178
+ initial_state = (x, jnp.inf, 0)
179
+ final_state = jax.lax.while_loop(cond_fn, body_fn, initial_state)
180
+ x = final_state[0]
181
+
182
+ return x
183
+
184
+
185
+ @gammap_inverse.defjvp
186
+ def gammap_inverse_jvp(primals, tangents):
187
+ """
188
+ Custom JVP for gammap_inverse using implicit function theorem.
189
+
190
+ For F(x, p) = gammainc(a, x) - p = 0:
191
+ dx/dp = -∂F/∂p / ∂F/∂x = 1 / (∂/∂x gammainc(a, x))
192
+ """
193
+ p, a = primals
194
+ p_dot, _ = tangents
195
+
196
+ # Forward pass
197
+ x = gammap_inverse(p, a)
198
+
199
+ # Compute derivative: dx/dp = 1 / (d/dx gammainc(a, x))
200
+ def gammainc_x(x_val):
201
+ return special.gammainc(a, x_val)
202
+
203
+ dgammainc_dx = jax.grad(gammainc_x)(x)
204
+
205
+ # Avoid division by zero
206
+ dgammainc_dx = jnp.where(
207
+ jnp.abs(dgammainc_dx) < TINY,
208
+ jnp.sign(dgammainc_dx) * TINY,
209
+ dgammainc_dx,
210
+ )
211
+
212
+ dx_dp = 1.0 / dgammainc_dx
213
+
214
+ # For now, ignore a derivatives (could be added if needed)
215
+ x_dot = dx_dp * p_dot
216
+
217
+ return x, x_dot
@@ -0,0 +1,5 @@
1
+ """Statistics submodule for numerax."""
2
+
3
+ from .profile import make_profile_llh
4
+
5
+ __all__ = ["make_profile_llh"]
@@ -0,0 +1,165 @@
1
+ """Profile likelihood functions for statistical inference."""
2
+
3
+ from collections.abc import Callable
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import optax
8
+ from optax import lbfgs
9
+
10
+
11
+ def make_profile_llh(
12
+ llh_fn: Callable,
13
+ is_nuisance: list[bool] | jnp.ndarray,
14
+ get_initial_nuisance: Callable,
15
+ tol: float = 1e-6,
16
+ initial_value: float = 1e-9,
17
+ initial_diff: float = 1e9,
18
+ ) -> Callable:
19
+ r"""
20
+ Factory function for creating profile likelihood functions.
21
+
22
+ ## Overview
23
+
24
+ Profile likelihood is a statistical technique used when dealing with
25
+ nuisance parameters that are not of primary interest but are necessary
26
+ for the model. This function creates an optimized profile likelihood
27
+ that maximizes over nuisance parameters while keeping inference
28
+ parameters fixed.
29
+
30
+ ## Mathematical Background
31
+
32
+ Given a likelihood function $L(\boldsymbol{\theta}, \boldsymbol{\lambda})$
33
+ where $\boldsymbol{\theta}$ are parameters of interest and
34
+ $\boldsymbol{\lambda}$ are nuisance parameters, the profile likelihood is:
35
+
36
+ $$L_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}}
37
+ L(\boldsymbol{\theta}, \boldsymbol{\lambda})$$
38
+
39
+ In practice, we work with the log-likelihood
40
+ $\ell(\boldsymbol{\theta}, \boldsymbol{\lambda}) =
41
+ \log L(\boldsymbol{\theta}, \boldsymbol{\lambda})$:
42
+
43
+ $$\ell_p(\boldsymbol{\theta}) = \max_{\boldsymbol{\lambda}}
44
+ \ell(\boldsymbol{\theta}, \boldsymbol{\lambda})$$
45
+
46
+ This function uses L-BFGS optimization to find the maximum likelihood
47
+ estimates of nuisance parameters for each fixed value of inference
48
+ parameters.
49
+
50
+ ## Args
51
+
52
+ - **llh_fn**: Log likelihood function taking (params, *args) and
53
+ returning scalar log likelihood value
54
+ - **is_nuisance**: Boolean array where True indicates nuisance
55
+ parameters and False indicates inference parameters
56
+ - **get_initial_nuisance**: Function taking (*args) and returning
57
+ initial values for nuisance parameters
58
+ - **tol**: Convergence tolerance for the optimization (default: 1e-6)
59
+ - **initial_value**: Initial objective value for convergence tracking
60
+ (default: 1e-9)
61
+ - **initial_diff**: Initial difference for convergence tracking
62
+ (default: 1e9)
63
+
64
+ ## Returns
65
+
66
+ Profile likelihood function with signature:
67
+ `(inference_values, *args) -> (profile_llh_value, optimal_nuisance,
68
+ convergence_diff, num_iterations)`
69
+
70
+ ## Example
71
+
72
+ Consider fitting a normal distribution where we want to infer the mean
73
+ $\mu$ but treat the variance $\sigma^2$ as a nuisance parameter:
74
+
75
+ ```python
76
+ import jax.numpy as jnp
77
+ import numerax
78
+
79
+ # Sample data
80
+ data = jnp.array([1.2, 0.8, 1.5, 0.9, 1.1, 1.3, 0.7, 1.4])
81
+
82
+
83
+ # Log likelihood for normal distribution
84
+ def normal_llh(params, data):
85
+ mu, log_sigma = params # Use log(sigma) for numerical stability
86
+ sigma = jnp.exp(log_sigma)
87
+ return jnp.sum(
88
+ -0.5 * jnp.log(2 * jnp.pi)
89
+ - log_sigma
90
+ - 0.5 * ((data - mu) / sigma) ** 2
91
+ )
92
+
93
+
94
+ # Profile over log_sigma (nuisance), infer mu
95
+ is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
96
+
97
+
98
+ def get_initial_log_sigma(data):
99
+ # Initialize with log of sample standard deviation
100
+ return jnp.array([jnp.log(jnp.std(data))])
101
+
102
+
103
+ profile_llh = numerax.stats.make_profile_llh(
104
+ normal_llh, is_nuisance, get_initial_log_sigma
105
+ )
106
+
107
+ # Evaluate profile likelihood at different mu values
108
+ mu_test = 1.0
109
+ llh_val, opt_log_sigma, diff, n_iter = profile_llh(
110
+ jnp.array([mu_test]), data
111
+ )
112
+ ```
113
+
114
+ ## Notes
115
+
116
+ - The function is JIT-compiled for performance
117
+ - Uses L-BFGS optimization which is well-suited for smooth likelihood
118
+ surfaces
119
+ - Returns convergence information for diagnostics
120
+ - Handles parameter masking automatically
121
+ - Consider using log-parameterization for positive parameters
122
+ (e.g., $\log \sigma$) to improve numerical stability
123
+ """
124
+ nuisance_mask = jnp.array(is_nuisance)
125
+ inference_mask = ~nuisance_mask
126
+
127
+ @jax.jit
128
+ def profile_llh(inference_values, *args):
129
+ solver = lbfgs()
130
+ initial_nuisance = get_initial_nuisance(*args)
131
+ opt_state = solver.init(initial_nuisance)
132
+
133
+ def objective(nuisance_params):
134
+ # Reconstruct full parameter vector
135
+ full_params = jnp.zeros(len(nuisance_mask))
136
+ full_params = full_params.at[inference_mask].set(inference_values)
137
+ full_params = full_params.at[nuisance_mask].set(nuisance_params)
138
+ return -llh_fn(full_params, *args)
139
+
140
+ value_and_grad = optax.value_and_grad_from_state(objective)
141
+
142
+ def profile_llh_loopfun(var):
143
+ params, last_value, opt_state, _, n = var
144
+ value, grad = value_and_grad(params, state=opt_state)
145
+ updates, opt_state = solver.update(
146
+ grad,
147
+ opt_state,
148
+ params,
149
+ value=value,
150
+ grad=grad,
151
+ value_fn=objective,
152
+ )
153
+ params = optax.apply_updates(params, updates)
154
+ diff = last_value - value
155
+ return params, value, opt_state, diff, n + 1
156
+
157
+ params, value, opt_state, diff, n = jax.lax.while_loop(
158
+ lambda var: jnp.abs(var[-2]) > jnp.abs(var[1] * tol),
159
+ profile_llh_loopfun,
160
+ (initial_nuisance, initial_value, opt_state, initial_diff, 0),
161
+ )
162
+
163
+ return -value, params, diff, n
164
+
165
+ return profile_llh
@@ -0,0 +1,48 @@
1
+ """Utility functions for the numerax package."""
2
+
3
+ import functools
4
+ from collections.abc import Callable
5
+ from typing import TypeVar
6
+
7
+ F = TypeVar("F", bound=Callable)
8
+
9
+
10
+ def preserve_metadata(decorator):
11
+ """
12
+ Wrapper that ensures a decorator preserves function metadata for
13
+ documentation tools.
14
+
15
+ ## Overview
16
+
17
+ This is particularly useful for JAX decorators like `@custom_jvp` that
18
+ create special objects which may not preserve `__doc__` and other metadata
19
+ properly for documentation generators like pdoc.
20
+
21
+ ## Args
22
+
23
+ - **decorator**: The decorator function to wrap
24
+
25
+ ## Returns
26
+
27
+ A new decorator that preserves metadata
28
+
29
+ ## Example
30
+
31
+ ```python
32
+ import jax
33
+ from numerax.utils import preserve_metadata
34
+
35
+ @preserve_metadata(jax.custom_jvp)
36
+ def my_function(x):
37
+ \"\"\"This docstring will be preserved for pdoc.\"\"\"
38
+ return x
39
+ ```
40
+ """
41
+
42
+ def metadata_preserving_decorator(func: F) -> F:
43
+ # Apply the original decorator
44
+ decorated = decorator(func)
45
+ # Ensure metadata is preserved using functools.wraps pattern
46
+ return functools.wraps(func)(decorated)
47
+
48
+ return metadata_preserving_decorator
@@ -0,0 +1 @@
1
+ """Test configuration for numerax."""
@@ -0,0 +1,50 @@
1
+ """
2
+ Test suite for gamma special functions.
3
+ """
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import pytest
8
+ from jax.scipy import special
9
+
10
+ from numerax.special import gammap_inverse
11
+
12
+
13
+ @pytest.mark.parametrize(
14
+ "p,a",
15
+ [
16
+ (0.1, 0.5),
17
+ (0.3, 1.0),
18
+ (0.5, 1.5),
19
+ (0.8, 2.0),
20
+ (0.95, 3.0),
21
+ ],
22
+ )
23
+ def test_gamma_inverse_correctness(p, a):
24
+ """Test that gammap_inverse correctly inverts gammainc."""
25
+ result = special.gammainc(a, gammap_inverse(p, a))
26
+ assert jnp.abs(result - p) < 1e-6
27
+
28
+
29
+ @pytest.mark.parametrize(
30
+ "p,a",
31
+ [
32
+ (0.2, 0.5),
33
+ (0.4, 1.0),
34
+ (0.6, 1.5),
35
+ (0.8, 2.0),
36
+ (0.9, 2.5),
37
+ ],
38
+ )
39
+ def test_gamma_inverse_gradients(p, a):
40
+ """Test that gradients of gammap_inverse are computed correctly."""
41
+ # Compute x such that gammainc(a, x) = p
42
+ x = gammap_inverse(p, a)
43
+
44
+ # Manual gradient: 1 / d/dx gammainc(a, x)
45
+ manual_grad = 1 / jax.grad(lambda x_val: special.gammainc(a, x_val))(x)
46
+
47
+ # Autodiff gradient
48
+ auto_grad = jax.grad(gammap_inverse)(p, a)
49
+
50
+ assert jnp.abs(manual_grad - auto_grad) < 1e-6
@@ -0,0 +1,116 @@
1
+ """Test profile likelihood with Gaussian data and analytical validation."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ from numerax.stats import make_profile_llh
7
+
8
+
9
+ def test_gaussian_profile_analytical_validation():
10
+ r"""Test profile likelihood against analytical Gaussian MLE solution.
11
+
12
+ For Gaussian data with fixed $\mu$, the MLE of $\sigma$ is:
13
+ $\hat{\sigma}(\mu) = \sqrt{\Sigma(x_i-\mu)^2/N}$
14
+ This test validates that our profile likelihood optimization finds the same
15
+ result as this analytical solution.
16
+ """
17
+ # Set random seed for reproducible test
18
+ key = jax.random.PRNGKey(42)
19
+
20
+ # Generate synthetic Gaussian data
21
+ true_mu = 2.0
22
+ true_sigma = 1.5
23
+ n_samples = 50
24
+
25
+ data = jax.random.normal(key, (n_samples,)) * true_sigma + true_mu
26
+
27
+ # Define normal log likelihood function
28
+ def normal_llh(params, data):
29
+ mu, log_sigma = params
30
+ sigma = jnp.exp(log_sigma)
31
+ return jnp.sum(
32
+ -0.5 * jnp.log(2 * jnp.pi)
33
+ - log_sigma
34
+ - 0.5 * ((data - mu) / sigma) ** 2
35
+ )
36
+
37
+ # Set up profile likelihood: mu=inference, log_sigma=nuisance
38
+ is_nuisance = [False, True] # mu=inference, log_sigma=nuisance
39
+
40
+ def get_initial_log_sigma(data):
41
+ # Initialize with log of sample standard deviation
42
+ return jnp.array([jnp.log(jnp.std(data))])
43
+
44
+ # Create profile likelihood function
45
+ profile_llh = make_profile_llh(
46
+ normal_llh, is_nuisance, get_initial_log_sigma
47
+ )
48
+
49
+ # Test at a fixed mu value
50
+ test_mu = 1.8
51
+ llh_val, opt_log_sigma, diff, n_iter = profile_llh(
52
+ jnp.array([test_mu]), data
53
+ )
54
+
55
+ # Extract optimized sigma
56
+ sigma_opt = jnp.exp(opt_log_sigma[0])
57
+
58
+ # Calculate analytical solution: sigma_hat(mu) = sqrt(sum((x_i-mu)^2)/N)
59
+ sigma_analytical = jnp.sqrt(jnp.mean((data - test_mu) ** 2))
60
+
61
+ # Validate convergence
62
+ assert n_iter > 0, "Optimization should have run at least one iteration"
63
+ assert jnp.abs(diff) <= jnp.abs(llh_val * 1e-6), (
64
+ f"Should have converged, but diff={diff}"
65
+ )
66
+
67
+ # Validate against analytical solution
68
+ assert jnp.allclose(sigma_opt, sigma_analytical, atol=1e-3), (
69
+ f"Optimized sigma={sigma_opt:.6f} should match "
70
+ f"analytical sigma={sigma_analytical:.6f}"
71
+ )
72
+
73
+ # Validate likelihood value is finite and reasonable
74
+ assert jnp.isfinite(llh_val), "Profile likelihood value should be finite"
75
+ assert llh_val < 0, "Log likelihood should be negative"
76
+
77
+
78
+ def test_gaussian_profile_at_true_mean():
79
+ """Test profile likelihood when mu equals the true data-generating mean."""
80
+ key = jax.random.PRNGKey(123)
81
+
82
+ # Generate data
83
+ true_mu = 3.0
84
+ true_sigma = 0.8
85
+ n_samples = 30
86
+
87
+ data = jax.random.normal(key, (n_samples,)) * true_sigma + true_mu
88
+
89
+ def normal_llh(params, data):
90
+ mu, log_sigma = params
91
+ sigma = jnp.exp(log_sigma)
92
+ return jnp.sum(
93
+ -0.5 * jnp.log(2 * jnp.pi)
94
+ - log_sigma
95
+ - 0.5 * ((data - mu) / sigma) ** 2
96
+ )
97
+
98
+ is_nuisance = [False, True]
99
+
100
+ def get_initial_log_sigma(data):
101
+ return jnp.array([jnp.log(jnp.std(data))])
102
+
103
+ profile_llh = make_profile_llh(
104
+ normal_llh, is_nuisance, get_initial_log_sigma
105
+ )
106
+
107
+ # Test at sample mean (should be close to optimal)
108
+ sample_mean = jnp.mean(data)
109
+ llh_val, opt_log_sigma, _, _ = profile_llh(jnp.array([sample_mean]), data)
110
+
111
+ sigma_opt = jnp.exp(opt_log_sigma[0])
112
+ sigma_analytical = jnp.sqrt(jnp.mean((data - sample_mean) ** 2))
113
+
114
+ # Should converge quickly and match analytical solution well
115
+ assert jnp.allclose(sigma_opt, sigma_analytical, atol=1e-4)
116
+ assert jnp.isfinite(llh_val)