n4ax 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- n4ax-0.1.0/LICENSE +21 -0
- n4ax-0.1.0/PKG-INFO +132 -0
- n4ax-0.1.0/README.md +95 -0
- n4ax-0.1.0/n4ax/__init__.py +6 -0
- n4ax-0.1.0/n4ax/core.py +256 -0
- n4ax-0.1.0/n4ax.egg-info/PKG-INFO +132 -0
- n4ax-0.1.0/n4ax.egg-info/SOURCES.txt +12 -0
- n4ax-0.1.0/n4ax.egg-info/dependency_links.txt +1 -0
- n4ax-0.1.0/n4ax.egg-info/requires.txt +20 -0
- n4ax-0.1.0/n4ax.egg-info/top_level.txt +1 -0
- n4ax-0.1.0/pyproject.toml +78 -0
- n4ax-0.1.0/setup.cfg +4 -0
- n4ax-0.1.0/tests/test_basic.py +56 -0
- n4ax-0.1.0/tests/test_vs_itk.py +49 -0
n4ax-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Gragas
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
n4ax-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: n4ax
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: JAX/GPU N4 bias field correction — a fast drop-in match for ITK N4
|
|
5
|
+
Author: Geoffroy Oudoumanessah, Jacopo Iollo
|
|
6
|
+
Author-email: Gragas <contact@gragas.ai>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Project-URL: Homepage, https://github.com/GragasLab/n4ax
|
|
9
|
+
Project-URL: Repository, https://github.com/GragasLab/n4ax
|
|
10
|
+
Keywords: MRI,bias field,N4,N3,JAX,GPU
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
17
|
+
Requires-Python: >=3.12
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
License-File: LICENSE
|
|
20
|
+
Requires-Dist: jax>=0.8.0
|
|
21
|
+
Requires-Dist: numpy>=1.24.0
|
|
22
|
+
Provides-Extra: cpu
|
|
23
|
+
Requires-Dist: jax[cpu]>=0.8.0; extra == "cpu"
|
|
24
|
+
Provides-Extra: cuda12
|
|
25
|
+
Requires-Dist: jax[cuda12]>=0.8.0; extra == "cuda12"
|
|
26
|
+
Provides-Extra: compare
|
|
27
|
+
Requires-Dist: SimpleITK>=2.3.0; extra == "compare"
|
|
28
|
+
Requires-Dist: matplotlib>=3.7; extra == "compare"
|
|
29
|
+
Requires-Dist: nibabel>=5.0; extra == "compare"
|
|
30
|
+
Provides-Extra: dev
|
|
31
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
32
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
33
|
+
Requires-Dist: ruff>=0.4.0; extra == "dev"
|
|
34
|
+
Requires-Dist: pre-commit>=3.0; extra == "dev"
|
|
35
|
+
Requires-Dist: SimpleITK>=2.3.0; extra == "dev"
|
|
36
|
+
Dynamic: license-file
|
|
37
|
+
|
|
38
|
+
# n4ax
|
|
39
|
+
|
|
40
|
+
**N4 bias field correction in pure JAX** — a fast, GPU-friendly, *drop-in* match for
|
|
41
|
+
ITK / SimpleITK's `N4BiasFieldCorrectionImageFilter`.
|
|
42
|
+
|
|
43
|
+
n4ax reimplements the N4 algorithm (Tustison et al., 2010 — N3 histogram sharpening
|
|
44
|
+
+ multi-resolution B-spline) faithfully enough to **match SimpleITK to ~1%** on real
|
|
45
|
+
MRI, while running **~1500× faster on a GPU** and **~20× faster on the same CPU**.
|
|
46
|
+
|
|
47
|
+

|
|
48
|
+
*Raw NKI T1w (with B1 shading) → n4ax-corrected → ITK-corrected (visually identical) → estimated bias field.*
|
|
49
|
+
|
|
50
|
+
## Why
|
|
51
|
+
|
|
52
|
+
N4 is the de-facto standard bias correction, but ITK's implementation is CPU-only and
|
|
53
|
+
slow (minutes per volume). In a GPU MRI pipeline it becomes the bottleneck. n4ax gives
|
|
54
|
+
**N4-quality output on the GPU in tens of milliseconds**, with no custom CUDA — just JAX.
|
|
55
|
+
|
|
56
|
+
## Install
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
uv sync --extra cuda12 # GPU (CUDA 12)
|
|
60
|
+
uv sync --extra cpu # CPU
|
|
61
|
+
uv sync --extra cuda12 --extra dev # + tests/linting
|
|
62
|
+
uv sync --extra cuda12 --extra compare # + SimpleITK/matplotlib for benchmarks
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
## Usage
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
import nibabel as nib
|
|
69
|
+
import n4ax
|
|
70
|
+
|
|
71
|
+
vol = nib.load("t1w.nii.gz").get_fdata() # 3D (or 2D) array, intensities >= 0
|
|
72
|
+
corrected = n4ax.n4(vol) # Otsu mask computed automatically
|
|
73
|
+
# or pass your own mask, and/or get the log bias field:
|
|
74
|
+
corrected, log_bias = n4ax.n4(vol, mask=mask, return_bias=True)
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
`corrected == vol / exp(log_bias)`. The default config (`iters=(8,12,12,8)`,
|
|
78
|
+
`over_relax=1.8`) is tuned for speed; for the tightest ITK match use the robust
|
|
79
|
+
fallback `n4ax.n4(vol, iters=(50,50,30,20), over_relax=1.0, conv_threshold=1.5e-3)`.
|
|
80
|
+
|
|
81
|
+
## Benchmark
|
|
82
|
+
|
|
83
|
+
Real NKI T1w volumes (256×176×256, ~2 M brain voxels), N4 `[50,50,30,20]`, same Otsu mask.
|
|
84
|
+
ITK on an 8-core CPU; n4ax CPU on the same node; n4ax GPU on an NVIDIA A100.
|
|
85
|
+
|
|
86
|
+
| Method | Time / volume | Speedup vs ITK |
|
|
87
|
+
|---|--:|--:|
|
|
88
|
+
| ITK N4 (CPU, 8 cores) | **146 s** | 1× |
|
|
89
|
+
| n4ax (CPU, 8 cores) | **7.7 s** | **~19×** |
|
|
90
|
+
| n4ax (A100 GPU) | **93 ms** | **~1571×** |
|
|
91
|
+
|
|
92
|
+
**Accuracy vs ITK** (corrected image, global scale removed — pipelines intensity-normalise anyway):
|
|
93
|
+
mean **1.15 %**, per-subject 0.79–1.59 % over 6 NKI scans. On a single fitting level n4ax
|
|
94
|
+
matches ITK to **0.4 %**, and a single N4 iteration to **0.1 %** — the building blocks are exact;
|
|
95
|
+
the residual is N4's own iterative crawl (ITK itself only converges after ~30 iters/level).
|
|
96
|
+
|
|
97
|
+
Multiple subjects, raw (top) vs n4ax-corrected (bottom):
|
|
98
|
+
|
|
99
|
+

