amjax 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,23 @@
1
+ name: Release
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - name: Release
13
+ uses: patrick-kidger/action_update_python_project@v8
14
+ with:
15
+ python-version: "3.11"
16
+ test-script: |
17
+ cp -r ${{ github.workspace }}/tests ./tests
18
+ cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml
19
+ uv sync --extra tests --no-install-project --inexact
20
+ uv run --no-sync pytest
21
+ pypi-token: ${{ secrets.pypi_token }}
22
+ github-user: vboussange
23
+ github-token: ${{ github.token }} # automatically created token
@@ -0,0 +1,14 @@
1
+ name: Tests
2
+
3
+ on: [push, pull_request]
4
+
5
+ jobs:
6
+ test:
7
+ runs-on: ubuntu-latest
8
+ steps:
9
+ - uses: actions/checkout@v4
10
+ - uses: astral-sh/setup-uv@v3
11
+ - name: Install dependencies
12
+ run: uv sync --extra tests
13
+ - name: Run tests
14
+ run: uv run pytest tests/ -v
amjax-0.0.1/.gitignore ADDED
@@ -0,0 +1,216 @@
1
+ # macOS
2
+ .DS_Store
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[codz]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py.cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # UV
101
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ #uv.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+ #poetry.toml
113
+
114
+ # pdm
115
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
117
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
118
+ #pdm.lock
119
+ #pdm.toml
120
+ .pdm-python
121
+ .pdm-build/
122
+
123
+ # pixi
124
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
125
+ #pixi.lock
126
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
127
+ # in the .venv directory. It is recommended not to include this directory in version control.
128
+ .pixi
129
+
130
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
131
+ __pypackages__/
132
+
133
+ # Celery stuff
134
+ celerybeat-schedule
135
+ celerybeat.pid
136
+
137
+ # SageMath parsed files
138
+ *.sage.py
139
+
140
+ # Environments
141
+ .env
142
+ .envrc
143
+ .venv
144
+ env/
145
+ venv/
146
+ ENV/
147
+ env.bak/
148
+ venv.bak/
149
+
150
+ # Spyder project settings
151
+ .spyderproject
152
+ .spyproject
153
+
154
+ # Rope project settings
155
+ .ropeproject
156
+
157
+ # mkdocs documentation
158
+ /site
159
+
160
+ # mypy
161
+ .mypy_cache/
162
+ .dmypy.json
163
+ dmypy.json
164
+
165
+ # Pyre type checker
166
+ .pyre/
167
+
168
+ # pytype static type analyzer
169
+ .pytype/
170
+
171
+ # Cython debug symbols
172
+ cython_debug/
173
+
174
+ # PyCharm
175
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
176
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
177
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
178
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
179
+ #.idea/
180
+
181
+ # Abstra
182
+ # Abstra is an AI-powered process automation framework.
183
+ # Ignore directories containing user credentials, local state, and settings.
184
+ # Learn more at https://abstra.io/docs
185
+ .abstra/
186
+
187
+ # Visual Studio Code
188
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
189
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
190
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
191
+ # you could uncomment the following to ignore the entire vscode folder
192
+ # .vscode/
193
+
194
+ # Ruff stuff:
195
+ .ruff_cache/
196
+
197
+ # PyPI configuration file
198
+ .pypirc
199
+
200
+ # Cursor
201
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
202
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
203
+ # refer to https://docs.cursor.com/context/ignore-files
204
+ .cursorignore
205
+ .cursorindexingignore
206
+
207
+ # Marimo
208
+ marimo/_static/
209
+ marimo/_lsp/
210
+ __marimo__/
211
+
212
+ # AMJax — local benchmark config (personal, not for version control)
213
+ benchmarks/config_local.yaml
214
+
215
+ benchmarks/results/
216
+ benchmarks/results_pinv/
amjax-0.0.1/LICENSE ADDED
@@ -0,0 +1,9 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026-present Victor Boussange <vic.boussange@gmail.com>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
amjax-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,128 @@
1
+ Metadata-Version: 2.4
2
+ Name: amjax
3
+ Version: 0.0.1
4
+ Summary: JAX Algebraic Multigrid Solvers in Python
5
+ Project-URL: Documentation, https://github.com/vboussange/AMJax#readme
6
+ Project-URL: Source, https://github.com/vboussange/AMJax
7
+ Author-email: Victor Boussange <vic.boussange@gmail.com>, Fanny Missillier <fanny.missillier@epfl.ch>
8
+ Maintainer-email: Victor Boussange <vic.boussange@gmail.com>, Fanny Missillier <fanny.missillier@epfl.ch>
9
+ License-File: LICENSE
10
+ Requires-Python: >=3.11
11
+ Requires-Dist: jax>=0.6.1
12
+ Requires-Dist: matplotlib>=3.0.0
13
+ Requires-Dist: numpy>2.4.0
14
+ Requires-Dist: pyamg>=4.0.0
15
+ Requires-Dist: scipy>=1.11.0
16
+ Provides-Extra: tests
17
+ Requires-Dist: pytest; extra == 'tests'
18
+ Description-Content-Type: text/markdown
19
+
20
+ # AMJax
21
+
22
+ AMJax bridges PyAMG and JAX for algebraic multigrid (AMG) solvers: it converts PyAMG-constructed hierarchies into `jax.{jit,grad,vmap}`-compatible, multi-level solvers and preconditioners for large sparse linear systems.
23
+
24
+ ## Installation
25
+
26
+ Install directly from GitHub (PyPI release coming soon):
27
+
28
+ ```bash
29
+ uv add git+https://github.com/vboussange/AMJax.git
30
+ ```
31
+
32
+ ## Usage
33
+
34
+ ### Direct solve
35
+
36
+ ```python
37
+ import pyamg
38
+ import jax
39
+ import jax.numpy as jnp
40
+ from amjax import AMJAXSolver
41
+
42
+ A = pyamg.gallery.poisson((100, 100), format="csr")
43
+ b = jnp.ones(A.shape[0])
44
+
45
+ ml = AMJAXSolver.from_pyamg(pyamg.ruge_stuben_solver(A))
46
+
47
+ solve = jax.jit(lambda b: ml.solve(b, tol=1e-10, maxiter=100))
48
+ x = solve(b)
49
+ ```
50
+
51
+ ### Preconditioning
52
+
53
+ `AMJAXSolver` exposes a preconditioner compatible with any JAX Krylov solver:
54
+
55
+ ```python
56
+ from jax.experimental import sparse as jsparse
57
+
58
+ A_jax = jsparse.BCOO.from_scipy_sparse(A)
59
+ M = ml.aspreconditioner(cycle='V')
60
+
61
+ x, info = jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10, maxiter=30)
62
+ ```
63
+
64
+ ### Batched solve with `jax.vmap`
65
+
66
+ ```python
67
+ import numpy as np
68
+
69
+ B = jnp.array(np.random.rand(4, A.shape[0])) # (n_rhs, n)
70
+ solve_batch = jax.jit(jax.vmap(lambda b: ml.solve(b, tol=1e-8, maxiter=100)))
71
+ X = solve_batch(B)
72
+ ```
73
+
74
+ ### Differentiating through the solve with `jax.grad`
75
+
76
+ ```python
77
+ f = lambda b: jnp.sum(ml.solve(b, tol=1e-10, maxiter=100))
78
+ grad = jax.grad(f)(b)
79
+ ```
80
+
81
+ ### Differentiation with preconditioning
82
+
83
+ ```python
84
+ f = lambda b: jnp.sum(jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10)[0])
85
+ grad = jax.grad(f)(b)
86
+ ```
87
+
88
+ ## Features
89
+
90
+ - V, W and F cycles compiled with `jax.jit`
91
+ - Coarse solvers: `jacobi`, `lu`, `qr`, `pinv`
92
+ - Smoothers: `jacobi`
93
+ - AMG preconditioning for JAX Krylov solvers (e.g. `jax.scipy.sparse.linalg.cg`)
94
+ - `jax.vmap` support for batched right-hand sides
95
+ - `jax.grad` support through both direct solve and preconditioned Krylov solvers
96
+
97
+ ## Solvers
98
+
99
+ `AMJAXSolver.from_pyamg` accepts any hierarchy produced by a PyAMG factory:
100
+
101
+ | Factory | Intended for |
102
+ |---------|--------------|
103
+ | `pyamg.smoothed_aggregation_solver` | SPD systems, standard aggregation AMG |
104
+ | `pyamg.rootnode_solver` | SPD systems, robust for anisotropic problems |
105
+ | `pyamg.pairwise_solver` | SPD systems, fast setup, weaker convergence |
106
+ | `pyamg.ruge_stuben_solver` | General SPD systems, classical C/F splitting |
107
+ | `pyamg.air_solver` | Non-symmetric systems |
108
+
109
+ **Current limitations:** V-cycle only. `jacobi` coarse solver only.
110
+
111
+ ## Benchmark
112
+
113
+ An exhaustive benchmark can be run in Colab:
114
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vboussange/AMJax/blob/main/benchmarks/benchmark.ipynb)
115
+ <!-- And results are reported at XXX -->
116
+
117
+ Some key insights on **speedup gains vs PyAMG-based counterpart**:
118
+
119
+
120
+ | Scenario | Method | CPU | GPU |
121
+ |----------|--------|----:|----:|
122
+ | Single solve ($Ax=b$, $b \in \mathbb{R}^n$) | AMJax | - | ~16× |
123
+ | Single solve ($Ax=b$, $b \in \mathbb{R}^n$) | AMJax + CG | - | ~17× |
124
+ | Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, `jax.vmap`) | AMJax | 0.7× | ~21× |
125
+ | Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, `jax.vmap`) | AMJax + CG | - | ~23× |
126
+
127
+
128
+ > Settings: Ruge-Stüben hierarchy, V-cycle, Jacobi smoother, `pinv` coarse solver, $n = 1{,}000$, f64, rtol $= 10^{-10}$, max 100 iterations. JAX times exclude JIT compilation. GPU speedup is relative to the PyAMG CPU counterpart.
amjax-0.0.1/README.md ADDED
@@ -0,0 +1,109 @@
1
+ # AMJax
2
+
3
+ AMJax bridges PyAMG and JAX for algebraic multigrid (AMG) solvers: it converts PyAMG-constructed hierarchies into `jax.{jit,grad,vmap}`-compatible, multi-level solvers and preconditioners for large sparse linear systems.
4
+
5
+ ## Installation
6
+
7
+ Install directly from GitHub (PyPI release coming soon):
8
+
9
+ ```bash
10
+ uv add git+https://github.com/vboussange/AMJax.git
11
+ ```
12
+
13
+ ## Usage
14
+
15
+ ### Direct solve
16
+
17
+ ```python
18
+ import pyamg
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from amjax import AMJAXSolver
22
+
23
+ A = pyamg.gallery.poisson((100, 100), format="csr")
24
+ b = jnp.ones(A.shape[0])
25
+
26
+ ml = AMJAXSolver.from_pyamg(pyamg.ruge_stuben_solver(A))
27
+
28
+ solve = jax.jit(lambda b: ml.solve(b, tol=1e-10, maxiter=100))
29
+ x = solve(b)
30
+ ```
31
+
32
+ ### Preconditioning
33
+
34
+ `AMJAXSolver` exposes a preconditioner compatible with any JAX Krylov solver:
35
+
36
+ ```python
37
+ from jax.experimental import sparse as jsparse
38
+
39
+ A_jax = jsparse.BCOO.from_scipy_sparse(A)
40
+ M = ml.aspreconditioner(cycle='V')
41
+
42
+ x, info = jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10, maxiter=30)
43
+ ```
44
+
45
+ ### Batched solve with `jax.vmap`
46
+
47
+ ```python
48
+ import numpy as np
49
+
50
+ B = jnp.array(np.random.rand(4, A.shape[0])) # (n_rhs, n)
51
+ solve_batch = jax.jit(jax.vmap(lambda b: ml.solve(b, tol=1e-8, maxiter=100)))
52
+ X = solve_batch(B)
53
+ ```
54
+
55
+ ### Differentiating through the solve with `jax.grad`
56
+
57
+ ```python
58
+ f = lambda b: jnp.sum(ml.solve(b, tol=1e-10, maxiter=100))
59
+ grad = jax.grad(f)(b)
60
+ ```
61
+
62
+ ### Differentiation with preconditioning
63
+
64
+ ```python
65
+ f = lambda b: jnp.sum(jax.scipy.sparse.linalg.cg(A_jax, b, M=M, tol=1e-10)[0])
66
+ grad = jax.grad(f)(b)
67
+ ```
68
+
69
+ ## Features
70
+
71
+ - V, W and F cycles compiled with `jax.jit`
72
+ - Coarse solvers: `jacobi`, `lu`, `qr`, `pinv`
73
+ - Smoothers: `jacobi`
74
+ - AMG preconditioning for JAX Krylov solvers (e.g. `jax.scipy.sparse.linalg.cg`)
75
+ - `jax.vmap` support for batched right-hand sides
76
+ - `jax.grad` support through both direct solve and preconditioned Krylov solvers
77
+
78
+ ## Solvers
79
+
80
+ `AMJAXSolver.from_pyamg` accepts any hierarchy produced by a PyAMG factory:
81
+
82
+ | Factory | Intended for |
83
+ |---------|--------------|
84
+ | `pyamg.smoothed_aggregation_solver` | SPD systems, standard aggregation AMG |
85
+ | `pyamg.rootnode_solver` | SPD systems, robust for anisotropic problems |
86
+ | `pyamg.pairwise_solver` | SPD systems, fast setup, weaker convergence |
87
+ | `pyamg.ruge_stuben_solver` | General SPD systems, classical C/F splitting |
88
+ | `pyamg.air_solver` | Non-symmetric systems |
89
+
90
+ **Current limitations:** V-cycle only. `jacobi` coarse solver only.
91
+
92
+ ## Benchmark
93
+
94
+ An exhaustive benchmark can be run in Colab:
95
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vboussange/AMJax/blob/main/benchmarks/benchmark.ipynb)
96
+ <!-- And results are reported at XXX -->
97
+
98
+ Some key insights on **speedup gains vs PyAMG-based counterpart**:
99
+
100
+
101
+ | Scenario | Method | CPU | GPU |
102
+ |----------|--------|----:|----:|
103
+ | Single solve ($Ax=b$, $b \in \mathbb{R}^n$) | AMJax | - | ~16× |
104
+ | Single solve ($Ax=b$, $b \in \mathbb{R}^n$) | AMJax + CG | - | ~17× |
105
+ | Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, `jax.vmap`) | AMJax | 0.7× | ~21× |
106
+ | Batched solve ($AX=B$, $B \in \mathbb{R}^{n \times K}$, $K=64$, `jax.vmap`) | AMJax + CG | - | ~23× |
107
+
108
+
109
+ > Settings: Ruge-Stüben hierarchy, V-cycle, Jacobi smoother, `pinv` coarse solver, $n = 1{,}000$, f64, rtol $= 10^{-10}$, max 100 iterations. JAX times exclude JIT compilation. GPU speedup is relative to the PyAMG CPU counterpart.
@@ -0,0 +1,10 @@
1
+ from .multilevel import MultilevelSolver
2
+ from .relaxation.smoothing import change_smoothers
3
+ from .relaxation.relaxation import jacobi, inverse_diagonal
4
+
5
+ __all__ = [
6
+ "MultilevelSolver",
7
+ "change_smoothers",
8
+ "jacobi",
9
+ "inverse_diagonal",
10
+ ]