fireants 0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fireants-0.1/.gitignore +17 -0
  2. fireants-0.1/PKG-INFO +42 -0
  3. fireants-0.1/README.md +24 -0
  4. fireants-0.1/fireants/__init__.py +0 -0
  5. fireants-0.1/fireants/io/__init__.py +1 -0
  6. fireants-0.1/fireants/io/image.py +122 -0
  7. fireants-0.1/fireants/losses/__init__.py +4 -0
  8. fireants-0.1/fireants/losses/cc.py +323 -0
  9. fireants-0.1/fireants/losses/mi.py +200 -0
  10. fireants-0.1/fireants/losses/mse.py +56 -0
  11. fireants-0.1/fireants/registration/abstract.py +64 -0
  12. fireants-0.1/fireants/registration/affine.py +141 -0
  13. fireants-0.1/fireants/registration/deformation/abstract.py +30 -0
  14. fireants-0.1/fireants/registration/deformation/compositive.py +150 -0
  15. fireants-0.1/fireants/registration/deformation/geodesic.py +137 -0
  16. fireants-0.1/fireants/registration/greedy.py +189 -0
  17. fireants-0.1/fireants/registration/logdemons.py +182 -0
  18. fireants-0.1/fireants/registration/optimizers/adam.py +126 -0
  19. fireants-0.1/fireants/registration/optimizers/sgd.py +131 -0
  20. fireants-0.1/fireants/registration/rigid.py +171 -0
  21. fireants-0.1/fireants/registration/syn.py +214 -0
  22. fireants-0.1/fireants/scripts/analyse_raytune.py +51 -0
  23. fireants-0.1/fireants/scripts/evaluate_metrics.py +43 -0
  24. fireants-0.1/fireants/scripts/evalutils.py +1 -0
  25. fireants-0.1/fireants/scripts/lookup_tables.py +1 -0
  26. fireants-0.1/fireants/scripts/oasis.py +185 -0
  27. fireants-0.1/fireants/scripts/oasis_test.py +65 -0
  28. fireants-0.1/fireants/scripts/test_cumc12.py +111 -0
  29. fireants-0.1/fireants/scripts/test_ibsr.py +108 -0
  30. fireants-0.1/fireants/scripts/test_lpba40.py +128 -0
  31. fireants-0.1/fireants/scripts/test_mgh10.py +111 -0
  32. fireants-0.1/fireants/scripts/tune_empire10.py +254 -0
  33. fireants-0.1/fireants/scripts/tune_lpba40.py +168 -0
  34. fireants-0.1/fireants/tests/loadtest.py +4 -0
  35. fireants-0.1/fireants/types.py +7 -0
  36. fireants-0.1/fireants/utils/__init__.py +0 -0
  37. fireants-0.1/fireants/utils/globals.py +1 -0
  38. fireants-0.1/fireants/utils/imageutils.py +285 -0
  39. fireants-0.1/fireants/utils/opticalflow.py +91 -0
  40. fireants-0.1/fireants/utils/util.py +124 -0
  41. fireants-0.1/pyproject.toml +25 -0