|
|
100
|
+
|
|
101
|
+
Reproduce: `python scripts/bench_nki.py` (GPU) and `JAX_PLATFORMS=cpu python scripts/bench_nki.py --skip-itk --skip-fig --tag cpu`.
|
|
102
|
+
|
|
103
|
+
## How it's fast (no custom kernels)
|
|
104
|
+
|
|
105
|
+
- **Separable B-spline fit.** N4's per-iteration B-spline least-squares (Lee MBA) is a
|
|
106
|
+
94 M-way scatter into a tiny control lattice — brutal atomic contention (~30 ms/iter).
|
|
107
|
+
Because the cubic weights depend only on the per-axis index and the Lee denominator
|
|
108
|
+
factorises, this becomes **3 small dense matmuls per axis** (cuBLAS) — *identical math*,
|
|
109
|
+
0.1 ms/iter.
|
|
110
|
+
- **Privatised histogram.** The N3 sharpening histogram (1.5 M → 200 bins) is privatised
|
|
111
|
+
over 256 lanes to avoid atomic serialisation.
|
|
112
|
+
- **Over-relaxation.** N4's fixed point is invariant to `B += α·S` (S = 0 there), so
|
|
113
|
+
`α ≈ 1.8` reaches ITK's result in far fewer iterations.
|
|
114
|
+
- The whole solve is one fused, jitted program with a device-side convergence loop.
|
|
115
|
+
|
|
116
|
+
Two things that mattered for *correctness*: zero-padding the sharpening FFT (circular
|
|
117
|
+
wraparound otherwise breaks convergence), and that float32 == float64 here (verified).
|
|
118
|
+
|
|
119
|
+
## Tests
|
|
120
|
+
|
|
121
|
+
```bash
|
|
122
|
+
uv run pytest # basic correctness + ground-truth match vs SimpleITK
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
`tests/test_vs_itk.py` asserts n4ax matches SimpleITK's N4 (the reference) within tolerance
|
|
126
|
+
on a phantom; `tests/test_basic.py` covers shapes, 2D/3D, the `image/exp(bias)` identity,
|
|
127
|
+
bias flattening, and the Otsu mask.
|
|
128
|
+
|
|
129
|
+
## Status
|
|
130
|
+
|
|
131
|
+
Alpha. The fast defaults are tuned on NKI/phantom data; validate on your own data before
|
|
132
|
+
production (the `iters=(50,50,30,20), over_relax=1.0` fallback is the conservative choice).
|
n4ax-0.1.0/README.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# n4ax
|
|
2
|
+
|
|
3
|
+
**N4 bias field correction in pure JAX** — a fast, GPU-friendly, *drop-in* match for
|
|
4
|
+
ITK / SimpleITK's `N4BiasFieldCorrectionImageFilter`.
|
|
5
|
+
|
|
6
|
+
n4ax reimplements the N4 algorithm (Tustison et al., 2010 — N3 histogram sharpening
|
|
7
|
+
+ multi-resolution B-spline) faithfully enough to **match SimpleITK to ~1%** on real
|
|
8
|
+
MRI, while running **~1500× faster on a GPU** and **~20× faster on the same CPU**.
|
|
9
|
+
|
|
10
|
+

|
|
11
|
+
*Raw NKI T1w (with B1 shading) → n4ax-corrected → ITK-corrected (visually identical) → estimated bias field.*
|
|
12
|
+
|
|
13
|
+
## Why
|
|
14
|
+
|
|
15
|
+
N4 is the de-facto standard bias correction, but ITK's implementation is CPU-only and
|
|
16
|
+
slow (minutes per volume). In a GPU MRI pipeline it becomes the bottleneck. n4ax gives
|
|
17
|
+
**N4-quality output on the GPU in tens of milliseconds**, with no custom CUDA — just JAX.
|
|
18
|
+
|
|
19
|
+
## Install
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
uv sync --extra cuda12 # GPU (CUDA 12)
|
|
23
|
+
uv sync --extra cpu # CPU
|
|
24
|
+
uv sync --extra cuda12 --extra dev # + tests/linting
|
|
25
|
+
uv sync --extra cuda12 --extra compare # + SimpleITK/matplotlib for benchmarks
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Usage
|
|
29
|
+
|
|
30
|
+
```python
|
|
31
|
+
import nibabel as nib
|
|
32
|
+
import n4ax
|
|
33
|
+
|
|
34
|
+
vol = nib.load("t1w.nii.gz").get_fdata() # 3D (or 2D) array, intensities >= 0
|
|
35
|
+
corrected = n4ax.n4(vol) # Otsu mask computed automatically
|
|
36
|
+
# or pass your own mask, and/or get the log bias field:
|
|
37
|
+
corrected, log_bias = n4ax.n4(vol, mask=mask, return_bias=True)
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
`corrected == vol / exp(log_bias)`. The default config (`iters=(8,12,12,8)`,
|
|
41
|
+
`over_relax=1.8`) is tuned for speed; for the tightest ITK match use the robust
|
|
42
|
+
fallback `n4ax.n4(vol, iters=(50,50,30,20), over_relax=1.0, conv_threshold=1.5e-3)`.
|
|
43
|
+
|
|
44
|
+
## Benchmark
|
|
45
|
+
|
|
46
|
+
Real NKI T1w volumes (256×176×256, ~2 M brain voxels), N4 `[50,50,30,20]`, same Otsu mask.
|
|
47
|
+
ITK on an 8-core CPU; n4ax CPU on the same node; n4ax GPU on an NVIDIA A100.
|
|
48
|
+
|
|
49
|
+
| Method | Time / volume | Speedup vs ITK |
|
|
50
|
+
|---|--:|--:|
|
|
51
|
+
| ITK N4 (CPU, 8 cores) | **146 s** | 1× |
|
|
52
|
+
| n4ax (CPU, 8 cores) | **7.7 s** | **~19×** |
|
|
53
|
+
| n4ax (A100 GPU) | **93 ms** | **~1571×** |
|
|
54
|
+
|
|
55
|
+
**Accuracy vs ITK** (corrected image, global scale removed — pipelines intensity-normalise anyway):
|
|
56
|
+
mean **1.15 %**, per-subject 0.79–1.59 % over 6 NKI scans. On a single fitting level n4ax
|
|
57
|
+
matches ITK to **0.4 %**, and a single N4 iteration to **0.1 %** — the building blocks are exact;
|
|
58
|
+
the residual is N4's own iterative crawl (ITK itself only converges after ~30 iters/level).
|
|
59
|
+
|
|
60
|
+
Multiple subjects, raw (top) vs n4ax-corrected (bottom):
|
|
61
|
+
|
|
62
|
+

