fireants 0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fireants-0.1/.gitignore +17 -0
- fireants-0.1/PKG-INFO +42 -0
- fireants-0.1/README.md +24 -0
- fireants-0.1/fireants/__init__.py +0 -0
- fireants-0.1/fireants/io/__init__.py +1 -0
- fireants-0.1/fireants/io/image.py +122 -0
- fireants-0.1/fireants/losses/__init__.py +4 -0
- fireants-0.1/fireants/losses/cc.py +323 -0
- fireants-0.1/fireants/losses/mi.py +200 -0
- fireants-0.1/fireants/losses/mse.py +56 -0
- fireants-0.1/fireants/registration/abstract.py +64 -0
- fireants-0.1/fireants/registration/affine.py +141 -0
- fireants-0.1/fireants/registration/deformation/abstract.py +30 -0
- fireants-0.1/fireants/registration/deformation/compositive.py +150 -0
- fireants-0.1/fireants/registration/deformation/geodesic.py +137 -0
- fireants-0.1/fireants/registration/greedy.py +189 -0
- fireants-0.1/fireants/registration/logdemons.py +182 -0
- fireants-0.1/fireants/registration/optimizers/adam.py +126 -0
- fireants-0.1/fireants/registration/optimizers/sgd.py +131 -0
- fireants-0.1/fireants/registration/rigid.py +171 -0
- fireants-0.1/fireants/registration/syn.py +214 -0
- fireants-0.1/fireants/scripts/analyse_raytune.py +51 -0
- fireants-0.1/fireants/scripts/evaluate_metrics.py +43 -0
- fireants-0.1/fireants/scripts/evalutils.py +1 -0
- fireants-0.1/fireants/scripts/lookup_tables.py +1 -0
- fireants-0.1/fireants/scripts/oasis.py +185 -0
- fireants-0.1/fireants/scripts/oasis_test.py +65 -0
- fireants-0.1/fireants/scripts/test_cumc12.py +111 -0
- fireants-0.1/fireants/scripts/test_ibsr.py +108 -0
- fireants-0.1/fireants/scripts/test_lpba40.py +128 -0
- fireants-0.1/fireants/scripts/test_mgh10.py +111 -0
- fireants-0.1/fireants/scripts/tune_empire10.py +254 -0
- fireants-0.1/fireants/scripts/tune_lpba40.py +168 -0
- fireants-0.1/fireants/tests/loadtest.py +4 -0
- fireants-0.1/fireants/types.py +7 -0
- fireants-0.1/fireants/utils/__init__.py +0 -0
- fireants-0.1/fireants/utils/globals.py +1 -0
- fireants-0.1/fireants/utils/imageutils.py +285 -0
- fireants-0.1/fireants/utils/opticalflow.py +91 -0
- fireants-0.1/fireants/utils/util.py +124 -0
- fireants-0.1/pyproject.toml +25 -0
fireants-0.1/.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
*egg-info
|
|
2
|
+
*pdf
|
|
3
|
+
*svg
|
|
4
|
+
dist/
|
|
5
|
+
build/
|
|
6
|
+
**pkl
|
|
7
|
+
**/__pycache__
|
|
8
|
+
**/baselines
|
|
9
|
+
**/*.pkl
|
|
10
|
+
**/.ipynb_checkpoints/
|
|
11
|
+
fireants/notebooks/images/
|
|
12
|
+
**/*log.txt
|
|
13
|
+
**/*.nii.gz
|
|
14
|
+
fireants/scripts/misc
|
|
15
|
+
fireants/notebooks
|
|
16
|
+
fireants/baselines
|
|
17
|
+
.req.txt
|
fireants-0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: fireants
|
|
3
|
+
Version: 0.1
|
|
4
|
+
Summary: FireANTs: Adaptive Riemannian Optimization for Multi-Scale Diffeomorphic Registration
|
|
5
|
+
Author: Rohit Jena, Pratik Chaudhari, James C. Gee
|
|
6
|
+
Requires-Python: >=3.7
|
|
7
|
+
Requires-Dist: matplotlib
|
|
8
|
+
Requires-Dist: nibabel==4.0.2
|
|
9
|
+
Requires-Dist: numpy
|
|
10
|
+
Requires-Dist: pandas==1.3.5
|
|
11
|
+
Requires-Dist: scikit-image
|
|
12
|
+
Requires-Dist: scipy
|
|
13
|
+
Requires-Dist: simpleitk==2.2.1
|
|
14
|
+
Requires-Dist: torch==1.13.1
|
|
15
|
+
Requires-Dist: tqdm
|
|
16
|
+
Requires-Dist: typing
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
|
|
19
|
+
# :fire: FireANTs: Adaptive Riemannian Optimization for Multi-Scale Diffeomorphic Registration
|
|
20
|
+
|
|
21
|
+
The FireANTs library is a lightweight registration package for Riemannian diffeomorphic registration on GPUs.
|
|
22
|
+
|
|
23
|
+
## Installation
|
|
24
|
+
To use the FireANTs package, you can either clone the repository and install the package locally or install the package directly from PyPI.
|
|
25
|
+
We recommend using a fresh Anaconda/Miniconda environment to install the package.
|
|
26
|
+
```
|
|
27
|
+
conda create -n fireants python=3.7
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
To install FireANTs locally:
|
|
31
|
+
```
|
|
32
|
+
git clone https://github.com/rohitrango/fireants
|
|
33
|
+
cd fireants
|
|
34
|
+
pip install -e .
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
Or to install from PyPI:
|
|
38
|
+
```
|
|
39
|
+
pip install fireants
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
## Tutorial
|
fireants-0.1/README.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# :fire: FireANTs: Adaptive Riemannian Optimization for Multi-Scale Diffeomorphic Registration
|
|
2
|
+
|
|
3
|
+
The FireANTs library is a lightweight registration package for Riemannian diffeomorphic registration on GPUs.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
To use the FireANTs package, you can either clone the repository and install the package locally or install the package directly from PyPI.
|
|
7
|
+
We recommend using a fresh Anaconda/Miniconda environment to install the package.
|
|
8
|
+
```
|
|
9
|
+
conda create -n fireants python=3.7
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
To install FireANTs locally:
|
|
13
|
+
```
|
|
14
|
+
git clone https://github.com/rohitrango/fireants
|
|
15
|
+
cd fireants
|
|
16
|
+
pip install -e .
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
Or to install from PyPI:
|
|
20
|
+
```
|
|
21
|
+
pip install fireants
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
## Tutorial
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from fireants.io.image import Image, BatchedImages
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import SimpleITK as sitk
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Any, Union, List
|
|
5
|
+
from time import time
|
|
6
|
+
from fireants.types import devicetype
|
|
7
|
+
from fireants.utils.imageutils import integer_to_onehot
|
|
8
|
+
|
|
9
|
+
class Image:
|
|
10
|
+
'''
|
|
11
|
+
TODO: Documentation here
|
|
12
|
+
'''
|
|
13
|
+
def __init__(self, itk_image: sitk.SimpleITK.Image, device: devicetype = 'cuda',
|
|
14
|
+
is_segmentation=False, max_seg_label=None, background_seg_label=0, seg_preprocessor=lambda x: x) -> None:
|
|
15
|
+
self.itk_image = itk_image
|
|
16
|
+
# check for segmentation parameters
|
|
17
|
+
# if `is_segmentation` is False, then just treat this as a float image
|
|
18
|
+
if not is_segmentation:
|
|
19
|
+
self.array = torch.from_numpy(sitk.GetArrayFromImage(itk_image).astype(float)).to(device).float()
|
|
20
|
+
self.array = self.array[None, None] # TODO: Change it to support multichannel images, right now just batchify and add a dummy channel to it
|
|
21
|
+
channels = itk_image.GetNumberOfComponentsPerPixel()
|
|
22
|
+
self.channels = channels
|
|
23
|
+
assert channels == 1, "Only single channel images supported"
|
|
24
|
+
else:
|
|
25
|
+
array = torch.from_numpy(sitk.GetArrayFromImage(itk_image).astype(int)).to(device).long()
|
|
26
|
+
# preprocess segmentation if provided by user
|
|
27
|
+
array = seg_preprocessor(array)
|
|
28
|
+
if max_seg_label is not None:
|
|
29
|
+
array[array > max_seg_label] = background_seg_label
|
|
30
|
+
array = integer_to_onehot(array, background_label=background_seg_label, max_label=max_seg_label)[None] # []
|
|
31
|
+
self.array = array.float()
|
|
32
|
+
self.channels = array.shape[1]
|
|
33
|
+
# initialize matrix for pixel to physical
|
|
34
|
+
dims = itk_image.GetDimension()
|
|
35
|
+
self.dims = dims
|
|
36
|
+
if dims not in [2, 3]:
|
|
37
|
+
raise NotImplementedError("Image class only supports 2D/3D images.")
|
|
38
|
+
px2phy = np.eye(dims+1)
|
|
39
|
+
px2phy[:dims, -1] = itk_image.GetOrigin()
|
|
40
|
+
px2phy[:dims, :dims] = np.array(itk_image.GetDirection()).reshape(dims, dims)
|
|
41
|
+
px2phy[:dims, :dims] = px2phy[:dims, :dims] * np.array(itk_image.GetSpacing())[None]
|
|
42
|
+
# generate mapping from torch to px
|
|
43
|
+
torch2px = np.eye(dims+1)
|
|
44
|
+
scaleterm = (np.array(itk_image.GetSize())-1)*0.5
|
|
45
|
+
torch2px[:dims, :dims] = np.diag(scaleterm)
|
|
46
|
+
torch2px[:dims, -1] = scaleterm
|
|
47
|
+
# save the mapping from physical to torch and vice versa
|
|
48
|
+
self.torch2phy = torch.from_numpy(np.matmul(px2phy, torch2px)).to(device).float().unsqueeze(0)
|
|
49
|
+
self.phy2torch = torch.inverse(self.torch2phy[0]).float().unsqueeze(0)
|
|
50
|
+
# also save intermediates just in case (as numpy arrays)
|
|
51
|
+
self._torch2px = torch2px
|
|
52
|
+
self._px2phy = px2phy
|
|
53
|
+
self.device = device
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def load_file(cls, image_path:str, *args, **kwargs) -> 'Image':
|
|
57
|
+
itk_image = sitk.ReadImage(image_path)
|
|
58
|
+
return cls(itk_image, *args, **kwargs)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class BatchedImages:
|
|
62
|
+
'''
|
|
63
|
+
Class for batched images
|
|
64
|
+
'''
|
|
65
|
+
def __init__(self, images: Union[Image, List[Image]]) -> None:
|
|
66
|
+
if isinstance(images, Image):
|
|
67
|
+
images = [images]
|
|
68
|
+
self.images = images
|
|
69
|
+
if len(self.images) == 0:
|
|
70
|
+
raise ValueError("BatchedImages must have at least one image")
|
|
71
|
+
for image in self.images:
|
|
72
|
+
if not isinstance(image, Image):
|
|
73
|
+
raise TypeError("All images must be of type Image")
|
|
74
|
+
shapes = [x.array.shape for x in self.images]
|
|
75
|
+
if all([x == shapes[0] for x in shapes]):
|
|
76
|
+
self.shape = shapes[0]
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError("All images must have the same shape")
|
|
79
|
+
self.n_images = len(self.images)
|
|
80
|
+
self.interpolate_mode = 'bilinear' if self.images[0] == 2 else 'trilinear'
|
|
81
|
+
|
|
82
|
+
def __call__(self):
|
|
83
|
+
# get batch of images
|
|
84
|
+
return torch.cat([x.array for x in self.images], dim=0)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def device(self):
|
|
88
|
+
return self.images[0].device
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def dims(self):
|
|
92
|
+
return self.images[0].dims
|
|
93
|
+
|
|
94
|
+
def size(self):
|
|
95
|
+
return self.n_images
|
|
96
|
+
|
|
97
|
+
def shape(self):
|
|
98
|
+
shape = self.images[0].shape
|
|
99
|
+
shape[0] = self.n_images
|
|
100
|
+
return shape
|
|
101
|
+
|
|
102
|
+
def get_torch2phy(self):
|
|
103
|
+
return torch.cat([x.torch2phy for x in self.images], dim=0)
|
|
104
|
+
|
|
105
|
+
def get_phy2torch(self):
|
|
106
|
+
return torch.cat([x.phy2torch for x in self.images], dim=0)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
if __name__ == '__main__':
|
|
110
|
+
# image = Image.load_file('/data/BRATS2021/training/BraTS2021_00598/BraTS2021_00598_t1.nii.gz')
|
|
111
|
+
# print(image.torch2phy)
|
|
112
|
+
# image2 = Image.load_file('/data/BRATS2021/training/BraTS2021_00599/BraTS2021_00599_t1.nii.gz')
|
|
113
|
+
# batch = BatchedImages([image, image2])
|
|
114
|
+
# print(batch().shape)
|
|
115
|
+
# print(batch.get_torch2phy().shape)
|
|
116
|
+
from glob import glob
|
|
117
|
+
files = sorted(glob("/data/IBSR_braindata/IBSR_01/*nii.gz"))
|
|
118
|
+
image = Image.load_file(files[2])
|
|
119
|
+
print(image.array.shape, image.array.min(), image.array.max())
|
|
120
|
+
# get label
|
|
121
|
+
label = Image.load_file(files[-1], is_segmentation=True)
|
|
122
|
+
print(label.array.shape, label.array.min(), label.array.max())
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
'''
|
|
2
|
+
Cross correlation
|
|
3
|
+
'''
|
|
4
|
+
from time import time, sleep
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.checkpoint import checkpoint
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.nn import functional as F
|
|
9
|
+
from typing import Union, Tuple, List, Optional, Dict, Any, Callable
|
|
10
|
+
from fireants.types import ItemOrList
|
|
11
|
+
|
|
12
|
+
@torch.jit.script
|
|
13
|
+
def gaussian_1d(
|
|
14
|
+
sigma: torch.Tensor, truncated: float = 4.0, approx: str = "erf", normalize: bool = True
|
|
15
|
+
) -> torch.Tensor:
|
|
16
|
+
"""
|
|
17
|
+
one dimensional Gaussian kernel.
|
|
18
|
+
Args:
|
|
19
|
+
sigma: std of the kernel
|
|
20
|
+
truncated: tail length
|
|
21
|
+
approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".
|
|
22
|
+
- ``erf`` approximation interpolates the error function;
|
|
23
|
+
- ``sampled`` uses a sampled Gaussian kernel;
|
|
24
|
+
- ``scalespace`` corresponds to
|
|
25
|
+
https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel
|
|
26
|
+
based on the modified Bessel functions.
|
|
27
|
+
normalize: whether to normalize the kernel with `kernel.sum()`.
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: When ``truncated`` is non-positive.
|
|
30
|
+
Returns:
|
|
31
|
+
1D torch tensor
|
|
32
|
+
"""
|
|
33
|
+
sigma = torch.as_tensor(sigma, dtype=torch.float, device=sigma.device if isinstance(sigma, torch.Tensor) else None)
|
|
34
|
+
device = sigma.device
|
|
35
|
+
if truncated <= 0.0:
|
|
36
|
+
raise ValueError(f"truncated must be positive, got {truncated}.")
|
|
37
|
+
tail = int(max(float(sigma) * truncated, 0.5) + 0.5)
|
|
38
|
+
if approx.lower() == "erf":
|
|
39
|
+
x = torch.arange(-tail, tail + 1, dtype=torch.float, device=device)
|
|
40
|
+
t = 0.70710678 / torch.abs(sigma)
|
|
41
|
+
out = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf())
|
|
42
|
+
out = out.clamp(min=0)
|
|
43
|
+
elif approx.lower() == "sampled":
|
|
44
|
+
x = torch.arange(-tail, tail + 1, dtype=torch.float, device=sigma.device)
|
|
45
|
+
out = torch.exp(-0.5 / (sigma * sigma) * x**2)
|
|
46
|
+
if not normalize: # compute the normalizer
|
|
47
|
+
out = out / (2.5066282 * sigma)
|
|
48
|
+
else:
|
|
49
|
+
raise NotImplementedError(f"Unsupported option: approx='{approx}'.")
|
|
50
|
+
return out / out.sum() if normalize else out # type: ignore
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@torch.jit.script
|
|
54
|
+
def make_rectangular_kernel(kernel_size: int) -> torch.Tensor:
|
|
55
|
+
return torch.ones(kernel_size)
|
|
56
|
+
|
|
57
|
+
@torch.jit.script
|
|
58
|
+
def make_triangular_kernel(kernel_size: int) -> torch.Tensor:
|
|
59
|
+
fsize = (kernel_size + 1) // 2
|
|
60
|
+
if fsize % 2 == 0:
|
|
61
|
+
fsize -= 1
|
|
62
|
+
f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize)
|
|
63
|
+
padding = (kernel_size - fsize) // 2 + fsize // 2
|
|
64
|
+
return F.conv1d(f, f, padding=padding).reshape(-1)
|
|
65
|
+
|
|
66
|
+
@torch.jit.script
|
|
67
|
+
def make_gaussian_kernel(kernel_size: int) -> torch.Tensor:
|
|
68
|
+
sigma = torch.tensor(kernel_size / 3.0)
|
|
69
|
+
kernel = gaussian_1d(sigma=sigma, truncated=(kernel_size // 2) * 1.0, approx="sampled", normalize=False) * (
|
|
70
|
+
2.5066282 * sigma
|
|
71
|
+
)
|
|
72
|
+
return kernel[:kernel_size]
|
|
73
|
+
|
|
74
|
+
@torch.jit.script
|
|
75
|
+
def _separable_filtering_conv(
|
|
76
|
+
input_: torch.Tensor,
|
|
77
|
+
kernels: List[torch.Tensor],
|
|
78
|
+
pad_mode: str,
|
|
79
|
+
spatial_dims: int,
|
|
80
|
+
paddings: List[int],
|
|
81
|
+
num_channels: int,
|
|
82
|
+
) -> torch.Tensor:
|
|
83
|
+
|
|
84
|
+
# re-write from recursive to non-recursive for torch.jit to work
|
|
85
|
+
# for d in range(spatial_dims-1, -1, -1):
|
|
86
|
+
for d in range(spatial_dims):
|
|
87
|
+
s = [1] * len(input_.shape)
|
|
88
|
+
s[d + 2] = -1
|
|
89
|
+
_kernel = kernels[d].reshape(s)
|
|
90
|
+
# if filter kernel is unity, don't convolve
|
|
91
|
+
if _kernel.numel() == 1 and _kernel[0] == 1:
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
_kernel = _kernel.repeat([num_channels, 1] + [1] * spatial_dims)
|
|
95
|
+
_padding = [0] * spatial_dims
|
|
96
|
+
_padding[d] = paddings[d]
|
|
97
|
+
_reversed_padding = _padding[::-1]
|
|
98
|
+
|
|
99
|
+
# translate padding for input to torch.nn.functional.pad
|
|
100
|
+
_reversed_padding_repeated_twice: list[list[int]] = [[p, p] for p in _reversed_padding]
|
|
101
|
+
_sum_reversed_padding_repeated_twice: list[int] = []
|
|
102
|
+
for p in _reversed_padding_repeated_twice:
|
|
103
|
+
_sum_reversed_padding_repeated_twice.extend(p)
|
|
104
|
+
# _sum_reversed_padding_repeated_twice: list[int] = sum(_reversed_padding_repeated_twice, [])
|
|
105
|
+
|
|
106
|
+
padded_input = F.pad(input_, _sum_reversed_padding_repeated_twice, mode=pad_mode)
|
|
107
|
+
# update input
|
|
108
|
+
if spatial_dims == 1:
|
|
109
|
+
input_ = F.conv1d(input=padded_input, weight=_kernel, groups=num_channels)
|
|
110
|
+
elif spatial_dims == 2:
|
|
111
|
+
input_ = F.conv2d(input=padded_input, weight=_kernel, groups=num_channels)
|
|
112
|
+
elif spatial_dims == 3:
|
|
113
|
+
input_ = F.conv3d(input=padded_input, weight=_kernel, groups=num_channels)
|
|
114
|
+
else:
|
|
115
|
+
raise NotImplementedError(f"Unsupported spatial_dims: {spatial_dims}.")
|
|
116
|
+
return input_
|
|
117
|
+
|
|
118
|
+
@torch.jit.script
|
|
119
|
+
def separable_filtering(x: torch.Tensor, kernels: ItemOrList[torch.Tensor], mode: str = "zeros") -> torch.Tensor:
|
|
120
|
+
"""
|
|
121
|
+
Apply 1-D convolutions along each spatial dimension of `x`.
|
|
122
|
+
Args:
|
|
123
|
+
x: the input image. must have shape (batch, channels, H[, W, ...]).
|
|
124
|
+
kernels: kernel along each spatial dimension.
|
|
125
|
+
could be a single kernel (duplicated for all spatial dimensions), or
|
|
126
|
+
a list of `spatial_dims` number of kernels.
|
|
127
|
+
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
|
|
128
|
+
or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.
|
|
129
|
+
Raises:
|
|
130
|
+
TypeError: When ``x`` is not a ``torch.Tensor``.
|
|
131
|
+
Examples:
|
|
132
|
+
.. code-block:: python
|
|
133
|
+
>>> import torch
|
|
134
|
+
>>> img = torch.randn(2, 4, 32, 32) # batch_size 2, channels 4, 32x32 2D images
|
|
135
|
+
# applying a [-1, 0, 1] filter along each of the spatial dimensions.
|
|
136
|
+
# the output shape is the same as the input shape.
|
|
137
|
+
>>> out = separable_filtering(img, torch.tensor((-1., 0., 1.)))
|
|
138
|
+
# applying `[-1, 0, 1]`, `[1, 0, -1]` filters along two spatial dimensions respectively.
|
|
139
|
+
# the output shape is the same as the input shape.
|
|
140
|
+
>>> out = separable_filtering(img, [torch.tensor((-1., 0., 1.)), torch.tensor((1., 0., -1.))])
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
if not isinstance(x, torch.Tensor):
|
|
144
|
+
raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")
|
|
145
|
+
|
|
146
|
+
spatial_dims = len(x.shape) - 2
|
|
147
|
+
if isinstance(kernels, torch.Tensor):
|
|
148
|
+
kernels = [kernels] * spatial_dims
|
|
149
|
+
_kernels = [s.to(x) for s in kernels]
|
|
150
|
+
_paddings = [(k.shape[0] - 1) // 2 for k in _kernels]
|
|
151
|
+
n_chs = x.shape[1]
|
|
152
|
+
pad_mode = "constant" if mode == "zeros" else mode
|
|
153
|
+
return _separable_filtering_conv(x, _kernels, pad_mode, spatial_dims, _paddings, n_chs)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# dict
|
|
157
|
+
kernel_dict = {
|
|
158
|
+
"rectangular": make_rectangular_kernel,
|
|
159
|
+
"triangular": make_triangular_kernel,
|
|
160
|
+
"gaussian": make_gaussian_kernel,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
class LocalNormalizedCrossCorrelationLoss(nn.Module):
|
|
164
|
+
"""
|
|
165
|
+
Local squared zero-normalized cross-correlation.
|
|
166
|
+
The loss is based on a moving kernel/window over the y_true/y_pred,
|
|
167
|
+
within the window the square of zncc is calculated.
|
|
168
|
+
The kernel can be a rectangular / triangular / gaussian window.
|
|
169
|
+
The final loss is the averaged loss over all windows.
|
|
170
|
+
Adapted from:
|
|
171
|
+
https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
|
|
172
|
+
DeepReg (https://github.com/DeepRegNet/DeepReg)
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
spatial_dims: int = 3,
|
|
178
|
+
kernel_size: int = 3,
|
|
179
|
+
kernel_type: str = "rectangular",
|
|
180
|
+
reduction: str = "mean",
|
|
181
|
+
smooth_nr: float = 1e-5,
|
|
182
|
+
smooth_dr: float = 1e-5,
|
|
183
|
+
unsigned: bool = True,
|
|
184
|
+
checkpointing: bool = False,
|
|
185
|
+
) -> None:
|
|
186
|
+
"""
|
|
187
|
+
Args:
|
|
188
|
+
spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3.
|
|
189
|
+
kernel_size: kernel spatial size, must be odd.
|
|
190
|
+
kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``.
|
|
191
|
+
reduction: {``"none"``, ``"mean"``, ``"sum"``}
|
|
192
|
+
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
|
|
193
|
+
- ``"none"``: no reduction will be applied.
|
|
194
|
+
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
|
|
195
|
+
- ``"sum"``: the output will be summed.
|
|
196
|
+
smooth_nr: a small constant added to the numerator to avoid nan.
|
|
197
|
+
smooth_dr: a small constant added to the denominator to avoid nan.
|
|
198
|
+
split: do we want to split computation across 2 GPUs? (if pred and target are on different GPUs)
|
|
199
|
+
default: False (assumes they are on same device and big enough to fit on one GPU)
|
|
200
|
+
"""
|
|
201
|
+
super().__init__()
|
|
202
|
+
self.ndim = spatial_dims
|
|
203
|
+
if self.ndim not in {1, 2, 3}:
|
|
204
|
+
raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported")
|
|
205
|
+
self.reduction = reduction
|
|
206
|
+
self.unsigned = unsigned
|
|
207
|
+
|
|
208
|
+
self.kernel_size = kernel_size
|
|
209
|
+
if self.kernel_size % 2 == 0:
|
|
210
|
+
raise ValueError(f"kernel_size must be odd, got {self.kernel_size}")
|
|
211
|
+
|
|
212
|
+
# _kernel = look_up_option(kernel_type, kernel_dict)
|
|
213
|
+
_kernel = kernel_dict[kernel_type]
|
|
214
|
+
self.kernel = _kernel(self.kernel_size)
|
|
215
|
+
self.kernel.requires_grad = False
|
|
216
|
+
self.kernel_nd, self.kernel_vol = self.get_kernel_vol() # get nD kernel and its volume
|
|
217
|
+
self.smooth_nr = float(smooth_nr)
|
|
218
|
+
self.smooth_dr = float(smooth_dr)
|
|
219
|
+
self.checkpointing = checkpointing
|
|
220
|
+
|
|
221
|
+
def get_kernel_vol(self):
|
|
222
|
+
vol = self.kernel
|
|
223
|
+
for _ in range(self.ndim - 1):
|
|
224
|
+
vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0))
|
|
225
|
+
return vol, torch.sum(vol)
|
|
226
|
+
|
|
227
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
228
|
+
"""
|
|
229
|
+
Args:
|
|
230
|
+
pred: the shape should be BNH[WD].
|
|
231
|
+
target: the shape should be BNH[WD].
|
|
232
|
+
Raises:
|
|
233
|
+
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
|
|
234
|
+
"""
|
|
235
|
+
if pred.ndim - 2 != self.ndim:
|
|
236
|
+
raise ValueError(f"expecting pred with {self.ndim} spatial dimensions, got pred of shape {pred.shape}")
|
|
237
|
+
if target.shape != pred.shape:
|
|
238
|
+
raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")
|
|
239
|
+
|
|
240
|
+
# sum over kernel
|
|
241
|
+
def cc_checkpoint_fn(target, pred, kernel, kernel_vol):
|
|
242
|
+
'''
|
|
243
|
+
This function is used to compute the intermediate results of the loss.
|
|
244
|
+
'''
|
|
245
|
+
t2, p2, tp = target * target, pred * pred, target * pred
|
|
246
|
+
kernel, kernel_vol = kernel.to(pred), kernel_vol.to(pred)
|
|
247
|
+
# kernel_nd = self.kernel_nd.to(pred)
|
|
248
|
+
kernels = [kernel] * self.ndim
|
|
249
|
+
kernels_t = kernels_p = kernels
|
|
250
|
+
kernel_vol_t = kernel_vol_p = kernel_vol
|
|
251
|
+
# compute intermediates
|
|
252
|
+
t_sum = separable_filtering(target, kernels=kernels_t)
|
|
253
|
+
p_sum = separable_filtering(pred, kernels=kernels_p)
|
|
254
|
+
t2_sum = separable_filtering(t2, kernels=kernels_t)
|
|
255
|
+
p2_sum = separable_filtering(p2, kernels=kernels_p)
|
|
256
|
+
tp_sum = separable_filtering(tp, kernels=kernels_t) # use target device's output
|
|
257
|
+
# average over kernel
|
|
258
|
+
t_avg = t_sum / kernel_vol_t
|
|
259
|
+
p_avg = p_sum / kernel_vol_p
|
|
260
|
+
# normalized cross correlation between t and p
|
|
261
|
+
# sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p]
|
|
262
|
+
# denoted by num / denom
|
|
263
|
+
# assume we sum over N values
|
|
264
|
+
# num = sum[t * p - mean[t] * p - t * mean[p] + mean[t] * mean[p]]
|
|
265
|
+
# = sum[t*p] - sum[t] * sum[p] / N * 2 + sum[t] * sum[p] / N
|
|
266
|
+
# = sum[t*p] - sum[t] * sum[p] / N
|
|
267
|
+
# = sum[t*p] - sum[t] * mean[p] = cross
|
|
268
|
+
# the following is actually squared ncc
|
|
269
|
+
cross = (tp_sum.to(pred) - p_avg * t_sum.to(pred)) # on pred device
|
|
270
|
+
t_var = torch.max(
|
|
271
|
+
t2_sum - t_avg * t_sum, torch.as_tensor(self.smooth_dr, dtype=t2_sum.dtype, device=t2_sum.device)
|
|
272
|
+
).to(pred)
|
|
273
|
+
p_var = torch.max(
|
|
274
|
+
p2_sum - p_avg * p_sum, torch.as_tensor(self.smooth_dr, dtype=p2_sum.dtype, device=p2_sum.device)
|
|
275
|
+
)
|
|
276
|
+
if self.unsigned:
|
|
277
|
+
ncc: torch.Tensor = (cross * cross + self.smooth_nr) / ((t_var * p_var) + self.smooth_dr)
|
|
278
|
+
else:
|
|
279
|
+
ncc: torch.Tensor = (cross + self.smooth_nr) / ((torch.sqrt(t_var) * torch.sqrt(p_var)) + self.smooth_dr)
|
|
280
|
+
return ncc
|
|
281
|
+
|
|
282
|
+
if self.checkpointing:
|
|
283
|
+
ncc = checkpoint(cc_checkpoint_fn, target, pred, self.kernel, self.kernel_vol)
|
|
284
|
+
else:
|
|
285
|
+
ncc = cc_checkpoint_fn(target, pred, self.kernel, self.kernel_vol)
|
|
286
|
+
|
|
287
|
+
if mask is not None:
|
|
288
|
+
maskmean = mask.flatten(2).mean(2) # [B, N]
|
|
289
|
+
for _ in range(self.ndim):
|
|
290
|
+
maskmean = maskmean.unsqueeze(-1) # [B, N, 1, 1, ...]
|
|
291
|
+
ncc = ncc * mask / maskmean
|
|
292
|
+
|
|
293
|
+
if self.reduction == 'sum':
|
|
294
|
+
return torch.sum(ncc).neg() # sum over the batch, channel and spatial ndims
|
|
295
|
+
if self.reduction == 'none':
|
|
296
|
+
return ncc.neg()
|
|
297
|
+
if self.reduction == 'mean':
|
|
298
|
+
return torch.mean(ncc).neg() # average over the batch, channel and spatial ndims
|
|
299
|
+
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
if __name__ == '__main__':
|
|
303
|
+
N = 64
|
|
304
|
+
img1 = torch.rand(1, 1, N, N, N).cuda()
|
|
305
|
+
img2 = torch.rand(1, 1, N, N, N).cuda()
|
|
306
|
+
# loss = torch.jit.script(LocalNormalizedCrossCorrelationLoss(3, kernel_type='rectangular', reduction='mean')).cuda()
|
|
307
|
+
loss = LocalNormalizedCrossCorrelationLoss(3, kernel_type='rectangular', reduction='mean').cuda()
|
|
308
|
+
total = 0
|
|
309
|
+
@torch.jit.script
|
|
310
|
+
def train(img1: torch.Tensor, img2: torch.Tensor, n: int) -> float:
|
|
311
|
+
total = 0.0
|
|
312
|
+
for i in range(n):
|
|
313
|
+
out = loss(img1, img2)
|
|
314
|
+
total += out.item()
|
|
315
|
+
return total
|
|
316
|
+
|
|
317
|
+
a = time()
|
|
318
|
+
# total = train(img1, img2, 200)
|
|
319
|
+
for i in range(200):
|
|
320
|
+
out = loss(img1, img2)
|
|
321
|
+
total += out.item()
|
|
322
|
+
print(time() - a)
|
|
323
|
+
print(total / 200)
|