n4ax 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
n4ax-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Gragas
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
n4ax-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,132 @@
1
+ Metadata-Version: 2.4
2
+ Name: n4ax
3
+ Version: 0.1.0
4
+ Summary: JAX/GPU N4 bias field correction — a fast drop-in match for ITK N4
5
+ Author: Geoffroy Oudoumanessah, Jacopo Iollo
6
+ Author-email: Gragas <contact@gragas.ai>
7
+ License-Expression: MIT
8
+ Project-URL: Homepage, https://github.com/GragasLab/n4ax
9
+ Project-URL: Repository, https://github.com/GragasLab/n4ax
10
+ Keywords: MRI,bias field,N4,N3,JAX,GPU
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
17
+ Requires-Python: >=3.12
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: jax>=0.8.0
21
+ Requires-Dist: numpy>=1.24.0
22
+ Provides-Extra: cpu
23
+ Requires-Dist: jax[cpu]>=0.8.0; extra == "cpu"
24
+ Provides-Extra: cuda12
25
+ Requires-Dist: jax[cuda12]>=0.8.0; extra == "cuda12"
26
+ Provides-Extra: compare
27
+ Requires-Dist: SimpleITK>=2.3.0; extra == "compare"
28
+ Requires-Dist: matplotlib>=3.7; extra == "compare"
29
+ Requires-Dist: nibabel>=5.0; extra == "compare"
30
+ Provides-Extra: dev
31
+ Requires-Dist: pytest>=7.0; extra == "dev"
32
+ Requires-Dist: pytest-cov; extra == "dev"
33
+ Requires-Dist: ruff>=0.4.0; extra == "dev"
34
+ Requires-Dist: pre-commit>=3.0; extra == "dev"
35
+ Requires-Dist: SimpleITK>=2.3.0; extra == "dev"
36
+ Dynamic: license-file
37
+
38
+ # n4ax
39
+
40
+ **N4 bias field correction in pure JAX** — a fast, GPU-friendly, *drop-in* match for
41
+ ITK / SimpleITK's `N4BiasFieldCorrectionImageFilter`.
42
+
43
+ n4ax reimplements the N4 algorithm (Tustison et al., 2010 — N3 histogram sharpening
44
+ + multi-resolution B-spline) faithfully enough to **match SimpleITK to ~1%** on real
45
+ MRI, while running **~1500× faster on a GPU** and **~20× faster on the same CPU**.
46
+
47
+ ![NKI raw vs n4ax-corrected vs ITK-corrected](assets/nki_sub-0002_gpu.png)
48
+ *Raw NKI T1w (with B1 shading) → n4ax-corrected → ITK-corrected (visually identical) → estimated bias field.*
49
+
50
+ ## Why
51
+
52
+ N4 is the de-facto standard bias correction, but ITK's implementation is CPU-only and
53
+ slow (minutes per volume). In a GPU MRI pipeline it becomes the bottleneck. n4ax gives
54
+ **N4-quality output on the GPU in tens of milliseconds**, with no custom CUDA — just JAX.
55
+
56
+ ## Install
57
+
58
+ ```bash
59
+ uv sync --extra cuda12 # GPU (CUDA 12)
60
+ uv sync --extra cpu # CPU
61
+ uv sync --extra cuda12 --extra dev # + tests/linting
62
+ uv sync --extra cuda12 --extra compare # + SimpleITK/matplotlib for benchmarks
63
+ ```
64
+
65
+ ## Usage
66
+
67
+ ```python
68
+ import nibabel as nib
69
+ import n4ax
70
+
71
+ vol = nib.load("t1w.nii.gz").get_fdata() # 3D (or 2D) array, intensities >= 0
72
+ corrected = n4ax.n4(vol) # Otsu mask computed automatically
73
+ # or pass your own mask, and/or get the log bias field:
74
+ corrected, log_bias = n4ax.n4(vol, mask=mask, return_bias=True)
75
+ ```
76
+
77
+ `corrected == vol / exp(log_bias)`. The default config (`iters=(8,12,12,8)`,
78
+ `over_relax=1.8`) is tuned for speed; for the tightest ITK match use the robust
79
+ fallback `n4ax.n4(vol, iters=(50,50,30,20), over_relax=1.0, conv_threshold=1.5e-3)`.
80
+
81
+ ## Benchmark
82
+
83
+ Real NKI T1w volumes (256×176×256, ~2 M brain voxels), N4 `[50,50,30,20]`, same Otsu mask.
84
+ ITK on an 8-core CPU; n4ax CPU on the same node; n4ax GPU on an NVIDIA A100.
85
+
86
+ | Method | Time / volume | Speedup vs ITK |
87
+ |---|--:|--:|
88
+ | ITK N4 (CPU, 8 cores) | **146 s** | 1× |
89
+ | n4ax (CPU, 8 cores) | **7.7 s** | **~19×** |
90
+ | n4ax (A100 GPU) | **93 ms** | **~1571×** |
91
+
92
+ **Accuracy vs ITK** (corrected image, global scale removed — pipelines intensity-normalise anyway):
93
+ mean **1.15 %**, per-subject 0.79–1.59 % over 6 NKI scans. On a single fitting level n4ax
94
+ matches ITK to **0.4 %**, and a single N4 iteration to **0.1 %** — the building blocks are exact;
95
+ the residual is N4's own iterative crawl (ITK itself only converges after ~30 iters/level).
96
+
97
+ Multiple subjects, raw (top) vs n4ax-corrected (bottom):
98
+
99
+ ![NKI grid](assets/nki_grid_gpu.png)
100
+
101
+ Reproduce: `python scripts/bench_nki.py` (GPU) and `JAX_PLATFORMS=cpu python scripts/bench_nki.py --skip-itk --skip-fig --tag cpu`.
102
+
103
+ ## How it's fast (no custom kernels)
104
+
105
+ - **Separable B-spline fit.** N4's per-iteration B-spline least-squares (Lee MBA) is a
106
+ 94 M-way scatter into a tiny control lattice — brutal atomic contention (~30 ms/iter).
107
+ Because the cubic weights depend only on the per-axis index and the Lee denominator
108
+ factorises, this becomes **3 small dense matmuls per axis** (cuBLAS) — *identical math*,
109
+ 0.1 ms/iter.
110
+ - **Privatised histogram.** The N3 sharpening histogram (1.5 M → 200 bins) is privatised
111
+ over 256 lanes to avoid atomic serialisation.
112
+ - **Over-relaxation.** N4's fixed point is invariant to `B += α·S` (S = 0 there), so
113
+ `α ≈ 1.8` reaches ITK's result in far fewer iterations.
114
+ - The whole solve is one fused, jitted program with a device-side convergence loop.
115
+
116
+ Two things that mattered for *correctness*: zero-padding the sharpening FFT (circular
117
+ wraparound otherwise breaks convergence), and that float32 == float64 here (verified).
118
+
119
+ ## Tests
120
+
121
+ ```bash
122
+ uv run pytest # basic correctness + ground-truth match vs SimpleITK
123
+ ```
124
+
125
+ `tests/test_vs_itk.py` asserts n4ax matches SimpleITK's N4 (the reference) within tolerance
126
+ on a phantom; `tests/test_basic.py` covers shapes, 2D/3D, the `image/exp(bias)` identity,
127
+ bias flattening, and the Otsu mask.
128
+
129
+ ## Status
130
+
131
+ Alpha. The fast defaults are tuned on NKI/phantom data; validate on your own data before
132
+ production (the `iters=(50,50,30,20), over_relax=1.0` fallback is the conservative choice).
n4ax-0.1.0/README.md ADDED
@@ -0,0 +1,95 @@
1
+ # n4ax
2
+
3
+ **N4 bias field correction in pure JAX** — a fast, GPU-friendly, *drop-in* match for
4
+ ITK / SimpleITK's `N4BiasFieldCorrectionImageFilter`.
5
+
6
+ n4ax reimplements the N4 algorithm (Tustison et al., 2010 — N3 histogram sharpening
7
+ + multi-resolution B-spline) faithfully enough to **match SimpleITK to ~1%** on real
8
+ MRI, while running **~1500× faster on a GPU** and **~20× faster on the same CPU**.
9
+
10
+ ![NKI raw vs n4ax-corrected vs ITK-corrected](assets/nki_sub-0002_gpu.png)
11
+ *Raw NKI T1w (with B1 shading) → n4ax-corrected → ITK-corrected (visually identical) → estimated bias field.*
12
+
13
+ ## Why
14
+
15
+ N4 is the de-facto standard bias correction, but ITK's implementation is CPU-only and
16
+ slow (minutes per volume). In a GPU MRI pipeline it becomes the bottleneck. n4ax gives
17
+ **N4-quality output on the GPU in tens of milliseconds**, with no custom CUDA — just JAX.
18
+
19
+ ## Install
20
+
21
+ ```bash
22
+ uv sync --extra cuda12 # GPU (CUDA 12)
23
+ uv sync --extra cpu # CPU
24
+ uv sync --extra cuda12 --extra dev # + tests/linting
25
+ uv sync --extra cuda12 --extra compare # + SimpleITK/matplotlib for benchmarks
26
+ ```
27
+
28
+ ## Usage
29
+
30
+ ```python
31
+ import nibabel as nib
32
+ import n4ax
33
+
34
+ vol = nib.load("t1w.nii.gz").get_fdata() # 3D (or 2D) array, intensities >= 0
35
+ corrected = n4ax.n4(vol) # Otsu mask computed automatically
36
+ # or pass your own mask, and/or get the log bias field:
37
+ corrected, log_bias = n4ax.n4(vol, mask=mask, return_bias=True)
38
+ ```
39
+
40
+ `corrected == vol / exp(log_bias)`. The default config (`iters=(8,12,12,8)`,
41
+ `over_relax=1.8`) is tuned for speed; for the tightest ITK match use the robust
42
+ fallback `n4ax.n4(vol, iters=(50,50,30,20), over_relax=1.0, conv_threshold=1.5e-3)`.
43
+
44
+ ## Benchmark
45
+
46
+ Real NKI T1w volumes (256×176×256, ~2 M brain voxels), N4 `[50,50,30,20]`, same Otsu mask.
47
+ ITK on an 8-core CPU; n4ax CPU on the same node; n4ax GPU on an NVIDIA A100.
48
+
49
+ | Method | Time / volume | Speedup vs ITK |
50
+ |---|--:|--:|
51
+ | ITK N4 (CPU, 8 cores) | **146 s** | 1× |
52
+ | n4ax (CPU, 8 cores) | **7.7 s** | **~19×** |
53
+ | n4ax (A100 GPU) | **93 ms** | **~1571×** |
54
+
55
+ **Accuracy vs ITK** (corrected image, global scale removed — pipelines intensity-normalise anyway):
56
+ mean **1.15 %**, per-subject 0.79–1.59 % over 6 NKI scans. On a single fitting level n4ax
57
+ matches ITK to **0.4 %**, and a single N4 iteration to **0.1 %** — the building blocks are exact;
58
+ the residual is N4's own iterative crawl (ITK itself only converges after ~30 iters/level).
59
+
60
+ Multiple subjects, raw (top) vs n4ax-corrected (bottom):
61
+
62
+ ![NKI grid](assets/nki_grid_gpu.png)
63
+
64
+ Reproduce: `python scripts/bench_nki.py` (GPU) and `JAX_PLATFORMS=cpu python scripts/bench_nki.py --skip-itk --skip-fig --tag cpu`.
65
+
66
+ ## How it's fast (no custom kernels)
67
+
68
+ - **Separable B-spline fit.** N4's per-iteration B-spline least-squares (Lee MBA) is a
69
+ 94 M-way scatter into a tiny control lattice — brutal atomic contention (~30 ms/iter).
70
+ Because the cubic weights depend only on the per-axis index and the Lee denominator
71
+ factorises, this becomes **3 small dense matmuls per axis** (cuBLAS) — *identical math*,
72
+ 0.1 ms/iter.
73
+ - **Privatised histogram.** The N3 sharpening histogram (1.5 M → 200 bins) is privatised
74
+ over 256 lanes to avoid atomic serialisation.
75
+ - **Over-relaxation.** N4's fixed point is invariant to `B += α·S` (S = 0 there), so
76
+ `α ≈ 1.8` reaches ITK's result in far fewer iterations.
77
+ - The whole solve is one fused, jitted program with a device-side convergence loop.
78
+
79
+ Two things that mattered for *correctness*: zero-padding the sharpening FFT (circular
80
+ wraparound otherwise breaks convergence), and that float32 == float64 here (verified).
81
+
82
+ ## Tests
83
+
84
+ ```bash
85
+ uv run pytest # basic correctness + ground-truth match vs SimpleITK
86
+ ```
87
+
88
+ `tests/test_vs_itk.py` asserts n4ax matches SimpleITK's N4 (the reference) within tolerance
89
+ on a phantom; `tests/test_basic.py` covers shapes, 2D/3D, the `image/exp(bias)` identity,
90
+ bias flattening, and the Otsu mask.
91
+
92
+ ## Status
93
+
94
+ Alpha. The fast defaults are tuned on NKI/phantom data; validate on your own data before
95
+ production (the `iters=(50,50,30,20), over_relax=1.0` fallback is the conservative choice).
@@ -0,0 +1,6 @@
1
+ """n4ax — JAX/GPU N4 bias field correction (a fast drop-in match for ITK N4)."""
2
+
3
+ from .core import n4, otsu_mask
4
+
5
+
6
+ __all__ = ["n4", "otsu_mask"]
@@ -0,0 +1,256 @@
1
+ """N4 bias field correction in pure JAX — a fast, GPU-friendly drop-in match for
2
+ ITK / SimpleITK's ``N4BiasFieldCorrectionImageFilter``.
3
+
4
+ Algorithm (Tustison 2010 N4 = N3 histogram sharpening + multi-resolution B-spline):
5
+
6
+ u = log(v) over the mask; B = 0 (log bias field)
7
+ for each fitting level (mesh = 1, 2, 4, 8):
8
+ repeat:
9
+ uc = u - B
10
+ E = sharpen(uc) # N3 histogram deconvolution (Wiener, FFT)
11
+ S = bspline_fit(uc - E) # cubic B-spline least-squares (Lee MBA)
12
+ B = B + over_relax * S
13
+ corrected = exp(u - B) # == v / exp(B)
14
+
15
+ Every building block matches ITK's implementation (parametric coords, cubic
16
+ B-spline weights, Lee-MBA accumulation, N3 Wiener deconvolution). Two ideas make
17
+ it fast on a GPU without any custom kernel:
18
+
19
+ * the B-spline fit is **separable** (weights depend only on the per-axis index and
20
+ the Lee denominator factorises), so the per-voxel scatter into the control
21
+ lattice becomes three small dense matmuls per axis — no atomic contention;
22
+ * the sharpening histogram is **privatised** over K lanes so the value-scatter
23
+ doesn't serialise on atomics.
24
+
25
+ ``over_relax > 1`` accelerates N4's (slow, monotone) crawl to the *same* fixed
26
+ point (S = 0 there, so ``B += a*S`` is fixed-point invariant), reaching ITK's
27
+ result in far fewer iterations.
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import functools
33
+
34
+ import jax
35
+ import jax.numpy as jnp
36
+
37
+
38
+ SPLINE_ORDER = 3
39
+ REAL = jnp.float32 # the iteration is float-precision-insensitive (verified vs float64)
40
+
41
+
42
+ # ----------------------------- Otsu mask ------------------------------------
43
+ def otsu_mask(volume, nbins: int = 200):
44
+ """Binary foreground mask via Otsu's threshold (matches ITK ``OtsuThreshold``
45
+ with insideValue=0/outsideValue=1: foreground = intensity above threshold)."""
46
+ v = jnp.asarray(volume, REAL)
47
+ vmin, vmax = jnp.min(v), jnp.max(v)
48
+ edges = jnp.linspace(vmin, vmax, nbins + 1)
49
+ hist, _ = jnp.histogram(v, bins=edges)
50
+ hist = hist.astype(REAL)
51
+ centers = 0.5 * (edges[:-1] + edges[1:])
52
+ w = jnp.cumsum(hist)
53
+ wb = w
54
+ wf = w[-1] - w
55
+ csum = jnp.cumsum(hist * centers)
56
+ mb = jnp.where(wb > 0, csum / jnp.where(wb > 0, wb, 1), 0.0)
57
+ mf = jnp.where(wf > 0, (csum[-1] - csum) / jnp.where(wf > 0, wf, 1), 0.0)
58
+ between = wb * wf * (mb - mf) ** 2
59
+ thr = centers[jnp.argmax(between)]
60
+ return (v > thr).astype(REAL)
61
+
62
+
63
+ # ----------------------------- N3 sharpening --------------------------------
64
+ @functools.partial(jax.jit, static_argnums=(2,))
65
+ def _sharpen(uc, mask, nbins=200, fwhm=0.15, wiener=0.01):
66
+ """N3 histogram-deconvolution sharpening of the masked log image ``uc``."""
67
+ m = mask > 0.5
68
+ vmin = jnp.min(jnp.where(m, uc, jnp.inf))
69
+ vmax = jnp.max(jnp.where(m, uc, -jnp.inf))
70
+ slope = (vmax - vmin) / (nbins - 1)
71
+
72
+ # parzen (2-bin linear) histogram, privatised over K lanes to avoid atomic
73
+ # contention (the value-scatter was ~100% of sharpen's cost otherwise).
74
+ cidx = (uc - vmin) / slope
75
+ idx = jnp.floor(cidx).astype(jnp.int32)
76
+ off = cidx - idx
77
+ w = m.astype(REAL).reshape(-1)
78
+ idxf = idx.reshape(-1)
79
+ offf = off.reshape(-1)
80
+ K = 256
81
+ lane = jnp.arange(idxf.shape[0], dtype=jnp.int32) % K
82
+ Hp = jnp.zeros((K, nbins), REAL)
83
+ Hp = Hp.at[lane, jnp.clip(idxf, 0, nbins - 1)].add(w * (1.0 - offf))
84
+ Hp = Hp.at[lane, jnp.clip(idxf + 1, 0, nbins - 1)].add(w * offf)
85
+ H = jnp.sum(Hp, axis=0)
86
+
87
+ # zero-padded FFT (npad >= 2*nbins) so the Gaussian deconvolution doesn't
88
+ # suffer circular wraparound (that wraparound otherwise breaks convergence).
89
+ npad = 1
90
+ while npad < 2 * nbins:
91
+ npad *= 2
92
+ Hp_ = jnp.zeros((npad,), REAL).at[:nbins].set(H)
93
+ k = jnp.arange(npad).astype(REAL)
94
+ ln2 = jnp.log(2.0)
95
+ scaled_fwhm = fwhm / slope
96
+ exp_factor = 4.0 * ln2 / (scaled_fwhm**2)
97
+ scale_factor = 2.0 * jnp.sqrt(ln2 / jnp.pi) / scaled_fwhm
98
+ d = jnp.where(k > npad / 2, k - npad, k)
99
+ F = scale_factor * jnp.exp(-(d**2) * exp_factor)
100
+
101
+ Ff = jnp.fft.fft(F)
102
+ Gf = jnp.conj(Ff) / (jnp.abs(Ff) ** 2 + wiener) # Wiener filter
103
+ Uhat = jnp.clip(jnp.real(jnp.fft.ifft(jnp.fft.fft(Hp_) * Gf)), 0.0, None)
104
+
105
+ centers = vmin + k * slope
106
+ num = jnp.real(jnp.fft.ifft(jnp.fft.fft(Uhat * centers) * Ff))
107
+ den = jnp.real(jnp.fft.ifft(jnp.fft.fft(Uhat) * Ff))
108
+ E = (num / jnp.where(jnp.abs(den) > 1e-10, den, 1e-10))[:nbins]
109
+
110
+ ci = jnp.clip(cidx, 0.0, nbins - 1.0)
111
+ lo = jnp.floor(ci).astype(jnp.int32)
112
+ fr = ci - lo
113
+ hi = jnp.clip(lo + 1, 0, nbins - 1)
114
+ return jnp.where(m, E[lo] * (1.0 - fr) + E[hi] * fr, 0.0)
115
+
116
+
117
+ # ----------------------------- B-spline fit ---------------------------------
118
+ def _bspline_w(frac):
119
+ """Order-3 uniform B-spline weights for the 4 controls span..span+3."""
120
+ f = frac
121
+ return jnp.stack(
122
+ [
123
+ (1.0 - f) ** 3 / 6.0,
124
+ (3.0 * f**3 - 6.0 * f**2 + 4.0) / 6.0,
125
+ (-3.0 * f**3 + 3.0 * f**2 + 3.0 * f + 1.0) / 6.0,
126
+ f**3 / 6.0,
127
+ ],
128
+ axis=-1,
129
+ )
130
+
131
+
132
+ def _axis_mats(n, ncp, mesh):
133
+ """Sparse per-axis weight matrices (ncp x n) at powers 1/2/3 of the cubic
134
+ weights. The 3D Lee fit is separable -> these turn the scatter into matmuls."""
135
+ i = jnp.arange(n).astype(REAL)
136
+ p = jnp.clip(i / max(n - 1, 1) * mesh, 0.0, float(mesh)) # max(): handle singleton axes (2D)
137
+ span = jnp.clip(jnp.floor(p).astype(jnp.int32), 0, mesh - 1)
138
+ w = _bspline_w(p - span)
139
+ rows = (span[:, None] + jnp.arange(4)[None, :]).reshape(-1)
140
+ cols = jnp.repeat(jnp.arange(n), 4)
141
+
142
+ def mk(power):
143
+ return jnp.zeros((ncp, n), REAL).at[rows, cols].add((w**power).reshape(-1))
144
+
145
+ return mk(1), mk(2), mk(3)
146
+
147
+
148
+ @functools.partial(jax.jit, static_argnums=(2, 3))
149
+ def _bspline_fit(r, mask, ncp_shape, mesh):
150
+ """Cubic B-spline Lee-MBA fit of residual ``r`` over the mask, evaluated densely.
151
+ Separable formulation: identical math to the per-voxel scatter, as matmuls."""
152
+ s, h, w = r.shape
153
+ ncpz, ncpy, ncpx = ncp_shape
154
+ Wz1, Wz2, Wz3 = _axis_mats(s, ncpz, mesh)
155
+ Wy1, Wy2, Wy3 = _axis_mats(h, ncpy, mesh)
156
+ Wx1, Wx2, Wx3 = _axis_mats(w, ncpx, mesh)
157
+
158
+ mvox = (mask > 0.5).astype(REAL)
159
+ sz, sy, sx = jnp.sum(Wz2, 0), jnp.sum(Wy2, 0), jnp.sum(Wx2, 0) # Lee denom factorises
160
+ g = (r * mvox) / (sz[:, None, None] * sy[None, :, None] * sx[None, None, :])
161
+
162
+ num = jnp.einsum("cx,abx->abc", Wx3, jnp.einsum("by,ayx->abx", Wy3, jnp.einsum("az,zyx->ayx", Wz3, g)))
163
+ den = jnp.einsum("cx,abx->abc", Wx2, jnp.einsum("by,ayx->abx", Wy2, jnp.einsum("az,zyx->ayx", Wz2, mvox)))
164
+ phi = num / jnp.where(den > 1e-12, den, 1e-12)
165
+
166
+ p1 = jnp.einsum("az,abc->zbc", Wz1, phi)
167
+ p2 = jnp.einsum("by,zbc->zyc", Wy1, p1)
168
+ return jnp.einsum("cx,zyc->zyx", Wx1, p2)
169
+
170
+
171
+ def _conv_cv(b_prev, b, maskb, cnt):
172
+ """ITK convergence measure: CV = std/mean of exp(B_prev - B_curr) over the mask."""
173
+ r = jnp.exp(jnp.where(maskb, b_prev - b, 0.0))
174
+ mu = jnp.sum(jnp.where(maskb, r, 0.0)) / cnt
175
+ var = jnp.sum(jnp.where(maskb, (r - mu) ** 2, 0.0)) / (cnt - 1.0)
176
+ return jnp.sqrt(var) / mu
177
+
178
+
179
+ # ----------------------------- driver ---------------------------------------
180
+ @functools.cache
181
+ def _compiled(iters, nbins):
182
+ @jax.jit
183
+ def core(v, mask, fwhm, wiener, threshold, tiny, over_relax):
184
+ maskb = mask > 0.5
185
+ cnt = jnp.sum(maskb.astype(REAL))
186
+ u = jnp.log(jnp.clip(v, tiny, None))
187
+ b = jnp.zeros_like(u)
188
+ used = []
189
+ for lvl, nit in enumerate(iters):
190
+ mesh = 2**lvl
191
+ ncp = (mesh + SPLINE_ORDER,) * 3
192
+
193
+ def cond(c, _nit=nit):
194
+ _, i, cv = c
195
+ return (i < _nit) & (cv > threshold)
196
+
197
+ def body(c, _ncp=ncp, _mesh=mesh):
198
+ bp, i, _ = c
199
+ uc = u - bp
200
+ e = _sharpen(uc, mask, nbins, fwhm, wiener)
201
+ s = _bspline_fit(jnp.where(maskb, uc - e, 0.0), mask, _ncp, _mesh)
202
+ bn = bp + over_relax * s
203
+ return (bn, i + 1, _conv_cv(bp, bn, maskb, cnt))
204
+
205
+ b, i, _ = jax.lax.while_loop(cond, body, (b, 0, jnp.array(jnp.inf, REAL)))
206
+ used.append(i)
207
+ corrected = jnp.where(maskb, jnp.exp(u - b), v)
208
+ return corrected, b, jnp.stack(used)
209
+
210
+ return core
211
+
212
+
213
+ def n4(
214
+ image,
215
+ mask=None,
216
+ *,
217
+ iters=(8, 12, 12, 8),
218
+ over_relax: float = 1.8,
219
+ conv_threshold: float = 0.0,
220
+ nbins: int = 200,
221
+ fwhm: float = 0.15,
222
+ wiener: float = 0.01,
223
+ tiny: float = 1e-6,
224
+ return_bias: bool = False,
225
+ ):
226
+ """N4 bias field correction.
227
+
228
+ Args:
229
+ image: ND array (2D or 3D) of intensities (>= 0).
230
+ mask: optional binary foreground mask; if ``None``, an Otsu mask is used.
231
+ iters: max iterations per fitting level (mesh 1, 2, 4, 8). ``conv_threshold``
232
+ (ITK-style CV) can stop a level early; the default uses fixed counts
233
+ with over-relaxation for speed.
234
+ over_relax: relaxation factor (>= 1) accelerating the crawl to the fixed point.
235
+ nbins/fwhm/wiener: N3 sharpening parameters (ITK defaults).
236
+
237
+ Returns:
238
+ corrected image (== image / exp(bias)); also the log bias field if ``return_bias``.
239
+ """
240
+ image = jnp.asarray(image, REAL)
241
+ in3d = image.ndim == 3
242
+ v = image if in3d else image[None]
243
+ mask = otsu_mask(v) if mask is None else jnp.asarray(mask, REAL).reshape(v.shape)
244
+ core = _compiled(tuple(iters), nbins)
245
+ corrected, bias, _ = core(
246
+ v,
247
+ mask,
248
+ jnp.asarray(fwhm, REAL),
249
+ jnp.asarray(wiener, REAL),
250
+ jnp.asarray(conv_threshold, REAL),
251
+ jnp.asarray(tiny, REAL),
252
+ jnp.asarray(over_relax, REAL),
253
+ )
254
+ if not in3d:
255
+ corrected, bias = corrected[0], bias[0]
256
+ return (corrected, bias) if return_bias else corrected
@@ -0,0 +1,132 @@
1
+ Metadata-Version: 2.4
2
+ Name: n4ax
3
+ Version: 0.1.0
4
+ Summary: JAX/GPU N4 bias field correction — a fast drop-in match for ITK N4
5
+ Author: Geoffroy Oudoumanessah, Jacopo Iollo
6
+ Author-email: Gragas <contact@gragas.ai>
7
+ License-Expression: MIT
8
+ Project-URL: Homepage, https://github.com/GragasLab/n4ax
9
+ Project-URL: Repository, https://github.com/GragasLab/n4ax
10
+ Keywords: MRI,bias field,N4,N3,JAX,GPU
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
17
+ Requires-Python: >=3.12
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE
20
+ Requires-Dist: jax>=0.8.0
21
+ Requires-Dist: numpy>=1.24.0
22
+ Provides-Extra: cpu
23
+ Requires-Dist: jax[cpu]>=0.8.0; extra == "cpu"
24
+ Provides-Extra: cuda12
25
+ Requires-Dist: jax[cuda12]>=0.8.0; extra == "cuda12"
26
+ Provides-Extra: compare
27
+ Requires-Dist: SimpleITK>=2.3.0; extra == "compare"
28
+ Requires-Dist: matplotlib>=3.7; extra == "compare"
29
+ Requires-Dist: nibabel>=5.0; extra == "compare"
30
+ Provides-Extra: dev
31
+ Requires-Dist: pytest>=7.0; extra == "dev"
32
+ Requires-Dist: pytest-cov; extra == "dev"
33
+ Requires-Dist: ruff>=0.4.0; extra == "dev"
34
+ Requires-Dist: pre-commit>=3.0; extra == "dev"
35
+ Requires-Dist: SimpleITK>=2.3.0; extra == "dev"
36
+ Dynamic: license-file
37
+
38
+ # n4ax
39
+
40
+ **N4 bias field correction in pure JAX** — a fast, GPU-friendly, *drop-in* match for
41
+ ITK / SimpleITK's `N4BiasFieldCorrectionImageFilter`.
42
+
43
+ n4ax reimplements the N4 algorithm (Tustison et al., 2010 — N3 histogram sharpening
44
+ + multi-resolution B-spline) faithfully enough to **match SimpleITK to ~1%** on real
45
+ MRI, while running **~1500× faster on a GPU** and **~20× faster on the same CPU**.
46
+
47
+ ![NKI raw vs n4ax-corrected vs ITK-corrected](assets/nki_sub-0002_gpu.png)
48
+ *Raw NKI T1w (with B1 shading) → n4ax-corrected → ITK-corrected (visually identical) → estimated bias field.*
49
+
50
+ ## Why
51
+
52
+ N4 is the de-facto standard bias correction, but ITK's implementation is CPU-only and
53
+ slow (minutes per volume). In a GPU MRI pipeline it becomes the bottleneck. n4ax gives
54
+ **N4-quality output on the GPU in tens of milliseconds**, with no custom CUDA — just JAX.
55
+
56
+ ## Install
57
+
58
+ ```bash
59
+ uv sync --extra cuda12 # GPU (CUDA 12)
60
+ uv sync --extra cpu # CPU
61
+ uv sync --extra cuda12 --extra dev # + tests/linting
62
+ uv sync --extra cuda12 --extra compare # + SimpleITK/matplotlib for benchmarks
63
+ ```
64
+
65
+ ## Usage
66
+
67
+ ```python
68
+ import nibabel as nib
69
+ import n4ax
70
+
71
+ vol = nib.load("t1w.nii.gz").get_fdata() # 3D (or 2D) array, intensities >= 0
72
+ corrected = n4ax.n4(vol) # Otsu mask computed automatically
73
+ # or pass your own mask, and/or get the log bias field:
74
+ corrected, log_bias = n4ax.n4(vol, mask=mask, return_bias=True)
75
+ ```
76
+
77
+ `corrected == vol / exp(log_bias)`. The default config (`iters=(8,12,12,8)`,
78
+ `over_relax=1.8`) is tuned for speed; for the tightest ITK match use the robust
79
+ fallback `n4ax.n4(vol, iters=(50,50,30,20), over_relax=1.0, conv_threshold=1.5e-3)`.
80
+
81
+ ## Benchmark
82
+
83
+ Real NKI T1w volumes (256×176×256, ~2 M brain voxels), N4 `[50,50,30,20]`, same Otsu mask.
84
+ ITK on an 8-core CPU; n4ax CPU on the same node; n4ax GPU on an NVIDIA A100.
85
+
86
+ | Method | Time / volume | Speedup vs ITK |
87
+ |---|--:|--:|
88
+ | ITK N4 (CPU, 8 cores) | **146 s** | 1× |
89
+ | n4ax (CPU, 8 cores) | **7.7 s** | **~19×** |
90
+ | n4ax (A100 GPU) | **93 ms** | **~1571×** |
91
+
92
+ **Accuracy vs ITK** (corrected image, global scale removed — pipelines intensity-normalise anyway):
93
+ mean **1.15 %**, per-subject 0.79–1.59 % over 6 NKI scans. On a single fitting level n4ax
94
+ matches ITK to **0.4 %**, and a single N4 iteration to **0.1 %** — the building blocks are exact;
95
+ the residual is N4's own iterative crawl (ITK itself only converges after ~30 iters/level).
96
+
97
+ Multiple subjects, raw (top) vs n4ax-corrected (bottom):
98
+
99
+ ![NKI grid](assets/nki_grid_gpu.png)
100
+
101
+ Reproduce: `python scripts/bench_nki.py` (GPU) and `JAX_PLATFORMS=cpu python scripts/bench_nki.py --skip-itk --skip-fig --tag cpu`.
102
+
103
+ ## How it's fast (no custom kernels)
104
+
105
+ - **Separable B-spline fit.** N4's per-iteration B-spline least-squares (Lee MBA) is a
106
+ 94 M-way scatter into a tiny control lattice — brutal atomic contention (~30 ms/iter).
107
+ Because the cubic weights depend only on the per-axis index and the Lee denominator
108
+ factorises, this becomes **3 small dense matmuls per axis** (cuBLAS) — *identical math*,
109
+ 0.1 ms/iter.
110
+ - **Privatised histogram.** The N3 sharpening histogram (1.5 M → 200 bins) is privatised
111
+ over 256 lanes to avoid atomic serialisation.
112
+ - **Over-relaxation.** N4's fixed point is invariant to `B += α·S` (S = 0 there), so
113
+ `α ≈ 1.8` reaches ITK's result in far fewer iterations.
114
+ - The whole solve is one fused, jitted program with a device-side convergence loop.
115
+
116
+ Two things that mattered for *correctness*: zero-padding the sharpening FFT (circular
117
+ wraparound otherwise breaks convergence), and that float32 == float64 here (verified).
118
+
119
+ ## Tests
120
+
121
+ ```bash
122
+ uv run pytest # basic correctness + ground-truth match vs SimpleITK
123
+ ```
124
+
125
+ `tests/test_vs_itk.py` asserts n4ax matches SimpleITK's N4 (the reference) within tolerance
126
+ on a phantom; `tests/test_basic.py` covers shapes, 2D/3D, the `image/exp(bias)` identity,
127
+ bias flattening, and the Otsu mask.
128
+
129
+ ## Status
130
+
131
+ Alpha. The fast defaults are tuned on NKI/phantom data; validate on your own data before
132
+ production (the `iters=(50,50,30,20), over_relax=1.0` fallback is the conservative choice).
@@ -0,0 +1,12 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ n4ax/__init__.py
5
+ n4ax/core.py
6
+ n4ax.egg-info/PKG-INFO
7
+ n4ax.egg-info/SOURCES.txt
8
+ n4ax.egg-info/dependency_links.txt
9
+ n4ax.egg-info/requires.txt
10
+ n4ax.egg-info/top_level.txt
11
+ tests/test_basic.py
12
+ tests/test_vs_itk.py
@@ -0,0 +1,20 @@
1
+ jax>=0.8.0
2
+ numpy>=1.24.0
3
+
4
+ [compare]
5
+ SimpleITK>=2.3.0
6
+ matplotlib>=3.7
7
+ nibabel>=5.0
8
+
9
+ [cpu]
10
+ jax[cpu]>=0.8.0
11
+
12
+ [cuda12]
13
+ jax[cuda12]>=0.8.0
14
+
15
+ [dev]
16
+ pytest>=7.0
17
+ pytest-cov
18
+ ruff>=0.4.0
19
+ pre-commit>=3.0
20
+ SimpleITK>=2.3.0
@@ -0,0 +1 @@
1
+ n4ax
@@ -0,0 +1,78 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "n4ax"
7
+ version = "0.1.0"
8
+ authors = [
9
+ {name = "Gragas", email = "contact@gragas.ai"},
10
+ {name = "Geoffroy Oudoumanessah"},
11
+ {name = "Jacopo Iollo"},
12
+ ]
13
+ description = "JAX/GPU N4 bias field correction — a fast drop-in match for ITK N4"
14
+ readme = "README.md"
15
+ license = "MIT"
16
+ requires-python = ">=3.12"
17
+ classifiers = [
18
+ "Development Status :: 3 - Alpha",
19
+ "Intended Audience :: Science/Research",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.12",
22
+ "Programming Language :: Python :: 3.13",
23
+ "Topic :: Scientific/Engineering :: Medical Science Apps.",
24
+ ]
25
+ keywords = ["MRI", "bias field", "N4", "N3", "JAX", "GPU"]
26
+ dependencies = [
27
+ "jax>=0.8.0",
28
+ "numpy>=1.24.0",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ cpu = ["jax[cpu]>=0.8.0"]
33
+ cuda12 = ["jax[cuda12]>=0.8.0"]
34
+ # `compare` pulls SimpleITK (the reference N4) + plotting/IO for tests & benchmarks.
35
+ compare = ["SimpleITK>=2.3.0", "matplotlib>=3.7", "nibabel>=5.0"]
36
+ dev = [
37
+ "pytest>=7.0",
38
+ "pytest-cov",
39
+ "ruff>=0.4.0",
40
+ "pre-commit>=3.0",
41
+ "SimpleITK>=2.3.0",
42
+ ]
43
+
44
+ [project.urls]
45
+ Homepage = "https://github.com/GragasLab/n4ax"
46
+ Repository = "https://github.com/GragasLab/n4ax"
47
+
48
+ [tool.setuptools.packages.find]
49
+ include = ["n4ax*"]
50
+
51
+ [tool.pytest.ini_options]
52
+ testpaths = ["tests"]
53
+ python_files = ["test_*.py"]
54
+ addopts = "-v --tb=short"
55
+ pythonpath = [".", "tests"]
56
+
57
+ [tool.ruff]
58
+ target-version = "py312"
59
+ line-length = 119
60
+
61
+ [tool.ruff.lint]
62
+ select = ["E", "F", "I", "W", "UP", "FURB", "SIM", "S110", "C4", "RUF013", "PERF102", "PLC1802", "PLC0208", "PIE794"]
63
+ ignore = ["E501", "E741", "SIM1", "SIM905", "UP015", "UP031"]
64
+ extend-safe-fixes = ["UP006"]
65
+
66
+ [tool.ruff.lint.per-file-ignores]
67
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
68
+ "*.ipynb" = ["E402", "E731", "B007", "N816"]
69
+
70
+ [tool.ruff.lint.isort]
71
+ lines-after-imports = 2
72
+ known-first-party = ["n4ax"]
73
+
74
+ [tool.ruff.format]
75
+ quote-style = "double"
76
+ indent-style = "space"
77
+ skip-magic-trailing-comma = false
78
+ line-ending = "auto"
n4ax-0.1.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,56 @@
1
+ """Basic correctness: import, shapes, finiteness, 2D/3D, near-identity on no bias."""
2
+
3
+ import numpy as np
4
+
5
+ import n4ax
6
+
7
+
8
+ def test_import():
9
+ assert callable(n4ax.n4)
10
+ assert callable(n4ax.otsu_mask)
11
+
12
+
13
+ def test_3d_shape_and_finite(phantom):
14
+ obs, _, _ = phantom
15
+ corr = np.asarray(n4ax.n4(obs))
16
+ assert corr.shape == obs.shape
17
+ assert np.isfinite(corr).all()
18
+ assert (corr >= 0).all()
19
+
20
+
21
+ def test_2d_runs():
22
+ rng = np.random.default_rng(0)
23
+ img = rng.uniform(0.5, 1.5, size=(64, 64)).astype(np.float32)
24
+ img[:5] = 0.0 # background
25
+ corr = np.asarray(n4ax.n4(img))
26
+ assert corr.shape == img.shape
27
+ assert np.isfinite(corr).all()
28
+
29
+
30
+ def test_return_bias(phantom):
31
+ obs, _, mask = phantom
32
+ corr, bias = n4ax.n4(obs, mask=mask.astype(np.float32), return_bias=True)
33
+ corr, bias = np.asarray(corr), np.asarray(bias)
34
+ m = mask
35
+ # corrected == image / exp(bias) inside the mask
36
+ recon = obs / np.exp(bias)
37
+ assert np.abs(corr[m] - recon[m]).max() < 1e-3
38
+
39
+
40
+ def test_recovers_bias(phantom):
41
+ """N4 should flatten a known smooth bias: corrected tissue is more uniform
42
+ (lower coefficient of variation) than the observed biased image."""
43
+ obs, bias, mask = phantom
44
+ corr = np.asarray(n4ax.n4(obs, mask=mask.astype(np.float32)))
45
+ m = mask
46
+ cv_before = obs[m].std() / obs[m].mean()
47
+ cv_after = corr[m].std() / corr[m].mean()
48
+ assert cv_after < cv_before
49
+
50
+
51
+ def test_otsu_mask(phantom):
52
+ obs, _, mask = phantom
53
+ om = np.asarray(n4ax.otsu_mask(obs)) > 0.5
54
+ # Otsu foreground should agree with the true brain mask on most voxels
55
+ agree = (om == mask).mean()
56
+ assert agree > 0.95
@@ -0,0 +1,49 @@
1
+ """Ground-truth test: n4ax must match SimpleITK's N4 (the reference implementation).
2
+
3
+ Uses the SAME mask for both (so we compare the N4 solve, not the masking), and
4
+ compares the corrected image with its global scale removed (N4's bias field is
5
+ defined up to a constant; downstream pipelines intensity-normalise anyway)."""
6
+
7
+ import numpy as np
8
+ import pytest
9
+
10
+ import n4ax
11
+
12
+
13
+ sitk = pytest.importorskip("SimpleITK")
14
+ from _phantom import make_phantom # noqa: E402
15
+
16
+
17
+ def _itk_n4(obs, mask, iters=(50, 50, 30, 20)):
18
+ img = sitk.GetImageFromArray(obs.astype(np.float32))
19
+ mk = sitk.GetImageFromArray(mask.astype(np.uint8))
20
+ c = sitk.N4BiasFieldCorrectionImageFilter()
21
+ c.SetMaximumNumberOfIterations([int(i) for i in iters])
22
+ out = c.Execute(img, mk)
23
+ return sitk.GetArrayFromImage(out).astype(np.float64)
24
+
25
+
26
+ def _scaled_relerr(a, b, m):
27
+ a, b = np.asarray(a, np.float64), np.asarray(b, np.float64)
28
+ ratio = a[m] / np.clip(b[m], 1e-6, None)
29
+ rel = np.abs(a[m] / np.median(ratio) - b[m]) / np.clip(np.abs(b[m]), 1e-6, None)
30
+ return rel
31
+
32
+
33
+ def test_matches_simpleitk():
34
+ obs, _, mask = make_phantom()
35
+ itk = _itk_n4(obs, mask)
36
+ jax_corr = np.asarray(n4ax.n4(obs, mask=mask.astype(np.float32)))
37
+ rel = _scaled_relerr(jax_corr, itk, mask)
38
+ assert rel.mean() < 0.015, f"mean rel-err {rel.mean() * 100:.2f}% too high"
39
+ assert np.percentile(rel, 95) < 0.05, f"p95 rel-err {np.percentile(rel, 95) * 100:.2f}% too high"
40
+
41
+
42
+ def test_closer_to_itk_than_uncorrected():
43
+ """n4ax's correction must be much closer to ITK's than doing nothing."""
44
+ obs, _, mask = make_phantom(seed=1)
45
+ itk = _itk_n4(obs, mask)
46
+ jax_corr = np.asarray(n4ax.n4(obs, mask=mask.astype(np.float32)))
47
+ err_jax = _scaled_relerr(jax_corr, itk, mask).mean()
48
+ err_none = _scaled_relerr(obs, itk, mask).mean()
49
+ assert err_jax < 0.25 * err_none