|
|
63
|
+
|
|
64
|
+
Reproduce: `python scripts/bench_nki.py` (GPU) and `JAX_PLATFORMS=cpu python scripts/bench_nki.py --skip-itk --skip-fig --tag cpu`.
|
|
65
|
+
|
|
66
|
+
## How it's fast (no custom kernels)
|
|
67
|
+
|
|
68
|
+
- **Separable B-spline fit.** N4's per-iteration B-spline least-squares (Lee MBA) is a
|
|
69
|
+
94 M-way scatter into a tiny control lattice — brutal atomic contention (~30 ms/iter).
|
|
70
|
+
Because the cubic weights depend only on the per-axis index and the Lee denominator
|
|
71
|
+
factorises, this becomes **3 small dense matmuls per axis** (cuBLAS) — *identical math*,
|
|
72
|
+
0.1 ms/iter.
|
|
73
|
+
- **Privatised histogram.** The N3 sharpening histogram (1.5 M → 200 bins) is privatised
|
|
74
|
+
over 256 lanes to avoid atomic serialisation.
|
|
75
|
+
- **Over-relaxation.** N4's fixed point is invariant to `B += α·S` (S = 0 there), so
|
|
76
|
+
`α ≈ 1.8` reaches ITK's result in far fewer iterations.
|
|
77
|
+
- The whole solve is one fused, jitted program with a device-side convergence loop.
|
|
78
|
+
|
|
79
|
+
Two things that mattered for *correctness*: zero-padding the sharpening FFT (circular
|
|
80
|
+
wraparound otherwise breaks convergence), and that float32 == float64 here (verified).
|
|
81
|
+
|
|
82
|
+
## Tests
|
|
83
|
+
|
|
84
|
+
```bash
|
|
85
|
+
uv run pytest # basic correctness + ground-truth match vs SimpleITK
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
`tests/test_vs_itk.py` asserts n4ax matches SimpleITK's N4 (the reference) within tolerance
|
|
89
|
+
on a phantom; `tests/test_basic.py` covers shapes, 2D/3D, the `image/exp(bias)` identity,
|
|
90
|
+
bias flattening, and the Otsu mask.
|
|
91
|
+
|
|
92
|
+
## Status
|
|
93
|
+
|
|
94
|
+
Alpha. The fast defaults are tuned on NKI/phantom data; validate on your own data before
|
|
95
|
+
production (the `iters=(50,50,30,20), over_relax=1.0` fallback is the conservative choice).
|
n4ax-0.1.0/n4ax/core.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""N4 bias field correction in pure JAX — a fast, GPU-friendly drop-in match for
|
|
2
|
+
ITK / SimpleITK's ``N4BiasFieldCorrectionImageFilter``.
|
|
3
|
+
|
|
4
|
+
Algorithm (Tustison 2010 N4 = N3 histogram sharpening + multi-resolution B-spline):
|
|
5
|
+
|
|
6
|
+
u = log(v) over the mask; B = 0 (log bias field)
|
|
7
|
+
for each fitting level (mesh = 1, 2, 4, 8):
|
|
8
|
+
repeat:
|
|
9
|
+
uc = u - B
|
|
10
|
+
E = sharpen(uc) # N3 histogram deconvolution (Wiener, FFT)
|
|
11
|
+
S = bspline_fit(uc - E) # cubic B-spline least-squares (Lee MBA)
|
|
12
|
+
B = B + over_relax * S
|
|
13
|
+
corrected = exp(u - B) # == v / exp(B)
|
|
14
|
+
|
|
15
|
+
Every building block matches ITK's implementation (parametric coords, cubic
|
|
16
|
+
B-spline weights, Lee-MBA accumulation, N3 Wiener deconvolution). Two ideas make
|
|
17
|
+
it fast on a GPU without any custom kernel:
|
|
18
|
+
|
|
19
|
+
* the B-spline fit is **separable** (weights depend only on the per-axis index and
|
|
20
|
+
the Lee denominator factorises), so the per-voxel scatter into the control
|
|
21
|
+
lattice becomes three small dense matmuls per axis — no atomic contention;
|
|
22
|
+
* the sharpening histogram is **privatised** over K lanes so the value-scatter
|
|
23
|
+
doesn't serialise on atomics.
|
|
24
|
+
|
|
25
|
+
``over_relax > 1`` accelerates N4's (slow, monotone) crawl to the *same* fixed
|
|
26
|
+
point (S = 0 there, so ``B += a*S`` is fixed-point invariant), reaching ITK's
|
|
27
|
+
result in far fewer iterations.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
import functools
|
|
33
|
+
|
|
34
|
+
import jax
|
|
35
|
+
import jax.numpy as jnp
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
SPLINE_ORDER = 3
|
|
39
|
+
REAL = jnp.float32 # the iteration is float-precision-insensitive (verified vs float64)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ----------------------------- Otsu mask ------------------------------------
|
|
43
|
+
def otsu_mask(volume, nbins: int = 200):
|
|
44
|
+
"""Binary foreground mask via Otsu's threshold (matches ITK ``OtsuThreshold``
|
|
45
|
+
with insideValue=0/outsideValue=1: foreground = intensity above threshold)."""
|
|
46
|
+
v = jnp.asarray(volume, REAL)
|
|
47
|
+
vmin, vmax = jnp.min(v), jnp.max(v)
|
|
48
|
+
edges = jnp.linspace(vmin, vmax, nbins + 1)
|
|
49
|
+
hist, _ = jnp.histogram(v, bins=edges)
|
|
50
|
+
hist = hist.astype(REAL)
|
|
51
|
+
centers = 0.5 * (edges[:-1] + edges[1:])
|
|
52
|
+
w = jnp.cumsum(hist)
|
|
53
|
+
wb = w
|
|
54
|
+
wf = w[-1] - w
|
|
55
|
+
csum = jnp.cumsum(hist * centers)
|
|
56
|
+
mb = jnp.where(wb > 0, csum / jnp.where(wb > 0, wb, 1), 0.0)
|
|
57
|
+
mf = jnp.where(wf > 0, (csum[-1] - csum) / jnp.where(wf > 0, wf, 1), 0.0)
|
|
58
|
+
between = wb * wf * (mb - mf) ** 2
|
|
59
|
+
thr = centers[jnp.argmax(between)]
|
|
60
|
+
return (v > thr).astype(REAL)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ----------------------------- N3 sharpening --------------------------------
|
|
64
|
+
@functools.partial(jax.jit, static_argnums=(2,))
|
|
65
|
+
def _sharpen(uc, mask, nbins=200, fwhm=0.15, wiener=0.01):
|
|
66
|
+
"""N3 histogram-deconvolution sharpening of the masked log image ``uc``."""
|
|
67
|
+
m = mask > 0.5
|
|
68
|
+
vmin = jnp.min(jnp.where(m, uc, jnp.inf))
|
|
69
|
+
vmax = jnp.max(jnp.where(m, uc, -jnp.inf))
|
|
70
|
+
slope = (vmax - vmin) / (nbins - 1)
|
|
71
|
+
|
|
72
|
+
# parzen (2-bin linear) histogram, privatised over K lanes to avoid atomic
|
|
73
|
+
# contention (the value-scatter was ~100% of sharpen's cost otherwise).
|
|
74
|
+
cidx = (uc - vmin) / slope
|
|
75
|
+
idx = jnp.floor(cidx).astype(jnp.int32)
|
|
76
|
+
off = cidx - idx
|
|
77
|
+
w = m.astype(REAL).reshape(-1)
|
|
78
|
+
idxf = idx.reshape(-1)
|
|
79
|
+
offf = off.reshape(-1)
|
|
80
|
+
K = 256
|
|
81
|
+
lane = jnp.arange(idxf.shape[0], dtype=jnp.int32) % K
|
|
82
|
+
Hp = jnp.zeros((K, nbins), REAL)
|
|
83
|
+
Hp = Hp.at[lane, jnp.clip(idxf, 0, nbins - 1)].add(w * (1.0 - offf))
|
|
84
|
+
Hp = Hp.at[lane, jnp.clip(idxf + 1, 0, nbins - 1)].add(w * offf)
|
|
85
|
+
H = jnp.sum(Hp, axis=0)
|
|
86
|
+
|
|
87
|
+
# zero-padded FFT (npad >= 2*nbins) so the Gaussian deconvolution doesn't
|
|
88
|
+
# suffer circular wraparound (that wraparound otherwise breaks convergence).
|
|
89
|
+
npad = 1
|
|
90
|
+
while npad < 2 * nbins:
|
|
91
|
+
npad *= 2
|
|
92
|
+
Hp_ = jnp.zeros((npad,), REAL).at[:nbins].set(H)
|
|
93
|
+
k = jnp.arange(npad).astype(REAL)
|
|
94
|
+
ln2 = jnp.log(2.0)
|
|
95
|
+
scaled_fwhm = fwhm / slope
|
|
96
|
+
exp_factor = 4.0 * ln2 / (scaled_fwhm**2)
|
|
97
|
+
scale_factor = 2.0 * jnp.sqrt(ln2 / jnp.pi) / scaled_fwhm
|
|
98
|
+
d = jnp.where(k > npad / 2, k - npad, k)
|
|
99
|
+
F = scale_factor * jnp.exp(-(d**2) * exp_factor)
|
|
100
|
+
|
|
101
|
+
Ff = jnp.fft.fft(F)
|
|
102
|
+
Gf = jnp.conj(Ff) / (jnp.abs(Ff) ** 2 + wiener) # Wiener filter
|
|
103
|
+
Uhat = jnp.clip(jnp.real(jnp.fft.ifft(jnp.fft.fft(Hp_) * Gf)), 0.0, None)
|
|
104
|
+
|
|
105
|
+
centers = vmin + k * slope
|
|
106
|
+
num = jnp.real(jnp.fft.ifft(jnp.fft.fft(Uhat * centers) * Ff))
|
|
107
|
+
den = jnp.real(jnp.fft.ifft(jnp.fft.fft(Uhat) * Ff))
|
|
108
|
+
E = (num / jnp.where(jnp.abs(den) > 1e-10, den, 1e-10))[:nbins]
|
|
109
|
+
|
|
110
|
+
ci = jnp.clip(cidx, 0.0, nbins - 1.0)
|
|
111
|
+
lo = jnp.floor(ci).astype(jnp.int32)
|
|
112
|
+
fr = ci - lo
|
|
113
|
+
hi = jnp.clip(lo + 1, 0, nbins - 1)
|
|
114
|
+
return jnp.where(m, E[lo] * (1.0 - fr) + E[hi] * fr, 0.0)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# ----------------------------- B-spline fit ---------------------------------
|
|
118
|
+
def _bspline_w(frac):
|
|
119
|
+
"""Order-3 uniform B-spline weights for the 4 controls span..span+3."""
|
|
120
|
+
f = frac
|
|
121
|
+
return jnp.stack(
|
|
122
|
+
[
|
|
123
|
+
(1.0 - f) ** 3 / 6.0,
|
|
124
|
+
(3.0 * f**3 - 6.0 * f**2 + 4.0) / 6.0,
|
|
125
|
+
(-3.0 * f**3 + 3.0 * f**2 + 3.0 * f + 1.0) / 6.0,
|
|
126
|
+
f**3 / 6.0,
|
|
127
|
+
],
|
|
128
|
+
axis=-1,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _axis_mats(n, ncp, mesh):
|
|
133
|
+
"""Sparse per-axis weight matrices (ncp x n) at powers 1/2/3 of the cubic
|
|
134
|
+
weights. The 3D Lee fit is separable -> these turn the scatter into matmuls."""
|
|
135
|
+
i = jnp.arange(n).astype(REAL)
|
|
136
|
+
p = jnp.clip(i / max(n - 1, 1) * mesh, 0.0, float(mesh)) # max(): handle singleton axes (2D)
|
|
137
|
+
span = jnp.clip(jnp.floor(p).astype(jnp.int32), 0, mesh - 1)
|
|
138
|
+
w = _bspline_w(p - span)
|
|
139
|
+
rows = (span[:, None] + jnp.arange(4)[None, :]).reshape(-1)
|
|
140
|
+
cols = jnp.repeat(jnp.arange(n), 4)
|
|
141
|
+
|
|
142
|
+
def mk(power):
|
|
143
|
+
return jnp.zeros((ncp, n), REAL).at[rows, cols].add((w**power).reshape(-1))
|
|
144
|
+
|
|
145
|
+
return mk(1), mk(2), mk(3)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@functools.partial(jax.jit, static_argnums=(2, 3))
|
|
149
|
+
def _bspline_fit(r, mask, ncp_shape, mesh):
|
|
150
|
+
"""Cubic B-spline Lee-MBA fit of residual ``r`` over the mask, evaluated densely.
|
|
151
|
+
Separable formulation: identical math to the per-voxel scatter, as matmuls."""
|
|
152
|
+
s, h, w = r.shape
|
|
153
|
+
ncpz, ncpy, ncpx = ncp_shape
|
|
154
|
+
Wz1, Wz2, Wz3 = _axis_mats(s, ncpz, mesh)
|
|
155
|
+
Wy1, Wy2, Wy3 = _axis_mats(h, ncpy, mesh)
|
|
156
|
+
Wx1, Wx2, Wx3 = _axis_mats(w, ncpx, mesh)
|
|
157
|
+
|
|
158
|
+
mvox = (mask > 0.5).astype(REAL)
|
|
159
|
+
sz, sy, sx = jnp.sum(Wz2, 0), jnp.sum(Wy2, 0), jnp.sum(Wx2, 0) # Lee denom factorises
|
|
160
|
+
g = (r * mvox) / (sz[:, None, None] * sy[None, :, None] * sx[None, None, :])
|
|
161
|
+
|
|
162
|
+
num = jnp.einsum("cx,abx->abc", Wx3, jnp.einsum("by,ayx->abx", Wy3, jnp.einsum("az,zyx->ayx", Wz3, g)))
|
|
163
|
+
den = jnp.einsum("cx,abx->abc", Wx2, jnp.einsum("by,ayx->abx", Wy2, jnp.einsum("az,zyx->ayx", Wz2, mvox)))
|
|
164
|
+
phi = num / jnp.where(den > 1e-12, den, 1e-12)
|
|
165
|
+
|
|
166
|
+
p1 = jnp.einsum("az,abc->zbc", Wz1, phi)
|
|
167
|
+
p2 = jnp.einsum("by,zbc->zyc", Wy1, p1)
|
|
168
|
+
return jnp.einsum("cx,zyc->zyx", Wx1, p2)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _conv_cv(b_prev, b, maskb, cnt):
|
|
172
|
+
"""ITK convergence measure: CV = std/mean of exp(B_prev - B_curr) over the mask."""
|
|
173
|
+
r = jnp.exp(jnp.where(maskb, b_prev - b, 0.0))
|
|
174
|
+
mu = jnp.sum(jnp.where(maskb, r, 0.0)) / cnt
|
|
175
|
+
var = jnp.sum(jnp.where(maskb, (r - mu) ** 2, 0.0)) / (cnt - 1.0)
|
|
176
|
+
return jnp.sqrt(var) / mu
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# ----------------------------- driver ---------------------------------------
|
|
180
|
+
@functools.cache
|
|
181
|
+
def _compiled(iters, nbins):
|
|
182
|
+
@jax.jit
|
|
183
|
+
def core(v, mask, fwhm, wiener, threshold, tiny, over_relax):
|
|
184
|
+
maskb = mask > 0.5
|
|
185
|
+
cnt = jnp.sum(maskb.astype(REAL))
|
|
186
|
+
u = jnp.log(jnp.clip(v, tiny, None))
|
|
187
|
+
b = jnp.zeros_like(u)
|
|
188
|
+
used = []
|
|
189
|
+
for lvl, nit in enumerate(iters):
|
|
190
|
+
mesh = 2**lvl
|
|
191
|
+
ncp = (mesh + SPLINE_ORDER,) * 3
|
|
192
|
+
|
|
193
|
+
def cond(c, _nit=nit):
|
|
194
|
+
_, i, cv = c
|
|
195
|
+
return (i < _nit) & (cv > threshold)
|
|
196
|
+
|
|
197
|
+
def body(c, _ncp=ncp, _mesh=mesh):
|
|
198
|
+
bp, i, _ = c
|
|
199
|
+
uc = u - bp
|
|
200
|
+
e = _sharpen(uc, mask, nbins, fwhm, wiener)
|
|
201
|
+
s = _bspline_fit(jnp.where(maskb, uc - e, 0.0), mask, _ncp, _mesh)
|
|
202
|
+
bn = bp + over_relax * s
|
|
203
|
+
return (bn, i + 1, _conv_cv(bp, bn, maskb, cnt))
|
|
204
|
+
|
|
205
|
+
b, i, _ = jax.lax.while_loop(cond, body, (b, 0, jnp.array(jnp.inf, REAL)))
|
|
206
|
+
used.append(i)
|
|
207
|
+
corrected = jnp.where(maskb, jnp.exp(u - b), v)
|
|
208
|
+
return corrected, b, jnp.stack(used)
|
|
209
|
+
|
|
210
|
+
return core
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def n4(
|
|
214
|
+
image,
|
|
215
|
+
mask=None,
|
|
216
|
+
*,
|
|
217
|
+
iters=(8, 12, 12, 8),
|
|
218
|
+
over_relax: float = 1.8,
|
|
219
|
+
conv_threshold: float = 0.0,
|
|
220
|
+
nbins: int = 200,
|
|
221
|
+
fwhm: float = 0.15,
|
|
222
|
+
wiener: float = 0.01,
|
|
223
|
+
tiny: float = 1e-6,
|
|
224
|
+
return_bias: bool = False,
|
|
225
|
+
):
|
|
226
|
+
"""N4 bias field correction.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
image: ND array (2D or 3D) of intensities (>= 0).
|
|
230
|
+
mask: optional binary foreground mask; if ``None``, an Otsu mask is used.
|
|
231
|
+
iters: max iterations per fitting level (mesh 1, 2, 4, 8). ``conv_threshold``
|
|
232
|
+
(ITK-style CV) can stop a level early; the default uses fixed counts
|
|
233
|
+
with over-relaxation for speed.
|
|
234
|
+
over_relax: relaxation factor (>= 1) accelerating the crawl to the fixed point.
|
|
235
|
+
nbins/fwhm/wiener: N3 sharpening parameters (ITK defaults).
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
corrected image (== image / exp(bias)); also the log bias field if ``return_bias``.
|
|
239
|
+
"""
|
|
240
|
+
image = jnp.asarray(image, REAL)
|
|
241
|
+
in3d = image.ndim == 3
|
|
242
|
+
v = image if in3d else image[None]
|
|
243
|
+
mask = otsu_mask(v) if mask is None else jnp.asarray(mask, REAL).reshape(v.shape)
|
|
244
|
+
core = _compiled(tuple(iters), nbins)
|
|
245
|
+
corrected, bias, _ = core(
|
|
246
|
+
v,
|
|
247
|
+
mask,
|
|
248
|
+
jnp.asarray(fwhm, REAL),
|
|
249
|
+
jnp.asarray(wiener, REAL),
|
|
250
|
+
jnp.asarray(conv_threshold, REAL),
|
|
251
|
+
jnp.asarray(tiny, REAL),
|
|
252
|
+
jnp.asarray(over_relax, REAL),
|
|
253
|
+
)
|
|
254
|
+
if not in3d:
|
|
255
|
+
corrected, bias = corrected[0], bias[0]
|
|
256
|
+
return (corrected, bias) if return_bias else corrected
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: n4ax
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: JAX/GPU N4 bias field correction — a fast drop-in match for ITK N4
|
|
5
|
+
Author: Geoffroy Oudoumanessah, Jacopo Iollo
|
|
6
|
+
Author-email: Gragas <contact@gragas.ai>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Project-URL: Homepage, https://github.com/GragasLab/n4ax
|
|
9
|
+
Project-URL: Repository, https://github.com/GragasLab/n4ax
|
|
10
|
+
Keywords: MRI,bias field,N4,N3,JAX,GPU
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
17
|
+
Requires-Python: >=3.12
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
License-File: LICENSE
|
|
20
|
+
Requires-Dist: jax>=0.8.0
|
|
21
|
+
Requires-Dist: numpy>=1.24.0
|
|
22
|
+
Provides-Extra: cpu
|
|
23
|
+
Requires-Dist: jax[cpu]>=0.8.0; extra == "cpu"
|
|
24
|
+
Provides-Extra: cuda12
|
|
25
|
+
Requires-Dist: jax[cuda12]>=0.8.0; extra == "cuda12"
|
|
26
|
+
Provides-Extra: compare
|
|
27
|
+
Requires-Dist: SimpleITK>=2.3.0; extra == "compare"
|
|
28
|
+
Requires-Dist: matplotlib>=3.7; extra == "compare"
|
|
29
|
+
Requires-Dist: nibabel>=5.0; extra == "compare"
|
|
30
|
+
Provides-Extra: dev
|
|
31
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
32
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
33
|
+
Requires-Dist: ruff>=0.4.0; extra == "dev"
|
|
34
|
+
Requires-Dist: pre-commit>=3.0; extra == "dev"
|
|
35
|
+
Requires-Dist: SimpleITK>=2.3.0; extra == "dev"
|
|
36
|
+
Dynamic: license-file
|
|
37
|
+
|
|
38
|
+
# n4ax
|
|
39
|
+
|
|
40
|
+
**N4 bias field correction in pure JAX** — a fast, GPU-friendly, *drop-in* match for
|
|
41
|
+
ITK / SimpleITK's `N4BiasFieldCorrectionImageFilter`.
|
|
42
|
+
|
|
43
|
+
n4ax reimplements the N4 algorithm (Tustison et al., 2010 — N3 histogram sharpening
|
|
44
|
+
+ multi-resolution B-spline) faithfully enough to **match SimpleITK to ~1%** on real
|
|
45
|
+
MRI, while running **~1500× faster on a GPU** and **~20× faster on the same CPU**.
|
|
46
|
+
|
|
47
|
+

|
|
48
|
+
*Raw NKI T1w (with B1 shading) → n4ax-corrected → ITK-corrected (visually identical) → estimated bias field.*
|
|
49
|
+
|
|
50
|
+
## Why
|
|
51
|
+
|
|
52
|
+
N4 is the de-facto standard bias correction, but ITK's implementation is CPU-only and
|
|
53
|
+
slow (minutes per volume). In a GPU MRI pipeline it becomes the bottleneck. n4ax gives
|
|
54
|
+
**N4-quality output on the GPU in tens of milliseconds**, with no custom CUDA — just JAX.
|
|
55
|
+
|
|
56
|
+
## Install
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
uv sync --extra cuda12 # GPU (CUDA 12)
|
|
60
|
+
uv sync --extra cpu # CPU
|
|
61
|
+
uv sync --extra cuda12 --extra dev # + tests/linting
|
|
62
|
+
uv sync --extra cuda12 --extra compare # + SimpleITK/matplotlib for benchmarks
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
## Usage
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
import nibabel as nib
|
|
69
|
+
import n4ax
|
|
70
|
+
|
|
71
|
+
vol = nib.load("t1w.nii.gz").get_fdata() # 3D (or 2D) array, intensities >= 0
|
|
72
|
+
corrected = n4ax.n4(vol) # Otsu mask computed automatically
|
|
73
|
+
# or pass your own mask, and/or get the log bias field:
|
|
74
|
+
corrected, log_bias = n4ax.n4(vol, mask=mask, return_bias=True)
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
`corrected == vol / exp(log_bias)`. The default config (`iters=(8,12,12,8)`,
|
|
78
|
+
`over_relax=1.8`) is tuned for speed; for the tightest ITK match use the robust
|
|
79
|
+
fallback `n4ax.n4(vol, iters=(50,50,30,20), over_relax=1.0, conv_threshold=1.5e-3)`.
|
|
80
|
+
|
|
81
|
+
## Benchmark
|
|
82
|
+
|
|
83
|
+
Real NKI T1w volumes (256×176×256, ~2 M brain voxels), N4 `[50,50,30,20]`, same Otsu mask.
|
|
84
|
+
ITK on an 8-core CPU; n4ax CPU on the same node; n4ax GPU on an NVIDIA A100.
|
|
85
|
+
|
|
86
|
+
| Method | Time / volume | Speedup vs ITK |
|
|
87
|
+
|---|--:|--:|
|
|
88
|
+
| ITK N4 (CPU, 8 cores) | **146 s** | 1× |
|
|
89
|
+
| n4ax (CPU, 8 cores) | **7.7 s** | **~19×** |
|
|
90
|
+
| n4ax (A100 GPU) | **93 ms** | **~1571×** |
|
|
91
|
+
|
|
92
|
+
**Accuracy vs ITK** (corrected image, global scale removed — pipelines intensity-normalise anyway):
|
|
93
|
+
mean **1.15 %**, per-subject 0.79–1.59 % over 6 NKI scans. On a single fitting level n4ax
|
|
94
|
+
matches ITK to **0.4 %**, and a single N4 iteration to **0.1 %** — the building blocks are exact;
|
|
95
|
+
the residual is N4's own iterative crawl (ITK itself only converges after ~30 iters/level).
|
|
96
|
+
|
|
97
|
+
Multiple subjects, raw (top) vs n4ax-corrected (bottom):
|
|
98
|
+
|
|
99
|
+

|
|
100
|
+
|
|
101
|
+
Reproduce: `python scripts/bench_nki.py` (GPU) and `JAX_PLATFORMS=cpu python scripts/bench_nki.py --skip-itk --skip-fig --tag cpu`.
|
|
102
|
+
|
|
103
|
+
## How it's fast (no custom kernels)
|
|
104
|
+
|
|
105
|
+
- **Separable B-spline fit.** N4's per-iteration B-spline least-squares (Lee MBA) is a
|
|
106
|
+
94 M-way scatter into a tiny control lattice — brutal atomic contention (~30 ms/iter).
|
|
107
|
+
Because the cubic weights depend only on the per-axis index and the Lee denominator
|
|
108
|
+
factorises, this becomes **3 small dense matmuls per axis** (cuBLAS) — *identical math*,
|
|
109
|
+
0.1 ms/iter.
|
|
110
|
+
- **Privatised histogram.** The N3 sharpening histogram (1.5 M → 200 bins) is privatised
|
|
111
|
+
over 256 lanes to avoid atomic serialisation.
|
|
112
|
+
- **Over-relaxation.** N4's fixed point is invariant to `B += α·S` (S = 0 there), so
|
|
113
|
+
`α ≈ 1.8` reaches ITK's result in far fewer iterations.
|
|
114
|
+
- The whole solve is one fused, jitted program with a device-side convergence loop.
|
|
115
|
+
|
|
116
|
+
Two things that mattered for *correctness*: zero-padding the sharpening FFT (circular
|
|
117
|
+
wraparound otherwise breaks convergence), and that float32 == float64 here (verified).
|
|
118
|
+
|
|
119
|
+
## Tests
|
|
120
|
+
|
|
121
|
+
```bash
|
|
122
|
+
uv run pytest # basic correctness + ground-truth match vs SimpleITK
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
`tests/test_vs_itk.py` asserts n4ax matches SimpleITK's N4 (the reference) within tolerance
|
|
126
|
+
on a phantom; `tests/test_basic.py` covers shapes, 2D/3D, the `image/exp(bias)` identity,
|
|
127
|
+
bias flattening, and the Otsu mask.
|
|
128
|
+
|
|
129
|
+
## Status
|
|
130
|
+
|
|
131
|
+
Alpha. The fast defaults are tuned on NKI/phantom data; validate on your own data before
|
|
132
|
+
production (the `iters=(50,50,30,20), over_relax=1.0` fallback is the conservative choice).
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
n4ax/__init__.py
|
|
5
|
+
n4ax/core.py
|
|
6
|
+
n4ax.egg-info/PKG-INFO
|
|
7
|
+
n4ax.egg-info/SOURCES.txt
|
|
8
|
+
n4ax.egg-info/dependency_links.txt
|
|
9
|
+
n4ax.egg-info/requires.txt
|
|
10
|
+
n4ax.egg-info/top_level.txt
|
|
11
|
+
tests/test_basic.py
|
|
12
|
+
tests/test_vs_itk.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
jax>=0.8.0
|
|
2
|
+
numpy>=1.24.0
|
|
3
|
+
|
|
4
|
+
[compare]
|
|
5
|
+
SimpleITK>=2.3.0
|
|
6
|
+
matplotlib>=3.7
|
|
7
|
+
nibabel>=5.0
|
|
8
|
+
|
|
9
|
+
[cpu]
|
|
10
|
+
jax[cpu]>=0.8.0
|
|
11
|
+
|
|
12
|
+
[cuda12]
|
|
13
|
+
jax[cuda12]>=0.8.0
|
|
14
|
+
|
|
15
|
+
[dev]
|
|
16
|
+
pytest>=7.0
|
|
17
|
+
pytest-cov
|
|
18
|
+
ruff>=0.4.0
|
|
19
|
+
pre-commit>=3.0
|
|
20
|
+
SimpleITK>=2.3.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
n4ax
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "n4ax"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
authors = [
|
|
9
|
+
{name = "Gragas", email = "contact@gragas.ai"},
|
|
10
|
+
{name = "Geoffroy Oudoumanessah"},
|
|
11
|
+
{name = "Jacopo Iollo"},
|
|
12
|
+
]
|
|
13
|
+
description = "JAX/GPU N4 bias field correction — a fast drop-in match for ITK N4"
|
|
14
|
+
readme = "README.md"
|
|
15
|
+
license = "MIT"
|
|
16
|
+
requires-python = ">=3.12"
|
|
17
|
+
classifiers = [
|
|
18
|
+
"Development Status :: 3 - Alpha",
|
|
19
|
+
"Intended Audience :: Science/Research",
|
|
20
|
+
"Programming Language :: Python :: 3",
|
|
21
|
+
"Programming Language :: Python :: 3.12",
|
|
22
|
+
"Programming Language :: Python :: 3.13",
|
|
23
|
+
"Topic :: Scientific/Engineering :: Medical Science Apps.",
|
|
24
|
+
]
|
|
25
|
+
keywords = ["MRI", "bias field", "N4", "N3", "JAX", "GPU"]
|
|
26
|
+
dependencies = [
|
|
27
|
+
"jax>=0.8.0",
|
|
28
|
+
"numpy>=1.24.0",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[project.optional-dependencies]
|
|
32
|
+
cpu = ["jax[cpu]>=0.8.0"]
|
|
33
|
+
cuda12 = ["jax[cuda12]>=0.8.0"]
|
|
34
|
+
# `compare` pulls SimpleITK (the reference N4) + plotting/IO for tests & benchmarks.
|
|
35
|
+
compare = ["SimpleITK>=2.3.0", "matplotlib>=3.7", "nibabel>=5.0"]
|
|
36
|
+
dev = [
|
|
37
|
+
"pytest>=7.0",
|
|
38
|
+
"pytest-cov",
|
|
39
|
+
"ruff>=0.4.0",
|
|
40
|
+
"pre-commit>=3.0",
|
|
41
|
+
"SimpleITK>=2.3.0",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
[project.urls]
|
|
45
|
+
Homepage = "https://github.com/GragasLab/n4ax"
|
|
46
|
+
Repository = "https://github.com/GragasLab/n4ax"
|
|
47
|
+
|
|
48
|
+
[tool.setuptools.packages.find]
|
|
49
|
+
include = ["n4ax*"]
|
|
50
|
+
|
|
51
|
+
[tool.pytest.ini_options]
|
|
52
|
+
testpaths = ["tests"]
|
|
53
|
+
python_files = ["test_*.py"]
|
|
54
|
+
addopts = "-v --tb=short"
|
|
55
|
+
pythonpath = [".", "tests"]
|
|
56
|
+
|
|
57
|
+
[tool.ruff]
|
|
58
|
+
target-version = "py312"
|
|
59
|
+
line-length = 119
|
|
60
|
+
|
|
61
|
+
[tool.ruff.lint]
|
|
62
|
+
select = ["E", "F", "I", "W", "UP", "FURB", "SIM", "S110", "C4", "RUF013", "PERF102", "PLC1802", "PLC0208", "PIE794"]
|
|
63
|
+
ignore = ["E501", "E741", "SIM1", "SIM905", "UP015", "UP031"]
|
|
64
|
+
extend-safe-fixes = ["UP006"]
|
|
65
|
+
|
|
66
|
+
[tool.ruff.lint.per-file-ignores]
|
|
67
|
+
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
|
68
|
+
"*.ipynb" = ["E402", "E731", "B007", "N816"]
|
|
69
|
+
|
|
70
|
+
[tool.ruff.lint.isort]
|
|
71
|
+
lines-after-imports = 2
|
|
72
|
+
known-first-party = ["n4ax"]
|
|
73
|
+
|
|
74
|
+
[tool.ruff.format]
|
|
75
|
+
quote-style = "double"
|
|
76
|
+
indent-style = "space"
|
|
77
|
+
skip-magic-trailing-comma = false
|
|
78
|
+
line-ending = "auto"
|
n4ax-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Basic correctness: import, shapes, finiteness, 2D/3D, near-identity on no bias."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import n4ax
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_import():
|
|
9
|
+
assert callable(n4ax.n4)
|
|
10
|
+
assert callable(n4ax.otsu_mask)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_3d_shape_and_finite(phantom):
|
|
14
|
+
obs, _, _ = phantom
|
|
15
|
+
corr = np.asarray(n4ax.n4(obs))
|
|
16
|
+
assert corr.shape == obs.shape
|
|
17
|
+
assert np.isfinite(corr).all()
|
|
18
|
+
assert (corr >= 0).all()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_2d_runs():
|
|
22
|
+
rng = np.random.default_rng(0)
|
|
23
|
+
img = rng.uniform(0.5, 1.5, size=(64, 64)).astype(np.float32)
|
|
24
|
+
img[:5] = 0.0 # background
|
|
25
|
+
corr = np.asarray(n4ax.n4(img))
|
|
26
|
+
assert corr.shape == img.shape
|
|
27
|
+
assert np.isfinite(corr).all()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_return_bias(phantom):
|
|
31
|
+
obs, _, mask = phantom
|
|
32
|
+
corr, bias = n4ax.n4(obs, mask=mask.astype(np.float32), return_bias=True)
|
|
33
|
+
corr, bias = np.asarray(corr), np.asarray(bias)
|
|
34
|
+
m = mask
|
|
35
|
+
# corrected == image / exp(bias) inside the mask
|
|
36
|
+
recon = obs / np.exp(bias)
|
|
37
|
+
assert np.abs(corr[m] - recon[m]).max() < 1e-3
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_recovers_bias(phantom):
|
|
41
|
+
"""N4 should flatten a known smooth bias: corrected tissue is more uniform
|
|
42
|
+
(lower coefficient of variation) than the observed biased image."""
|
|
43
|
+
obs, bias, mask = phantom
|
|
44
|
+
corr = np.asarray(n4ax.n4(obs, mask=mask.astype(np.float32)))
|
|
45
|
+
m = mask
|
|
46
|
+
cv_before = obs[m].std() / obs[m].mean()
|
|
47
|
+
cv_after = corr[m].std() / corr[m].mean()
|
|
48
|
+
assert cv_after < cv_before
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_otsu_mask(phantom):
|
|
52
|
+
obs, _, mask = phantom
|
|
53
|
+
om = np.asarray(n4ax.otsu_mask(obs)) > 0.5
|
|
54
|
+
# Otsu foreground should agree with the true brain mask on most voxels
|
|
55
|
+
agree = (om == mask).mean()
|
|
56
|
+
assert agree > 0.95
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Ground-truth test: n4ax must match SimpleITK's N4 (the reference implementation).
|
|
2
|
+
|
|
3
|
+
Uses the SAME mask for both (so we compare the N4 solve, not the masking), and
|
|
4
|
+
compares the corrected image with its global scale removed (N4's bias field is
|
|
5
|
+
defined up to a constant; downstream pipelines intensity-normalise anyway)."""
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
import n4ax
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
sitk = pytest.importorskip("SimpleITK")
|
|
14
|
+
from _phantom import make_phantom # noqa: E402
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _itk_n4(obs, mask, iters=(50, 50, 30, 20)):
|
|
18
|
+
img = sitk.GetImageFromArray(obs.astype(np.float32))
|
|
19
|
+
mk = sitk.GetImageFromArray(mask.astype(np.uint8))
|
|
20
|
+
c = sitk.N4BiasFieldCorrectionImageFilter()
|
|
21
|
+
c.SetMaximumNumberOfIterations([int(i) for i in iters])
|
|
22
|
+
out = c.Execute(img, mk)
|
|
23
|
+
return sitk.GetArrayFromImage(out).astype(np.float64)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _scaled_relerr(a, b, m):
|
|
27
|
+
a, b = np.asarray(a, np.float64), np.asarray(b, np.float64)
|
|
28
|
+
ratio = a[m] / np.clip(b[m], 1e-6, None)
|
|
29
|
+
rel = np.abs(a[m] / np.median(ratio) - b[m]) / np.clip(np.abs(b[m]), 1e-6, None)
|
|
30
|
+
return rel
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_matches_simpleitk():
|
|
34
|
+
obs, _, mask = make_phantom()
|
|
35
|
+
itk = _itk_n4(obs, mask)
|
|
36
|
+
jax_corr = np.asarray(n4ax.n4(obs, mask=mask.astype(np.float32)))
|
|
37
|
+
rel = _scaled_relerr(jax_corr, itk, mask)
|
|
38
|
+
assert rel.mean() < 0.015, f"mean rel-err {rel.mean() * 100:.2f}% too high"
|
|
39
|
+
assert np.percentile(rel, 95) < 0.05, f"p95 rel-err {np.percentile(rel, 95) * 100:.2f}% too high"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_closer_to_itk_than_uncorrected():
|
|
43
|
+
"""n4ax's correction must be much closer to ITK's than doing nothing."""
|
|
44
|
+
obs, _, mask = make_phantom(seed=1)
|
|
45
|
+
itk = _itk_n4(obs, mask)
|
|
46
|
+
jax_corr = np.asarray(n4ax.n4(obs, mask=mask.astype(np.float32)))
|
|
47
|
+
err_jax = _scaled_relerr(jax_corr, itk, mask).mean()
|
|
48
|
+
err_none = _scaled_relerr(obs, itk, mask).mean()
|
|
49
|
+
assert err_jax < 0.25 * err_none
|