nlls-gram 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,33 @@
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+
8
+ jobs:
9
+ lint:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v4
13
+ - uses: astral-sh/setup-uv@v5
14
+ with:
15
+ enable-cache: true
16
+ - run: uv sync --locked
17
+ - run: uv run ruff check .
18
+ - run: uv run ruff format --check .
19
+
20
+ test:
21
+ runs-on: ubuntu-latest
22
+ strategy:
23
+ fail-fast: false
24
+ matrix:
25
+ python-version: ["3.11", "3.12", "3.13"]
26
+ steps:
27
+ - uses: actions/checkout@v4
28
+ - uses: astral-sh/setup-uv@v5
29
+ with:
30
+ enable-cache: true
31
+ python-version: ${{ matrix.python-version }}
32
+ - run: uv sync --locked
33
+ - run: uv run pytest -q
@@ -0,0 +1,39 @@
1
+ name: Docs
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ workflow_dispatch:
7
+
8
+ permissions:
9
+ contents: read
10
+ pages: write
11
+ id-token: write
12
+
13
+ concurrency:
14
+ group: pages
15
+ cancel-in-progress: false
16
+
17
+ jobs:
18
+ build:
19
+ runs-on: ubuntu-latest
20
+ steps:
21
+ - uses: actions/checkout@v4
22
+ - uses: astral-sh/setup-uv@v5
23
+ with:
24
+ enable-cache: true
25
+ - run: uv sync --group docs
26
+ - run: uv run mkdocs build --strict
27
+ - uses: actions/upload-pages-artifact@v3
28
+ with:
29
+ path: site
30
+
31
+ deploy:
32
+ needs: build
33
+ runs-on: ubuntu-latest
34
+ environment:
35
+ name: github-pages
36
+ url: ${{ steps.deployment.outputs.page_url }}
37
+ steps:
38
+ - id: deployment
39
+ uses: actions/deploy-pages@v4
@@ -0,0 +1,36 @@
1
+ name: Publish
2
+
3
+ on:
4
+ release:
5
+ types: [published]
6
+
7
+ permissions:
8
+ contents: read
9
+
10
+ jobs:
11
+ build:
12
+ runs-on: ubuntu-latest
13
+ steps:
14
+ - uses: actions/checkout@v4
15
+ - uses: astral-sh/setup-uv@v5
16
+ - run: uv build
17
+ - run: uvx twine check dist/*
18
+ - uses: actions/upload-artifact@v4
19
+ with:
20
+ name: dist
21
+ path: dist/
22
+
23
+ publish:
24
+ needs: build
25
+ runs-on: ubuntu-latest
26
+ environment:
27
+ name: pypi
28
+ url: https://pypi.org/p/nlls-gram
29
+ permissions:
30
+ id-token: write
31
+ steps:
32
+ - uses: actions/download-artifact@v4
33
+ with:
34
+ name: dist
35
+ path: dist/
36
+ - uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,23 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.egg-info/
5
+ .eggs/
6
+
7
+ # Build artifacts
8
+ build/
9
+ dist/
10
+
11
+ # Virtual environments
12
+ .venv/
13
+
14
+ # Docs build
15
+ site/
16
+
17
+ # Tooling caches
18
+ .pytest_cache/
19
+ .ruff_cache/
20
+ .mypy_cache/
21
+
22
+ # OS
23
+ .DS_Store
@@ -0,0 +1 @@
1
+ 3.13
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 HighDimensionalEconLab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,104 @@
1
+ Metadata-Version: 2.4
2
+ Name: nlls-gram
3
+ Version: 0.1.0
4
+ Summary: Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX/Flax NNX models
5
+ Project-URL: Homepage, https://github.com/HighDimensionalEconLab/nlls_gram
6
+ Project-URL: Repository, https://github.com/HighDimensionalEconLab/nlls_gram
7
+ Project-URL: Documentation, https://highdimensionaleconlab.github.io/nlls_gram/
8
+ Project-URL: Issues, https://github.com/HighDimensionalEconLab/nlls_gram/issues
9
+ Author-email: Jesse Perla <jesseperla@gmail.com>
10
+ License-Expression: MIT
11
+ License-File: LICENSE
12
+ Keywords: flax,jax,levenberg-marquardt,nnx,nonlinear-least-squares,optimization
13
+ Classifier: Development Status :: 3 - Alpha
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
19
+ Requires-Python: >=3.11
20
+ Requires-Dist: flax>=0.10.7
21
+ Requires-Dist: jax>=0.7.0
22
+ Requires-Dist: optax>=0.2.4
23
+ Description-Content-Type: text/markdown
24
+
25
+ # nlls_gram
26
+
27
+ [![CI](https://github.com/HighDimensionalEconLab/nlls_gram/actions/workflows/ci.yml/badge.svg)](https://github.com/HighDimensionalEconLab/nlls_gram/actions/workflows/ci.yml)
28
+ [![Docs](https://github.com/HighDimensionalEconLab/nlls_gram/actions/workflows/docs.yml/badge.svg)](https://highdimensionaleconlab.github.io/nlls_gram/)
29
+ [![PyPI](https://img.shields.io/pypi/v/nlls-gram.svg)](https://pypi.org/project/nlls-gram/)
30
+ [![Python versions](https://img.shields.io/pypi/pyversions/nlls-gram.svg)](https://pypi.org/project/nlls-gram/)
31
+ [![License: MIT](https://img.shields.io/github/license/HighDimensionalEconLab/nlls_gram)](https://github.com/HighDimensionalEconLab/nlls_gram/blob/main/LICENSE)
32
+ [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
33
+
34
+ Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX/Flax
35
+ NNX models.
36
+
37
+ `GramLevenbergMarquardt` minimizes `||r(theta)||^2` for a residual defined over an
38
+ `nnx.Module`, following the optax/nnx `init`/`update` protocol so that steps apply
39
+ through `nnx.Optimizer(model, optax.identity(), wrt=...)`. For overparameterized
40
+ systems (many more parameters `p` than residual rows `n`) it factors the small
41
+ `n x n` gram (dual) system instead of the `p x p` normal equations.
42
+
43
+ ## Install
44
+
45
+ ```bash
46
+ pip install nlls-gram
47
+ ```
48
+
49
+ ## Minimal example
50
+
51
+ Fit `y = a * exp(b * x)` to noise-free data generated from `(a, b) = (2, -1)`:
52
+
53
+ ```python
54
+ import jax
55
+ import jax.numpy as jnp
56
+ import optax
57
+ from flax import nnx
58
+
59
+ from nlls_gram import GramLevenbergMarquardt
60
+
61
+ jax.config.update("jax_enable_x64", True)
62
+
63
+
64
+ class ExpModel(nnx.Module):
65
+ def __init__(self, a, b):
66
+ self.a = nnx.Param(jnp.asarray(a))
67
+ self.b = nnx.Param(jnp.asarray(b))
68
+
69
+ def __call__(self, x):
70
+ return self.a * jnp.exp(self.b * x)
71
+
72
+
73
+ x = jnp.linspace(0.0, 2.0, 20)
74
+ y = 2.0 * jnp.exp(-1.0 * x)
75
+
76
+ model = ExpModel(a=1.0, b=0.0)
77
+ solver = GramLevenbergMarquardt(lambda m, batch: m(batch[0]) - batch[1])
78
+ optimizer = nnx.Optimizer(model, optax.identity(), wrt=nnx.Param)
79
+ lm_state = solver.init()
80
+
81
+
82
+ @jax.jit
83
+ def train_step(graphdef, state, lm_state, batch):
84
+ m, opt = nnx.merge(graphdef, state)
85
+ updates, lm_state, info = solver.update(m, lm_state, batch)
86
+ opt.update(m, updates)
87
+ return lm_state, info, nnx.state((m, opt))
88
+
89
+
90
+ graphdef, state = nnx.split((model, optimizer))
91
+ for _ in range(50):
92
+ lm_state, info, state = train_step(graphdef, state, lm_state, (x, y))
93
+ nnx.update((model, optimizer), state)
94
+
95
+ print(model.a[...], model.b[...]) # ~2.0, ~-1.0
96
+ ```
97
+
98
+ ## Documentation
99
+
100
+ Full docs: https://highdimensionaleconlab.github.io/nlls_gram/
101
+
102
+ ## License
103
+
104
+ MIT
@@ -0,0 +1,80 @@
1
+ # nlls_gram
2
+
3
+ [![CI](https://github.com/HighDimensionalEconLab/nlls_gram/actions/workflows/ci.yml/badge.svg)](https://github.com/HighDimensionalEconLab/nlls_gram/actions/workflows/ci.yml)
4
+ [![Docs](https://github.com/HighDimensionalEconLab/nlls_gram/actions/workflows/docs.yml/badge.svg)](https://highdimensionaleconlab.github.io/nlls_gram/)
5
+ [![PyPI](https://img.shields.io/pypi/v/nlls-gram.svg)](https://pypi.org/project/nlls-gram/)
6
+ [![Python versions](https://img.shields.io/pypi/pyversions/nlls-gram.svg)](https://pypi.org/project/nlls-gram/)
7
+ [![License: MIT](https://img.shields.io/github/license/HighDimensionalEconLab/nlls_gram)](https://github.com/HighDimensionalEconLab/nlls_gram/blob/main/LICENSE)
8
+ [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
9
+
10
+ Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX/Flax
11
+ NNX models.
12
+
13
+ `GramLevenbergMarquardt` minimizes `||r(theta)||^2` for a residual defined over an
14
+ `nnx.Module`, following the optax/nnx `init`/`update` protocol so that steps apply
15
+ through `nnx.Optimizer(model, optax.identity(), wrt=...)`. For overparameterized
16
+ systems (many more parameters `p` than residual rows `n`) it factors the small
17
+ `n x n` gram (dual) system instead of the `p x p` normal equations.
18
+
19
+ ## Install
20
+
21
+ ```bash
22
+ pip install nlls-gram
23
+ ```
24
+
25
+ ## Minimal example
26
+
27
+ Fit `y = a * exp(b * x)` to noise-free data generated from `(a, b) = (2, -1)`:
28
+
29
+ ```python
30
+ import jax
31
+ import jax.numpy as jnp
32
+ import optax
33
+ from flax import nnx
34
+
35
+ from nlls_gram import GramLevenbergMarquardt
36
+
37
+ jax.config.update("jax_enable_x64", True)
38
+
39
+
40
+ class ExpModel(nnx.Module):
41
+ def __init__(self, a, b):
42
+ self.a = nnx.Param(jnp.asarray(a))
43
+ self.b = nnx.Param(jnp.asarray(b))
44
+
45
+ def __call__(self, x):
46
+ return self.a * jnp.exp(self.b * x)
47
+
48
+
49
+ x = jnp.linspace(0.0, 2.0, 20)
50
+ y = 2.0 * jnp.exp(-1.0 * x)
51
+
52
+ model = ExpModel(a=1.0, b=0.0)
53
+ solver = GramLevenbergMarquardt(lambda m, batch: m(batch[0]) - batch[1])
54
+ optimizer = nnx.Optimizer(model, optax.identity(), wrt=nnx.Param)
55
+ lm_state = solver.init()
56
+
57
+
58
+ @jax.jit
59
+ def train_step(graphdef, state, lm_state, batch):
60
+ m, opt = nnx.merge(graphdef, state)
61
+ updates, lm_state, info = solver.update(m, lm_state, batch)
62
+ opt.update(m, updates)
63
+ return lm_state, info, nnx.state((m, opt))
64
+
65
+
66
+ graphdef, state = nnx.split((model, optimizer))
67
+ for _ in range(50):
68
+ lm_state, info, state = train_step(graphdef, state, lm_state, (x, y))
69
+ nnx.update((model, optimizer), state)
70
+
71
+ print(model.a[...], model.b[...]) # ~2.0, ~-1.0
72
+ ```
73
+
74
+ ## Documentation
75
+
76
+ Full docs: https://highdimensionaleconlab.github.io/nlls_gram/
77
+
78
+ ## License
79
+
80
+ MIT
@@ -0,0 +1,78 @@
1
+ # nlls_gram
2
+
3
+ Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX/Flax
4
+ NNX models.
5
+
6
+ `GramLevenbergMarquardt` minimizes `||r(theta)||^2` for a residual defined over an
7
+ `nnx.Module`, following the optax/nnx `init`/`update` protocol so that steps apply
8
+ through `nnx.Optimizer(model, optax.identity(), wrt=...)`. For overparameterized
9
+ systems (many more parameters `p` than residual rows `n`) it factors the small
10
+ `n x n` gram (dual) system instead of the `p x p` normal equations.
11
+
12
+ ## Install
13
+
14
+ ```bash
15
+ pip install nlls-gram
16
+ ```
17
+
18
+ ## Minimal example
19
+
20
+ Fit `y = a * exp(b * x)` to noise-free data generated from `(a, b) = (2, -1)`:
21
+
22
+ ```python
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import optax
26
+ from flax import nnx
27
+
28
+ from nlls_gram import GramLevenbergMarquardt
29
+
30
+ jax.config.update("jax_enable_x64", True)
31
+
32
+
33
+ class ExpModel(nnx.Module):
34
+ def __init__(self, a, b):
35
+ self.a = nnx.Param(jnp.asarray(a))
36
+ self.b = nnx.Param(jnp.asarray(b))
37
+
38
+ def __call__(self, x):
39
+ return self.a * jnp.exp(self.b * x)
40
+
41
+
42
+ # residual_fn(model, batch) -> residual array; the solver minimizes its sum of squares
43
+ def residual_fn(model, batch):
44
+ x, y = batch
45
+ return model(x) - y
46
+
47
+
48
+ x = jnp.linspace(0.0, 2.0, 20)
49
+ y = 2.0 * jnp.exp(-1.0 * x)
50
+
51
+ model = ExpModel(a=1.0, b=0.0)
52
+ solver = GramLevenbergMarquardt(residual_fn, init_damping=1e-2)
53
+ optimizer = nnx.Optimizer(model, optax.identity(), wrt=nnx.Param)
54
+ lm_state = solver.init()
55
+
56
+
57
+ # The solver does not jit internally; wrap the train step yourself.
58
+ @jax.jit
59
+ def train_step(graphdef, state, lm_state, batch):
60
+ m, opt = nnx.merge(graphdef, state)
61
+ updates, lm_state, info = solver.update(m, lm_state, batch)
62
+ opt.update(m, updates)
63
+ return lm_state, info, nnx.state((m, opt))
64
+
65
+
66
+ graphdef, state = nnx.split((model, optimizer))
67
+ for _ in range(50):
68
+ lm_state, info, state = train_step(graphdef, state, lm_state, (x, y))
69
+ nnx.update((model, optimizer), state)
70
+
71
+ print(model.a[...], model.b[...]) # ~2.0, ~-1.0
72
+ ```
73
+
74
+ ## API reference
75
+
76
+ ::: nlls_gram.GramLevenbergMarquardt
77
+
78
+ ::: nlls_gram.flat_residual
@@ -0,0 +1,40 @@
1
+ site_name: nlls_gram
2
+ site_description: Gram/dual-form Levenberg-Marquardt nonlinear least-squares for JAX/Flax NNX models
3
+ site_url: https://highdimensionaleconlab.github.io/nlls_gram/
4
+ repo_url: https://github.com/HighDimensionalEconLab/nlls_gram
5
+ repo_name: HighDimensionalEconLab/nlls_gram
6
+
7
+ theme:
8
+ name: material
9
+ features:
10
+ - content.code.copy
11
+ - navigation.top
12
+ palette:
13
+ - scheme: default
14
+ primary: indigo
15
+ toggle:
16
+ icon: material/brightness-7
17
+ name: Switch to dark mode
18
+ - scheme: slate
19
+ primary: indigo
20
+ toggle:
21
+ icon: material/brightness-4
22
+ name: Switch to light mode
23
+
24
+ plugins:
25
+ - search
26
+ - mkdocstrings:
27
+ handlers:
28
+ python:
29
+ options:
30
+ show_root_heading: true
31
+ show_source: true
32
+
33
+ markdown_extensions:
34
+ - pymdownx.highlight:
35
+ anchor_linenums: true
36
+ - pymdownx.superfences
37
+ - admonition
38
+
39
+ nav:
40
+ - Home: index.md
@@ -0,0 +1,63 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "nlls-gram"
7
+ version = "0.1.0"
8
+ description = "Gram/dual-form Levenberg-Marquardt nonlinear least-squares solvers for JAX/Flax NNX models"
9
+ readme = "README.md"
10
+ license = "MIT"
11
+ license-files = ["LICENSE"]
12
+ requires-python = ">=3.11"
13
+ authors = [{ name = "Jesse Perla", email = "jesseperla@gmail.com" }]
14
+ keywords = [
15
+ "levenberg-marquardt",
16
+ "nonlinear-least-squares",
17
+ "jax",
18
+ "flax",
19
+ "nnx",
20
+ "optimization",
21
+ ]
22
+ classifiers = [
23
+ "Development Status :: 3 - Alpha",
24
+ "Intended Audience :: Science/Research",
25
+ "Programming Language :: Python :: 3.11",
26
+ "Programming Language :: Python :: 3.12",
27
+ "Programming Language :: Python :: 3.13",
28
+ "Topic :: Scientific/Engineering :: Mathematics",
29
+ ]
30
+ dependencies = [
31
+ "flax>=0.10.7",
32
+ "jax>=0.7.0",
33
+ "optax>=0.2.4",
34
+ ]
35
+
36
+ [project.urls]
37
+ Homepage = "https://github.com/HighDimensionalEconLab/nlls_gram"
38
+ Repository = "https://github.com/HighDimensionalEconLab/nlls_gram"
39
+ Documentation = "https://highdimensionaleconlab.github.io/nlls_gram/"
40
+ Issues = "https://github.com/HighDimensionalEconLab/nlls_gram/issues"
41
+
42
+ [tool.hatch.build.targets.wheel]
43
+ packages = ["src/nlls_gram"]
44
+
45
+ [tool.ruff]
46
+ line-length = 88
47
+ target-version = "py311"
48
+
49
+ [tool.ruff.lint]
50
+ select = ["E", "F", "I", "UP", "B"]
51
+
52
+ [tool.pytest.ini_options]
53
+ testpaths = ["tests"]
54
+
55
+ [dependency-groups]
56
+ dev = [
57
+ "pytest>=8",
58
+ "ruff>=0.15.17",
59
+ ]
60
+ docs = [
61
+ "mkdocs-material>=9.5",
62
+ "mkdocstrings[python]>=0.26",
63
+ ]
@@ -0,0 +1,16 @@
1
+ """Gram/dual-form Levenberg-Marquardt nonlinear least-squares for JAX/Flax NNX models.
2
+
3
+ GramLevenbergMarquardt minimizes ||r(theta)||^2 for an nnx.Module-defined residual,
4
+ following the optax/nnx init/update protocol so steps apply through
5
+ nnx.Optimizer(model, optax.identity(), wrt=...). For overparameterized systems
6
+ (p parameters >> n residual rows) it factors the small n x n gram (dual) system.
7
+ """
8
+
9
+ from nlls_gram.gram_lm import (
10
+ GramLevenbergMarquardt,
11
+ LMInfo,
12
+ LMState,
13
+ flat_residual,
14
+ )
15
+
16
+ __all__ = ["GramLevenbergMarquardt", "LMState", "LMInfo", "flat_residual"]
@@ -0,0 +1,94 @@
1
+ from typing import NamedTuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.flatten_util import ravel_pytree
7
+
8
+ # optax/nnx-style least-squares and root-finding solvers: each class exposes
9
+ # init() -> state and update(model, state, batch) -> (updates, state, info), where
10
+ # updates are parameter *increments* structured like nnx.split(model, wrt, ...)[1],
11
+ # meant to be applied through nnx.Optimizer(model, optax.identity(), wrt=...).
12
+ # Classes do not jit internally -- wrap the caller's train step in jax.jit. All
13
+ # hyperparameters are static Python scalars; all data-dependent control flow is
14
+ # traced (jnp.where / lax.while_loop), so a rejected step returns zero updates
15
+ # rather than branching.
16
+
17
+
18
+ # Functionalize a model into a flat parameter vector and a residual function of it:
19
+ # residual_fn(model, batch) becomes residual_flat(theta) with theta the raveled
20
+ # wrt-filtered state (non-wrt state captured and passed through unchanged).
21
+ def flat_residual(model, wrt, residual_fn, batch):
22
+ graphdef, diff_state, rest = nnx.split(model, wrt, ...)
23
+ theta, unravel = ravel_pytree(diff_state)
24
+
25
+ def residual_flat(th):
26
+ m = nnx.merge(graphdef, unravel(th), rest)
27
+ return jnp.ravel(residual_fn(m, batch))
28
+
29
+ return theta, unravel, residual_flat
30
+
31
+
32
+ class LMState(NamedTuple):
33
+ damping: jax.Array
34
+
35
+
36
+ class LMInfo(NamedTuple):
37
+ loss: jax.Array # min(old, new) sum of squared residuals
38
+ accepted: jax.Array
39
+ damping: jax.Array # post-update damping
40
+
41
+
42
+ # Classic Marquardt damping on min ||R(theta)||^2: accept the step iff the sum of
43
+ # squared residuals decreases, multiplying the damping by damping_decrease on
44
+ # acceptance and damping_increase on rejection. solve_method picks which linear
45
+ # system is factored for the identical step:
46
+ # "gram": step = -J' (J J' + damping I_n)^{-1} R (n x n dual; right for n << p)
47
+ # "normal": step = -(J'J + damping I_p)^{-1} J' R (p x p; right for p <~ n)
48
+ class GramLevenbergMarquardt:
49
+ def __init__(
50
+ self,
51
+ residual_fn,
52
+ *,
53
+ init_damping=1e-3,
54
+ damping_decrease=0.5,
55
+ damping_increase=4.0,
56
+ solve_method="gram",
57
+ wrt=nnx.Param,
58
+ ):
59
+ if solve_method not in ("gram", "normal"):
60
+ raise ValueError(f"unknown solve_method: {solve_method}")
61
+ self.residual_fn = residual_fn
62
+ self.init_damping = init_damping
63
+ self.damping_decrease = damping_decrease
64
+ self.damping_increase = damping_increase
65
+ self.solve_method = solve_method
66
+ self.wrt = wrt
67
+
68
+ def init(self):
69
+ return LMState(jnp.asarray(self.init_damping))
70
+
71
+ def update(self, model, state, batch):
72
+ theta, unravel, residual_flat = flat_residual(
73
+ model, self.wrt, self.residual_fn, batch
74
+ )
75
+ resid = residual_flat(theta)
76
+ J = jax.jacrev(residual_flat)(theta)
77
+ if self.solve_method == "gram":
78
+ step = -J.T @ jnp.linalg.solve(
79
+ J @ J.T + state.damping * jnp.eye(J.shape[0]), resid
80
+ )
81
+ else:
82
+ step = -jnp.linalg.solve(
83
+ J.T @ J + state.damping * jnp.eye(J.shape[1]), J.T @ resid
84
+ )
85
+ resid_new = residual_flat(theta + step)
86
+ improved = jnp.sum(resid_new**2) < jnp.sum(resid**2)
87
+ updates_flat = jnp.where(improved, step, jnp.zeros_like(step))
88
+ damping = jnp.where(
89
+ improved,
90
+ state.damping * self.damping_decrease,
91
+ state.damping * self.damping_increase,
92
+ )
93
+ loss = jnp.minimum(jnp.sum(resid_new**2), jnp.sum(resid**2))
94
+ return unravel(updates_flat), LMState(damping), LMInfo(loss, improved, damping)
@@ -0,0 +1,50 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import optax
4
+ from flax import nnx
5
+
6
+ from nlls_gram import GramLevenbergMarquardt
7
+
8
+ jax.config.update("jax_enable_x64", True)
9
+
10
+
11
+ class ExpModel(nnx.Module):
12
+ def __init__(self, a, b):
13
+ self.a = nnx.Param(jnp.asarray(a))
14
+ self.b = nnx.Param(jnp.asarray(b))
15
+
16
+ def __call__(self, x):
17
+ return self.a * jnp.exp(self.b * x)
18
+
19
+
20
+ def residual_fn(model, batch):
21
+ x, y = batch
22
+ return model(x) - y
23
+
24
+
25
+ def test_recovers_known_parameters():
26
+ a_true, b_true = 2.0, -1.0
27
+ x = jnp.linspace(0.0, 2.0, 20)
28
+ y = a_true * jnp.exp(b_true * x)
29
+
30
+ model = ExpModel(a=1.0, b=0.0)
31
+ solver = GramLevenbergMarquardt(residual_fn, init_damping=1e-2)
32
+ optimizer = nnx.Optimizer(model, optax.identity(), wrt=nnx.Param)
33
+ lm_state = solver.init()
34
+
35
+ @jax.jit
36
+ def train_step(graphdef, state, lm_state, batch):
37
+ m, opt = nnx.merge(graphdef, state)
38
+ updates, lm_state, info = solver.update(m, lm_state, batch)
39
+ opt.update(m, updates)
40
+ return lm_state, info, nnx.state((m, opt))
41
+
42
+ graphdef, state = nnx.split((model, optimizer))
43
+ info = None
44
+ for _ in range(50):
45
+ lm_state, info, state = train_step(graphdef, state, lm_state, (x, y))
46
+ nnx.update((model, optimizer), state)
47
+
48
+ assert float(info.loss) < 1e-10
49
+ assert jnp.allclose(model.a[...], a_true, atol=1e-4)
50
+ assert jnp.allclose(model.b[...], b_true, atol=1e-4)