@@ -0,0 +1,17 @@
1
+ *egg-info
2
+ *pdf
3
+ *svg
4
+ dist/
5
+ build/
6
+ **pkl
7
+ **/__pycache__
8
+ **/baselines
9
+ **/*.pkl
10
+ **/.ipynb_checkpoints/
11
+ fireants/notebooks/images/
12
+ **/*log.txt
13
+ **/*.nii.gz
14
+ fireants/scripts/misc
15
+ fireants/notebooks
16
+ fireants/baselines
17
+ .req.txt
fireants-0.1/PKG-INFO ADDED
@@ -0,0 +1,42 @@
1
+ Metadata-Version: 2.1
2
+ Name: fireants
3
+ Version: 0.1
4
+ Summary: FireANTs: Adaptive Riemannian Optimization for Multi-Scale Diffeomorphic Registration
5
+ Author: Rohit Jena, Pratik Chaudhari, James C. Gee
6
+ Requires-Python: >=3.7
7
+ Requires-Dist: matplotlib
8
+ Requires-Dist: nibabel==4.0.2
9
+ Requires-Dist: numpy
10
+ Requires-Dist: pandas==1.3.5
11
+ Requires-Dist: scikit-image
12
+ Requires-Dist: scipy
13
+ Requires-Dist: simpleitk==2.2.1
14
+ Requires-Dist: torch==1.13.1
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: typing
17
+ Description-Content-Type: text/markdown
18
+
19
+ # :fire: FireANTs: Adaptive Riemannian Optimization for Multi-Scale Diffeomorphic Registration
20
+
21
+ The FireANTs library is a lightweight registration package for Riemannian diffeomorphic registration on GPUs.
22
+
23
+ ## Installation
24
+ To use the FireANTs package, you can either clone the repository and install the package locally or install the package directly from PyPI.
25
+ We recommend using a fresh Anaconda/Miniconda environment to install the package.
26
+ ```
27
+ conda create -n fireants python=3.7
28
+ ```
29
+
30
+ To install FireANTs locally:
31
+ ```
32
+ git clone https://github.com/rohitrango/fireants
33
+ cd fireants
34
+ pip install -e .
35
+ ```
36
+
37
+ Or to install from PyPI:
38
+ ```
39
+ pip install fireants
40
+ ```
41
+
42
+ ## Tutorial
fireants-0.1/README.md ADDED
@@ -0,0 +1,24 @@
1
+ # :fire: FireANTs: Adaptive Riemannian Optimization for Multi-Scale Diffeomorphic Registration
2
+
3
+ The FireANTs library is a lightweight registration package for Riemannian diffeomorphic registration on GPUs.
4
+
5
+ ## Installation
6
+ To use the FireANTs package, you can either clone the repository and install the package locally or install the package directly from PyPI.
7
+ We recommend using a fresh Anaconda/Miniconda environment to install the package.
8
+ ```
9
+ conda create -n fireants python=3.7
10
+ ```
11
+
12
+ To install FireANTs locally:
13
+ ```
14
+ git clone https://github.com/rohitrango/fireants
15
+ cd fireants
16
+ pip install -e .
17
+ ```
18
+
19
+ Or to install from PyPI:
20
+ ```
21
+ pip install fireants
22
+ ```
23
+
24
+ ## Tutorial
File without changes
@@ -0,0 +1 @@
1
+ from fireants.io.image import Image, BatchedImages
@@ -0,0 +1,122 @@
1
+ import torch
2
+ import SimpleITK as sitk
3
+ import numpy as np
4
+ from typing import Any, Union, List
5
+ from time import time
6
+ from fireants.types import devicetype
7
+ from fireants.utils.imageutils import integer_to_onehot
8
+
9
+ class Image:
10
+ '''
11
+ TODO: Documentation here
12
+ '''
13
+ def __init__(self, itk_image: sitk.SimpleITK.Image, device: devicetype = 'cuda',
14
+ is_segmentation=False, max_seg_label=None, background_seg_label=0, seg_preprocessor=lambda x: x) -> None:
15
+ self.itk_image = itk_image
16
+ # check for segmentation parameters
17
+ # if `is_segmentation` is False, then just treat this as a float image
18
+ if not is_segmentation:
19
+ self.array = torch.from_numpy(sitk.GetArrayFromImage(itk_image).astype(float)).to(device).float()
20
+ self.array = self.array[None, None] # TODO: Change it to support multichannel images, right now just batchify and add a dummy channel to it
21
+ channels = itk_image.GetNumberOfComponentsPerPixel()
22
+ self.channels = channels
23
+ assert channels == 1, "Only single channel images supported"
24
+ else:
25
+ array = torch.from_numpy(sitk.GetArrayFromImage(itk_image).astype(int)).to(device).long()
26
+ # preprocess segmentation if provided by user
27
+ array = seg_preprocessor(array)
28
+ if max_seg_label is not None:
29
+ array[array > max_seg_label] = background_seg_label
30
+ array = integer_to_onehot(array, background_label=background_seg_label, max_label=max_seg_label)[None] # []
31
+ self.array = array.float()
32
+ self.channels = array.shape[1]
33
+ # initialize matrix for pixel to physical
34
+ dims = itk_image.GetDimension()
35
+ self.dims = dims
36
+ if dims not in [2, 3]:
37
+ raise NotImplementedError("Image class only supports 2D/3D images.")
38
+ px2phy = np.eye(dims+1)
39
+ px2phy[:dims, -1] = itk_image.GetOrigin()
40
+ px2phy[:dims, :dims] = np.array(itk_image.GetDirection()).reshape(dims, dims)
41
+ px2phy[:dims, :dims] = px2phy[:dims, :dims] * np.array(itk_image.GetSpacing())[None]
42
+ # generate mapping from torch to px
43
+ torch2px = np.eye(dims+1)
44
+ scaleterm = (np.array(itk_image.GetSize())-1)*0.5
45
+ torch2px[:dims, :dims] = np.diag(scaleterm)
46
+ torch2px[:dims, -1] = scaleterm
47
+ # save the mapping from physical to torch and vice versa
48
+ self.torch2phy = torch.from_numpy(np.matmul(px2phy, torch2px)).to(device).float().unsqueeze(0)
49
+ self.phy2torch = torch.inverse(self.torch2phy[0]).float().unsqueeze(0)
50
+ # also save intermediates just in case (as numpy arrays)
51
+ self._torch2px = torch2px
52
+ self._px2phy = px2phy
53
+ self.device = device
54
+
55
+ @classmethod
56
+ def load_file(cls, image_path:str, *args, **kwargs) -> 'Image':
57
+ itk_image = sitk.ReadImage(image_path)
58
+ return cls(itk_image, *args, **kwargs)
59
+
60
+
61
+ class BatchedImages:
62
+ '''
63
+ Class for batched images
64
+ '''
65
+ def __init__(self, images: Union[Image, List[Image]]) -> None:
66
+ if isinstance(images, Image):
67
+ images = [images]
68
+ self.images = images
69
+ if len(self.images) == 0:
70
+ raise ValueError("BatchedImages must have at least one image")
71
+ for image in self.images:
72
+ if not isinstance(image, Image):
73
+ raise TypeError("All images must be of type Image")
74
+ shapes = [x.array.shape for x in self.images]
75
+ if all([x == shapes[0] for x in shapes]):
76
+ self.shape = shapes[0]
77
+ else:
78
+ raise ValueError("All images must have the same shape")
79
+ self.n_images = len(self.images)
80
+ self.interpolate_mode = 'bilinear' if self.images[0] == 2 else 'trilinear'
81
+
82
+ def __call__(self):
83
+ # get batch of images
84
+ return torch.cat([x.array for x in self.images], dim=0)
85
+
86
+ @property
87
+ def device(self):
88
+ return self.images[0].device
89
+
90
+ @property
91
+ def dims(self):
92
+ return self.images[0].dims
93
+
94
+ def size(self):
95
+ return self.n_images
96
+
97
+ def shape(self):
98
+ shape = self.images[0].shape
99
+ shape[0] = self.n_images
100
+ return shape
101
+
102
+ def get_torch2phy(self):
103
+ return torch.cat([x.torch2phy for x in self.images], dim=0)
104
+
105
+ def get_phy2torch(self):
106
+ return torch.cat([x.phy2torch for x in self.images], dim=0)
107
+
108
+
109
+ if __name__ == '__main__':
110
+ # image = Image.load_file('/data/BRATS2021/training/BraTS2021_00598/BraTS2021_00598_t1.nii.gz')
111
+ # print(image.torch2phy)
112
+ # image2 = Image.load_file('/data/BRATS2021/training/BraTS2021_00599/BraTS2021_00599_t1.nii.gz')
113
+ # batch = BatchedImages([image, image2])
114
+ # print(batch().shape)
115
+ # print(batch.get_torch2phy().shape)
116
+ from glob import glob
117
+ files = sorted(glob("/data/IBSR_braindata/IBSR_01/*nii.gz"))
118
+ image = Image.load_file(files[2])
119
+ print(image.array.shape, image.array.min(), image.array.max())
120
+ # get label
121
+ label = Image.load_file(files[-1], is_segmentation=True)
122
+ print(label.array.shape, label.array.min(), label.array.max())
@@ -0,0 +1,4 @@
1
+ from .mi import GlobalMutualInformationLoss
2
+ from .cc import LocalNormalizedCrossCorrelationLoss
3
+
4
+ __all__ = ['GlobalMutualInformationLoss', 'LocalNormalizedCrossCorrelationLoss']
@@ -0,0 +1,323 @@
1
+ '''
2
+ Cross correlation
3
+ '''
4
+ from time import time, sleep
5
+ import torch
6
+ from torch.utils.checkpoint import checkpoint
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from typing import Union, Tuple, List, Optional, Dict, Any, Callable
10
+ from fireants.types import ItemOrList
11
+
12
+ @torch.jit.script
13
+ def gaussian_1d(
14
+ sigma: torch.Tensor, truncated: float = 4.0, approx: str = "erf", normalize: bool = True
15
+ ) -> torch.Tensor:
16
+ """
17
+ one dimensional Gaussian kernel.
18
+ Args:
19
+ sigma: std of the kernel
20
+ truncated: tail length
21
+ approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".
22
+ - ``erf`` approximation interpolates the error function;
23
+ - ``sampled`` uses a sampled Gaussian kernel;
24
+ - ``scalespace`` corresponds to
25
+ https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel
26
+ based on the modified Bessel functions.
27
+ normalize: whether to normalize the kernel with `kernel.sum()`.
28
+ Raises:
29
+ ValueError: When ``truncated`` is non-positive.
30
+ Returns:
31
+ 1D torch tensor
32
+ """
33
+ sigma = torch.as_tensor(sigma, dtype=torch.float, device=sigma.device if isinstance(sigma, torch.Tensor) else None)
34
+ device = sigma.device
35
+ if truncated <= 0.0:
36
+ raise ValueError(f"truncated must be positive, got {truncated}.")
37
+ tail = int(max(float(sigma) * truncated, 0.5) + 0.5)
38
+ if approx.lower() == "erf":
39
+ x = torch.arange(-tail, tail + 1, dtype=torch.float, device=device)
40
+ t = 0.70710678 / torch.abs(sigma)
41
+ out = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf())
42
+ out = out.clamp(min=0)
43
+ elif approx.lower() == "sampled":
44
+ x = torch.arange(-tail, tail + 1, dtype=torch.float, device=sigma.device)
45
+ out = torch.exp(-0.5 / (sigma * sigma) * x**2)
46
+ if not normalize: # compute the normalizer
47
+ out = out / (2.5066282 * sigma)
48
+ else:
49
+ raise NotImplementedError(f"Unsupported option: approx='{approx}'.")
50
+ return out / out.sum() if normalize else out # type: ignore
51
+
52
+
53
+ @torch.jit.script
54
+ def make_rectangular_kernel(kernel_size: int) -> torch.Tensor:
55
+ return torch.ones(kernel_size)
56
+
57
+ @torch.jit.script
58
+ def make_triangular_kernel(kernel_size: int) -> torch.Tensor:
59
+ fsize = (kernel_size + 1) // 2
60
+ if fsize % 2 == 0:
61
+ fsize -= 1
62
+ f = torch.ones((1, 1, fsize), dtype=torch.float).div(fsize)
63
+ padding = (kernel_size - fsize) // 2 + fsize // 2
64
+ return F.conv1d(f, f, padding=padding).reshape(-1)
65
+
66
+ @torch.jit.script
67
+ def make_gaussian_kernel(kernel_size: int) -> torch.Tensor:
68
+ sigma = torch.tensor(kernel_size / 3.0)
69
+ kernel = gaussian_1d(sigma=sigma, truncated=(kernel_size // 2) * 1.0, approx="sampled", normalize=False) * (
70
+ 2.5066282 * sigma
71
+ )
72
+ return kernel[:kernel_size]
73
+
74
+ @torch.jit.script
75
+ def _separable_filtering_conv(
76
+ input_: torch.Tensor,
77
+ kernels: List[torch.Tensor],
78
+ pad_mode: str,
79
+ spatial_dims: int,
80
+ paddings: List[int],
81
+ num_channels: int,
82
+ ) -> torch.Tensor:
83
+
84
+ # re-write from recursive to non-recursive for torch.jit to work
85
+ # for d in range(spatial_dims-1, -1, -1):
86
+ for d in range(spatial_dims):
87
+ s = [1] * len(input_.shape)
88
+ s[d + 2] = -1
89
+ _kernel = kernels[d].reshape(s)
90
+ # if filter kernel is unity, don't convolve
91
+ if _kernel.numel() == 1 and _kernel[0] == 1:
92
+ continue
93
+
94
+ _kernel = _kernel.repeat([num_channels, 1] + [1] * spatial_dims)
95
+ _padding = [0] * spatial_dims
96
+ _padding[d] = paddings[d]
97
+ _reversed_padding = _padding[::-1]
98
+
99
+ # translate padding for input to torch.nn.functional.pad
100
+ _reversed_padding_repeated_twice: list[list[int]] = [[p, p] for p in _reversed_padding]
101
+ _sum_reversed_padding_repeated_twice: list[int] = []
102
+ for p in _reversed_padding_repeated_twice:
103
+ _sum_reversed_padding_repeated_twice.extend(p)
104
+ # _sum_reversed_padding_repeated_twice: list[int] = sum(_reversed_padding_repeated_twice, [])
105
+
106
+ padded_input = F.pad(input_, _sum_reversed_padding_repeated_twice, mode=pad_mode)
107
+ # update input
108
+ if spatial_dims == 1:
109
+ input_ = F.conv1d(input=padded_input, weight=_kernel, groups=num_channels)
110
+ elif spatial_dims == 2:
111
+ input_ = F.conv2d(input=padded_input, weight=_kernel, groups=num_channels)
112
+ elif spatial_dims == 3:
113
+ input_ = F.conv3d(input=padded_input, weight=_kernel, groups=num_channels)
114
+ else:
115
+ raise NotImplementedError(f"Unsupported spatial_dims: {spatial_dims}.")
116
+ return input_
117
+
118
+ @torch.jit.script
119
+ def separable_filtering(x: torch.Tensor, kernels: ItemOrList[torch.Tensor], mode: str = "zeros") -> torch.Tensor:
120
+ """
121
+ Apply 1-D convolutions along each spatial dimension of `x`.
122
+ Args:
123
+ x: the input image. must have shape (batch, channels, H[, W, ...]).
124
+ kernels: kernel along each spatial dimension.
125
+ could be a single kernel (duplicated for all spatial dimensions), or
126
+ a list of `spatial_dims` number of kernels.
127
+ mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
128
+ or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.
129
+ Raises:
130
+ TypeError: When ``x`` is not a ``torch.Tensor``.
131
+ Examples:
132
+ .. code-block:: python
133
+ >>> import torch
134
+ >>> img = torch.randn(2, 4, 32, 32) # batch_size 2, channels 4, 32x32 2D images
135
+ # applying a [-1, 0, 1] filter along each of the spatial dimensions.
136
+ # the output shape is the same as the input shape.
137
+ >>> out = separable_filtering(img, torch.tensor((-1., 0., 1.)))
138
+ # applying `[-1, 0, 1]`, `[1, 0, -1]` filters along two spatial dimensions respectively.
139
+ # the output shape is the same as the input shape.
140
+ >>> out = separable_filtering(img, [torch.tensor((-1., 0., 1.)), torch.tensor((1., 0., -1.))])
141
+ """
142
+
143
+ if not isinstance(x, torch.Tensor):
144
+ raise TypeError(f"x must be a torch.Tensor but is {type(x).__name__}.")
145
+
146
+ spatial_dims = len(x.shape) - 2
147
+ if isinstance(kernels, torch.Tensor):
148
+ kernels = [kernels] * spatial_dims
149
+ _kernels = [s.to(x) for s in kernels]
150
+ _paddings = [(k.shape[0] - 1) // 2 for k in _kernels]
151
+ n_chs = x.shape[1]
152
+ pad_mode = "constant" if mode == "zeros" else mode
153
+ return _separable_filtering_conv(x, _kernels, pad_mode, spatial_dims, _paddings, n_chs)
154
+
155
+
156
+ # dict
157
+ kernel_dict = {
158
+ "rectangular": make_rectangular_kernel,
159
+ "triangular": make_triangular_kernel,
160
+ "gaussian": make_gaussian_kernel,
161
+ }
162
+
163
+ class LocalNormalizedCrossCorrelationLoss(nn.Module):
164
+ """
165
+ Local squared zero-normalized cross-correlation.
166
+ The loss is based on a moving kernel/window over the y_true/y_pred,
167
+ within the window the square of zncc is calculated.
168
+ The kernel can be a rectangular / triangular / gaussian window.
169
+ The final loss is the averaged loss over all windows.
170
+ Adapted from:
171
+ https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py
172
+ DeepReg (https://github.com/DeepRegNet/DeepReg)
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ spatial_dims: int = 3,
178
+ kernel_size: int = 3,
179
+ kernel_type: str = "rectangular",
180
+ reduction: str = "mean",
181
+ smooth_nr: float = 1e-5,
182
+ smooth_dr: float = 1e-5,
183
+ unsigned: bool = True,
184
+ checkpointing: bool = False,
185
+ ) -> None:
186
+ """
187
+ Args:
188
+ spatial_dims: number of spatial dimensions, {``1``, ``2``, ``3``}. Defaults to 3.
189
+ kernel_size: kernel spatial size, must be odd.
190
+ kernel_type: {``"rectangular"``, ``"triangular"``, ``"gaussian"``}. Defaults to ``"rectangular"``.
191
+ reduction: {``"none"``, ``"mean"``, ``"sum"``}
192
+ Specifies the reduction to apply to the output. Defaults to ``"mean"``.
193
+ - ``"none"``: no reduction will be applied.
194
+ - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
195
+ - ``"sum"``: the output will be summed.
196
+ smooth_nr: a small constant added to the numerator to avoid nan.
197
+ smooth_dr: a small constant added to the denominator to avoid nan.
198
+ split: do we want to split computation across 2 GPUs? (if pred and target are on different GPUs)
199
+ default: False (assumes they are on same device and big enough to fit on one GPU)
200
+ """
201
+ super().__init__()
202
+ self.ndim = spatial_dims
203
+ if self.ndim not in {1, 2, 3}:
204
+ raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported")
205
+ self.reduction = reduction
206
+ self.unsigned = unsigned
207
+
208
+ self.kernel_size = kernel_size
209
+ if self.kernel_size % 2 == 0:
210
+ raise ValueError(f"kernel_size must be odd, got {self.kernel_size}")
211
+
212
+ # _kernel = look_up_option(kernel_type, kernel_dict)
213
+ _kernel = kernel_dict[kernel_type]
214
+ self.kernel = _kernel(self.kernel_size)
215
+ self.kernel.requires_grad = False
216
+ self.kernel_nd, self.kernel_vol = self.get_kernel_vol() # get nD kernel and its volume
217
+ self.smooth_nr = float(smooth_nr)
218
+ self.smooth_dr = float(smooth_dr)
219
+ self.checkpointing = checkpointing
220
+
221
+ def get_kernel_vol(self):
222
+ vol = self.kernel
223
+ for _ in range(self.ndim - 1):
224
+ vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0))
225
+ return vol, torch.sum(vol)
226
+
227
+ def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
228
+ """
229
+ Args:
230
+ pred: the shape should be BNH[WD].
231
+ target: the shape should be BNH[WD].
232
+ Raises:
233
+ ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
234
+ """
235
+ if pred.ndim - 2 != self.ndim:
236
+ raise ValueError(f"expecting pred with {self.ndim} spatial dimensions, got pred of shape {pred.shape}")
237
+ if target.shape != pred.shape:
238
+ raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")
239
+
240
+ # sum over kernel
241
+ def cc_checkpoint_fn(target, pred, kernel, kernel_vol):
242
+ '''
243
+ This function is used to compute the intermediate results of the loss.
244
+ '''
245
+ t2, p2, tp = target * target, pred * pred, target * pred
246
+ kernel, kernel_vol = kernel.to(pred), kernel_vol.to(pred)
247
+ # kernel_nd = self.kernel_nd.to(pred)
248
+ kernels = [kernel] * self.ndim
249
+ kernels_t = kernels_p = kernels
250
+ kernel_vol_t = kernel_vol_p = kernel_vol
251
+ # compute intermediates
252
+ t_sum = separable_filtering(target, kernels=kernels_t)
253
+ p_sum = separable_filtering(pred, kernels=kernels_p)
254
+ t2_sum = separable_filtering(t2, kernels=kernels_t)
255
+ p2_sum = separable_filtering(p2, kernels=kernels_p)
256
+ tp_sum = separable_filtering(tp, kernels=kernels_t) # use target device's output
257
+ # average over kernel
258
+ t_avg = t_sum / kernel_vol_t
259
+ p_avg = p_sum / kernel_vol_p
260
+ # normalized cross correlation between t and p
261
+ # sum[(t - mean[t]) * (p - mean[p])] / std[t] / std[p]
262
+ # denoted by num / denom
263
+ # assume we sum over N values
264
+ # num = sum[t * p - mean[t] * p - t * mean[p] + mean[t] * mean[p]]
265
+ # = sum[t*p] - sum[t] * sum[p] / N * 2 + sum[t] * sum[p] / N
266
+ # = sum[t*p] - sum[t] * sum[p] / N
267
+ # = sum[t*p] - sum[t] * mean[p] = cross
268
+ # the following is actually squared ncc
269
+ cross = (tp_sum.to(pred) - p_avg * t_sum.to(pred)) # on pred device
270
+ t_var = torch.max(
271
+ t2_sum - t_avg * t_sum, torch.as_tensor(self.smooth_dr, dtype=t2_sum.dtype, device=t2_sum.device)
272
+ ).to(pred)
273
+ p_var = torch.max(
274
+ p2_sum - p_avg * p_sum, torch.as_tensor(self.smooth_dr, dtype=p2_sum.dtype, device=p2_sum.device)
275
+ )
276
+ if self.unsigned:
277
+ ncc: torch.Tensor = (cross * cross + self.smooth_nr) / ((t_var * p_var) + self.smooth_dr)
278
+ else:
279
+ ncc: torch.Tensor = (cross + self.smooth_nr) / ((torch.sqrt(t_var) * torch.sqrt(p_var)) + self.smooth_dr)
280
+ return ncc
281
+
282
+ if self.checkpointing:
283
+ ncc = checkpoint(cc_checkpoint_fn, target, pred, self.kernel, self.kernel_vol)
284
+ else:
285
+ ncc = cc_checkpoint_fn(target, pred, self.kernel, self.kernel_vol)
286
+
287
+ if mask is not None:
288
+ maskmean = mask.flatten(2).mean(2) # [B, N]
289
+ for _ in range(self.ndim):
290
+ maskmean = maskmean.unsqueeze(-1) # [B, N, 1, 1, ...]
291
+ ncc = ncc * mask / maskmean
292
+
293
+ if self.reduction == 'sum':
294
+ return torch.sum(ncc).neg() # sum over the batch, channel and spatial ndims
295
+ if self.reduction == 'none':
296
+ return ncc.neg()
297
+ if self.reduction == 'mean':
298
+ return torch.mean(ncc).neg() # average over the batch, channel and spatial ndims
299
+ raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
300
+
301
+
302
+ if __name__ == '__main__':
303
+ N = 64
304
+ img1 = torch.rand(1, 1, N, N, N).cuda()
305
+ img2 = torch.rand(1, 1, N, N, N).cuda()
306
+ # loss = torch.jit.script(LocalNormalizedCrossCorrelationLoss(3, kernel_type='rectangular', reduction='mean')).cuda()
307
+ loss = LocalNormalizedCrossCorrelationLoss(3, kernel_type='rectangular', reduction='mean').cuda()
308
+ total = 0
309
+ @torch.jit.script
310
+ def train(img1: torch.Tensor, img2: torch.Tensor, n: int) -> float:
311
+ total = 0.0
312
+ for i in range(n):
313
+ out = loss(img1, img2)
314
+ total += out.item()
315
+ return total
316
+
317
+ a = time()
318
+ # total = train(img1, img2, 200)
319
+ for i in range(200):
320
+ out = loss(img1, img2)
321
+ total += out.item()
322
+ print(time() - a)
323
+ print(total / 200)