cornucopia 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornucopia/__init__.py +73 -0
- cornucopia/base.py +1915 -0
- cornucopia/baseutils.py +575 -0
- cornucopia/contrast.py +260 -0
- cornucopia/ctx.py +25 -0
- cornucopia/fov.py +707 -0
- cornucopia/geometric.py +2068 -0
- cornucopia/intensity.py +1358 -0
- cornucopia/io.py +161 -0
- cornucopia/kspace.py +505 -0
- cornucopia/labels.py +1872 -0
- cornucopia/noise.py +508 -0
- cornucopia/psf.py +463 -0
- cornucopia/qmri.py +1288 -0
- cornucopia/random.py +1480 -0
- cornucopia/special.py +159 -0
- cornucopia/synth.py +708 -0
- cornucopia/tests/__init__.py +0 -0
- cornucopia/tests/test_backward_geometric.py +173 -0
- cornucopia/tests/test_backward_intensity.py +243 -0
- cornucopia/tests/test_backward_kspace.py +115 -0
- cornucopia/tests/test_backward_noise.py +169 -0
- cornucopia/tests/test_backward_psf.py +142 -0
- cornucopia/tests/test_backward_qmri.py +249 -0
- cornucopia/tests/test_backward_random.py +44 -0
- cornucopia/tests/test_backward_synth.py +72 -0
- cornucopia/tests/test_base.py +401 -0
- cornucopia/tests/test_geometric.py +26 -0
- cornucopia/tests/test_intensity.py +9 -0
- cornucopia/tests/test_random.py +722 -0
- cornucopia/tests/test_run_contrast.py +28 -0
- cornucopia/tests/test_run_fov.py +132 -0
- cornucopia/tests/test_run_geometric.py +157 -0
- cornucopia/tests/test_run_intensity.py +192 -0
- cornucopia/tests/test_run_kspace.py +70 -0
- cornucopia/tests/test_run_labels.py +224 -0
- cornucopia/tests/test_run_noise.py +127 -0
- cornucopia/tests/test_run_psf.py +115 -0
- cornucopia/tests/test_run_qmri.py +114 -0
- cornucopia/tests/test_run_synth.py +67 -0
- cornucopia/typing.py +97 -0
- cornucopia/utils/__init__.py +0 -0
- cornucopia/utils/b0.py +745 -0
- cornucopia/utils/bounds.py +412 -0
- cornucopia/utils/compat.py +47 -0
- cornucopia/utils/conv.py +305 -0
- cornucopia/utils/gmm.py +169 -0
- cornucopia/utils/indexing.py +911 -0
- cornucopia/utils/io.py +258 -0
- cornucopia/utils/jit.py +128 -0
- cornucopia/utils/kernels.py +288 -0
- cornucopia/utils/morpho.py +234 -0
- cornucopia/utils/mrf.py +574 -0
- cornucopia/utils/padding.py +173 -0
- cornucopia/utils/patch.py +302 -0
- cornucopia/utils/pool.py +282 -0
- cornucopia/utils/py.py +348 -0
- cornucopia/utils/smart_inplace.py +163 -0
- cornucopia/utils/version.py +57 -0
- cornucopia/utils/warps.py +606 -0
- cornucopia-0.0.0.dist-info/METADATA +92 -0
- cornucopia-0.0.0.dist-info/RECORD +65 -0
- cornucopia-0.0.0.dist-info/WHEEL +5 -0
- cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
- cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/io.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""This module contains transforms that load data from disk."""
|
|
2
|
+
__all__ = ['ToTensorTransform', 'LoadTransform']
|
|
3
|
+
# stdlib
|
|
4
|
+
import os.path
|
|
5
|
+
from os import PathLike
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
# dependencies
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
import typing_extensions as tx
|
|
12
|
+
|
|
13
|
+
# internals
|
|
14
|
+
from .base import FinalTransform
|
|
15
|
+
from .utils.io import loaders
|
|
16
|
+
from . import typing as cct
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ToTensorTransform(FinalTransform):
|
|
20
|
+
"""Convert to Tensor (or to other dtype/device)"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
ndim: tx.Optional[int] = None,
|
|
25
|
+
dtype: tx.Optional[torch.dtype] = None,
|
|
26
|
+
device: tx.Optional[cct.TorchDevice] = None,
|
|
27
|
+
**kwargs
|
|
28
|
+
) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
ndim : int, optional
|
|
33
|
+
Number of spatial dimensions (default: guess from data)
|
|
34
|
+
dtype : torch.dtype, optional
|
|
35
|
+
Returned data type (default: keep same)
|
|
36
|
+
device : torch.device, optional
|
|
37
|
+
Returned device (default: keep same)
|
|
38
|
+
|
|
39
|
+
Other Parameters
|
|
40
|
+
----------------
|
|
41
|
+
returns, append, prefix, include, exclude, consume
|
|
42
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__(**kwargs)
|
|
45
|
+
self.dim = ndim
|
|
46
|
+
self.dtype = dtype
|
|
47
|
+
self.device = device
|
|
48
|
+
|
|
49
|
+
def _xform(self, x: Tensor) -> Tensor:
|
|
50
|
+
x = torch.as_tensor(x, dtype=self.dtype, device=self.device).squeeze()
|
|
51
|
+
if self.dim:
|
|
52
|
+
for _ in range(max(0, self.dim + 1 - x.ndim)):
|
|
53
|
+
x = x[None]
|
|
54
|
+
if x.ndim > self.dim + 1:
|
|
55
|
+
raise ValueError(f'Too many dimensions: '
|
|
56
|
+
f'{x.ndim} > 1 + {self.dim}')
|
|
57
|
+
return x
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class LoadTransform(FinalTransform):
|
|
61
|
+
"""
|
|
62
|
+
Load data from disk.
|
|
63
|
+
|
|
64
|
+
Available loaders are:
|
|
65
|
+
|
|
66
|
+
- `BabelLoader`: for medical image formats (nifti, mgz, minc, etc.)
|
|
67
|
+
- `TiffLoader`: for TIFF files (including multi-page)
|
|
68
|
+
- `PillowLoader`: for common image formats (png, jpg, etc., with optional rot90)
|
|
69
|
+
- `NumpyLoader`: for .npy and .npz files (with optional field name)
|
|
70
|
+
|
|
71
|
+
Custom loaders can be added by registering them in
|
|
72
|
+
`cornucopia.utils.io.loaders`.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
ndim: tx.Optional[int] = None,
|
|
78
|
+
dtype: tx.Optional[torch.dtype] = None,
|
|
79
|
+
*,
|
|
80
|
+
device: tx.Optional[cct.TorchDevice] = None,
|
|
81
|
+
returns: tx.Optional[cct.ReturnsT] = None,
|
|
82
|
+
append: cct.AppendT = False,
|
|
83
|
+
prefix: cct.PrefixT = False,
|
|
84
|
+
include: tx.Optional[cct.IncludeT] = None,
|
|
85
|
+
exclude: tx.Optional[cct.ExcludeT] = None,
|
|
86
|
+
consume: tx.Optional[cct.ConsumeT] = None,
|
|
87
|
+
**kwargs
|
|
88
|
+
) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
ndim : int | None
|
|
93
|
+
Number of spatial dimensions (default: guess from file)
|
|
94
|
+
dtype : torch.dtype | str | None
|
|
95
|
+
Data type (default: guess from file)
|
|
96
|
+
device : torch.device | str | None
|
|
97
|
+
Device on which to load data (default: cpu)
|
|
98
|
+
|
|
99
|
+
Other Parameters
|
|
100
|
+
------------------
|
|
101
|
+
to_ras : bool, default=True
|
|
102
|
+
Reorient data so that it has a RAS layout.
|
|
103
|
+
Only used by `BabelLoader`.
|
|
104
|
+
rot90 : bool, default=True
|
|
105
|
+
Rotate by 90 degrees in-plane.
|
|
106
|
+
Only used by `PillowLoader`.
|
|
107
|
+
field : str, default="arr_0"
|
|
108
|
+
Field to load from a npz file.
|
|
109
|
+
Only used by `NumpyLoader`.
|
|
110
|
+
returns, append, prefix, include, exclude, consume
|
|
111
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
112
|
+
"""
|
|
113
|
+
super().__init__(
|
|
114
|
+
returns=returns,
|
|
115
|
+
append=append,
|
|
116
|
+
prefix=prefix,
|
|
117
|
+
include=include,
|
|
118
|
+
exclude=exclude,
|
|
119
|
+
consume=consume,
|
|
120
|
+
)
|
|
121
|
+
self.ndim = ndim
|
|
122
|
+
self.dtype = dtype
|
|
123
|
+
self.device = device
|
|
124
|
+
self.kwargs = kwargs
|
|
125
|
+
|
|
126
|
+
def _xform(self, x: tx.Union[str, Tensor]) -> Tensor:
|
|
127
|
+
try:
|
|
128
|
+
return torch.as_tensor(x, dtype=self.dtype, device=self.device)
|
|
129
|
+
except Exception:
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
if isinstance(x, str):
|
|
133
|
+
x = Path(x)
|
|
134
|
+
|
|
135
|
+
exceptions = []
|
|
136
|
+
if isinstance(x, PathLike):
|
|
137
|
+
parts, ext = os.path.splitext(str(x))
|
|
138
|
+
if ext.lower() in ('.gz', '.bz', '.bz2', '.gzip', '.bzip2'):
|
|
139
|
+
_, preext = os.path.splitext(parts)
|
|
140
|
+
ext = preext + ext
|
|
141
|
+
ext = ext.lower()
|
|
142
|
+
if ext in loaders:
|
|
143
|
+
for loader in loaders[ext]:
|
|
144
|
+
try:
|
|
145
|
+
return loader(self.ndim, self.dtype, self.device,
|
|
146
|
+
**self.kwargs)(x)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
exceptions.append(str(e))
|
|
149
|
+
|
|
150
|
+
all_loaders = set(loader for loader_ext in loaders.values()
|
|
151
|
+
for loader in loader_ext)
|
|
152
|
+
for loader in all_loaders:
|
|
153
|
+
try:
|
|
154
|
+
return loader(self.ndim, self.dtype, self.device)(x)
|
|
155
|
+
except Exception as e:
|
|
156
|
+
exceptions.append(str(e))
|
|
157
|
+
pass
|
|
158
|
+
|
|
159
|
+
message = [f'Could not load {x}:'] + exceptions
|
|
160
|
+
message = '\n'.join(message)
|
|
161
|
+
raise ValueError(message)
|
cornucopia/kspace.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
1
|
+
"""This module contains transforms that operate in k-space (Fourier space)."""
|
|
2
|
+
__all__ = [
|
|
3
|
+
'ArrayCoilCombinationTransform',
|
|
4
|
+
'ArrayCoilTransform',
|
|
5
|
+
'SumOfSquaresTransform',
|
|
6
|
+
'IntraScanMotionFinalTransform',
|
|
7
|
+
'IntraScanMotionTransform',
|
|
8
|
+
'SmallIntraScanMotionTransform',
|
|
9
|
+
]
|
|
10
|
+
# stdlib
|
|
11
|
+
import math
|
|
12
|
+
import random
|
|
13
|
+
from math import inf
|
|
14
|
+
|
|
15
|
+
# dependencies
|
|
16
|
+
import torch
|
|
17
|
+
import typing_extensions as tx
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
|
|
20
|
+
# internals
|
|
21
|
+
from .base import Transform, NonFinalTransform, FinalTransform
|
|
22
|
+
from .baseutils import Returned, prepare_output, return_requires
|
|
23
|
+
from .intensity import MulFieldTransform
|
|
24
|
+
from .geometric import RandomAffineTransform
|
|
25
|
+
from .random import Fixed, Sampler
|
|
26
|
+
from .utils.warps import identity
|
|
27
|
+
from .utils.smart_inplace import sqrt_, square_, abs_, mul_, exp_, sub_, add_
|
|
28
|
+
from . import ctx
|
|
29
|
+
from . import typing as cct
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ArrayCoilCombinationTransform(FinalTransform):
|
|
33
|
+
"""Apply coil sensitivities to an image and combine across coils."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, sens: Tensor, **kwargs) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
sens : (K, *spatial) tensor
|
|
40
|
+
Complex coil sensitivities
|
|
41
|
+
|
|
42
|
+
Other Parameters
|
|
43
|
+
----------------
|
|
44
|
+
returns : [(list | dict) of] str
|
|
45
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
46
|
+
Default is `'uncombined'`.
|
|
47
|
+
|
|
48
|
+
| Value | Description |
|
|
49
|
+
| -------------- | ---------------------------------------- |
|
|
50
|
+
| `'sos'` | Sum of square combined (magnitude) image |
|
|
51
|
+
| `'uncombined'` | Uncombined (complex) coil images |
|
|
52
|
+
| `'sens'` | Uncombined (complex) coil sensitivities |
|
|
53
|
+
| `'netsens'` | Net (magnitude) coil sensitivity |
|
|
54
|
+
|
|
55
|
+
append, prefix, include, exclude, consume
|
|
56
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__(**kwargs)
|
|
59
|
+
self.sens = sens
|
|
60
|
+
|
|
61
|
+
def _xform(self, x: Tensor) -> Returned:
|
|
62
|
+
sens = self.sens.to(x.device)
|
|
63
|
+
uncombined = x * sens
|
|
64
|
+
netsens = sqrt_(square_(sens.abs()).sum(0))[None]
|
|
65
|
+
sos = sqrt_(square_(uncombined.abs()).sum(0))[None]
|
|
66
|
+
return prepare_output(
|
|
67
|
+
dict(input=x, sos=sos, output=uncombined,
|
|
68
|
+
uncombined=uncombined, netsens=netsens, sens=sens),
|
|
69
|
+
self.returns)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ArrayCoilTransform(NonFinalTransform):
|
|
73
|
+
"""Generate and apply random coil sensitivities (real or complex)"""
|
|
74
|
+
|
|
75
|
+
Final = Next = ArrayCoilCombinationTransform
|
|
76
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
ncoils: int = 8,
|
|
81
|
+
fwhm: float = 0.5,
|
|
82
|
+
diameter: float = 0.8,
|
|
83
|
+
jitter: float = 0.01,
|
|
84
|
+
unit: tx.Literal['fov', 'vox'] = 'fov',
|
|
85
|
+
shape: cct.NumberOrSequence[int] = 4,
|
|
86
|
+
sos: bool = True,
|
|
87
|
+
*,
|
|
88
|
+
shared=True,
|
|
89
|
+
**kwargs
|
|
90
|
+
) -> None:
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
ncoils : int
|
|
96
|
+
Number of complex receiver channels
|
|
97
|
+
fwhm : float
|
|
98
|
+
Width of each receiver profile
|
|
99
|
+
diameter : float
|
|
100
|
+
Diameter of the ellipsoid on wich receivers are centered
|
|
101
|
+
jitter : float
|
|
102
|
+
Amount of jitter off the ellipsoid
|
|
103
|
+
unit : {'fov', 'vox'}
|
|
104
|
+
Unit of `fwhm`, `diameter`, `jitter`
|
|
105
|
+
shape : [list of] int
|
|
106
|
+
Number of control points for the underlying smooth component.
|
|
107
|
+
|
|
108
|
+
Other Parameters
|
|
109
|
+
----------------
|
|
110
|
+
shared
|
|
111
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
112
|
+
for details.
|
|
113
|
+
returns : [(list | dict) of] str
|
|
114
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
115
|
+
Default is `'uncombined'`.
|
|
116
|
+
|
|
117
|
+
| Value | Description |
|
|
118
|
+
| -------------- | ---------------------------------------- |
|
|
119
|
+
| `'sos'` | Sum of square combined (magnitude) image |
|
|
120
|
+
| `'uncombined'` | Uncombined (complex) coil images |
|
|
121
|
+
| `'sens'` | Uncombined (complex) coil sensitivities |
|
|
122
|
+
| `'netsens'` | Net (magnitude) coil sensitivity |
|
|
123
|
+
|
|
124
|
+
append, prefix, include, exclude, consume
|
|
125
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
126
|
+
""" # noqa: E501
|
|
127
|
+
super().__init__(shared=shared, **kwargs)
|
|
128
|
+
self.ncoils = ncoils
|
|
129
|
+
self.fwhm = fwhm
|
|
130
|
+
self.diameter = diameter
|
|
131
|
+
self.jitter = jitter
|
|
132
|
+
self.unit = unit
|
|
133
|
+
self.shape = shape
|
|
134
|
+
self.sos = sos
|
|
135
|
+
|
|
136
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
137
|
+
if max_depth == 0:
|
|
138
|
+
return self
|
|
139
|
+
|
|
140
|
+
ndim = x.dim() - 1
|
|
141
|
+
backend = dict(dtype=x.dtype, device=x.device)
|
|
142
|
+
fake_x = torch.ones([], **backend)
|
|
143
|
+
fake_x = fake_x.expand([2*self.ncoils, *x.shape[1:]])
|
|
144
|
+
|
|
145
|
+
smooth_bias = MulFieldTransform(shape=self.shape, vmin=-1, vmax=1)
|
|
146
|
+
smooth_bias = smooth_bias(fake_x)
|
|
147
|
+
phase = smooth_bias[::2].atan2(smooth_bias[1::2])
|
|
148
|
+
magnitude = sqrt_(add_(smooth_bias[0::2].square(),
|
|
149
|
+
smooth_bias[1::2].square()))
|
|
150
|
+
|
|
151
|
+
fov = torch.as_tensor(x.shape[1:], **backend)
|
|
152
|
+
fwhm = self.fwhm
|
|
153
|
+
if self.unit == 'fov':
|
|
154
|
+
fwhm = fwhm * fov
|
|
155
|
+
lam = (2.355 / fwhm) ** 2
|
|
156
|
+
for k in range(self.ncoils):
|
|
157
|
+
loc = torch.randn(ndim, **backend)
|
|
158
|
+
loc /= loc.square().sum().sqrt_()
|
|
159
|
+
loc = mul_(loc, self.diameter)
|
|
160
|
+
if self.jitter:
|
|
161
|
+
jitter = torch.rand(ndim, **backend)
|
|
162
|
+
jitter = mul_(jitter, self.jitter)
|
|
163
|
+
loc = add_(loc, jitter)
|
|
164
|
+
loc = (1 + loc) / 2
|
|
165
|
+
if self.unit == 'fov':
|
|
166
|
+
loc = mul_(loc, fov)
|
|
167
|
+
exp_bias = sub_(identity(x.shape[1:], **backend), loc)
|
|
168
|
+
exp_bias = mul_(square_(exp_bias), lam).sum(-1)
|
|
169
|
+
exp_bias = exp_(mul_(exp_bias, -0.5))
|
|
170
|
+
if exp_bias.requires_grad:
|
|
171
|
+
magnitude_k = magnitude[k].clone()
|
|
172
|
+
magnitude[k].copy_(magnitude_k * exp_bias)
|
|
173
|
+
else:
|
|
174
|
+
mul_(magnitude[k], exp_bias)
|
|
175
|
+
|
|
176
|
+
sens = mul_(exp_(1j * phase), magnitude)
|
|
177
|
+
return self.Next(
|
|
178
|
+
sens, **self.get_prm()
|
|
179
|
+
).unroll(x, max_depth-1)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class SumOfSquaresTransform(FinalTransform):
|
|
183
|
+
"""Compute the sum-of-squares across coils/channels"""
|
|
184
|
+
|
|
185
|
+
def _xform(self, x: Tensor) -> Tensor:
|
|
186
|
+
return sqrt_(square_(abs_(x)).sum(0, keepdim=True))
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class IntraScanMotionFinalTransform(FinalTransform):
|
|
190
|
+
"""Apply pre-computed intra-scan motion"""
|
|
191
|
+
|
|
192
|
+
def __init__(
|
|
193
|
+
self,
|
|
194
|
+
motion: FinalTransform,
|
|
195
|
+
patterns: Tensor,
|
|
196
|
+
sens: tx.Optional[FinalTransform] = None,
|
|
197
|
+
axis: int = -1,
|
|
198
|
+
freq: bool = False,
|
|
199
|
+
**kwargs
|
|
200
|
+
) -> None:
|
|
201
|
+
"""
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
motion : FinalTransform
|
|
205
|
+
A transform that applies the motion to an image
|
|
206
|
+
patterns : tensor
|
|
207
|
+
Binary tensor of shape (N, X) indicating which frequencies/slices
|
|
208
|
+
are acquired in each shot. N is the number of shots, X is the
|
|
209
|
+
number of frequencies/slices along the motion axis.
|
|
210
|
+
sens : FinalTransform
|
|
211
|
+
A transform that generates a set of complex sensitivity profiles.
|
|
212
|
+
axis : int
|
|
213
|
+
Axis along which shots are acquired (slice or phase-encode)
|
|
214
|
+
freq : bool
|
|
215
|
+
Motion happens across a phase-encode direction, which means
|
|
216
|
+
that the k-space is build from pieces with different object
|
|
217
|
+
position. This typically happens in "3D" acquisitions.
|
|
218
|
+
If False, motion happens along the slice direction ("2D"
|
|
219
|
+
acquisition).
|
|
220
|
+
|
|
221
|
+
Other Parameters
|
|
222
|
+
----------------
|
|
223
|
+
returns : [(list | dict) of] str
|
|
224
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
225
|
+
Default is `'sos'`.
|
|
226
|
+
|
|
227
|
+
| Value | Description | Shape |
|
|
228
|
+
| -------------- | ---------------------------------------- | ------------- |
|
|
229
|
+
| `'sos'` | Sum of square combined (magnitude) image | `(C,X,Y,Z)` |
|
|
230
|
+
| `'uncombined'` | Uncombined (complex) coil images | `(K,X,Y,Z)` |
|
|
231
|
+
| `'sens'` | Uncombined (complex) coil sensitivities | `(K,X,Y,Z)` |
|
|
232
|
+
| `'netsens'` | Net (magnitude) coil sensitivity | `(1,X,Y,Z)` |
|
|
233
|
+
| `'flow'` | Displacement field in each shot | `(N,3,X,Y,Z)` |
|
|
234
|
+
| `'matrix'` | Rigid matrix in each shot | `(N,4,4)` |
|
|
235
|
+
| `'pattern'` | Frequencies acquired in each shot | `(N,X)` |
|
|
236
|
+
|
|
237
|
+
append, prefix, include, exclude, consume
|
|
238
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
239
|
+
"""
|
|
240
|
+
super().__init__(**kwargs)
|
|
241
|
+
self.motion = motion
|
|
242
|
+
self.patterns = patterns
|
|
243
|
+
self.sens = sens
|
|
244
|
+
self.axis = axis
|
|
245
|
+
self.freq = freq
|
|
246
|
+
|
|
247
|
+
def _xform(self, x: Tensor) -> Returned:
|
|
248
|
+
motions = self.motion
|
|
249
|
+
patterns = self.patterns.to(x.device)
|
|
250
|
+
|
|
251
|
+
if self.sens:
|
|
252
|
+
assert len(x) == 1
|
|
253
|
+
with ctx.returns(self.sens, ['sens', 'netsens']):
|
|
254
|
+
sens, netsens = self.sens(x)
|
|
255
|
+
# sens = self.sens.to(x.device)
|
|
256
|
+
# netsens = sens.abs().square_().sum(0).sqrt_()[None]
|
|
257
|
+
y = x.new_empty([len(sens), *x.shape[1:]],
|
|
258
|
+
dtype=torch.complex64)
|
|
259
|
+
else:
|
|
260
|
+
ydtype = torch.complex64 if self.freq else x.dtype
|
|
261
|
+
y = torch.empty_like(x, dtype=ydtype)
|
|
262
|
+
sens = netsens = None
|
|
263
|
+
|
|
264
|
+
x = x.movedim(self.axis, 1)
|
|
265
|
+
y = y.movedim(self.axis, 1)
|
|
266
|
+
|
|
267
|
+
matrix, flow = [], []
|
|
268
|
+
returned = return_requires(self.returns)
|
|
269
|
+
returns = dict(moved='output')
|
|
270
|
+
if 'matrix' in returned:
|
|
271
|
+
returns['matrix'] = 'matrix'
|
|
272
|
+
if 'flow' in returned:
|
|
273
|
+
returns['flow'] = 'flow'
|
|
274
|
+
for motion_trf, pattern in zip(motions, patterns):
|
|
275
|
+
with ctx.returns(motion_trf, returns):
|
|
276
|
+
moved = motion_trf(x)
|
|
277
|
+
matrix.append(moved.get('matrix', None))
|
|
278
|
+
flow.append(moved.get('flow', None))
|
|
279
|
+
moved = moved['moved']
|
|
280
|
+
if sens is not None:
|
|
281
|
+
moved = moved * sens
|
|
282
|
+
if self.freq:
|
|
283
|
+
moved = torch.fft.ifftshift(moved)
|
|
284
|
+
moved = torch.fft.fft(moved, dim=1)
|
|
285
|
+
moved = torch.fft.fftshift(moved)
|
|
286
|
+
# NOTE: In torch < 1.*:
|
|
287
|
+
# >> y[:, pattern] = moved[:, pattern]
|
|
288
|
+
# RuntimeError: index does not support automatic
|
|
289
|
+
# differentiation for outputs with complex dtype.
|
|
290
|
+
# Use torch.where instead.
|
|
291
|
+
y = torch.where(pattern[None], y, moved)
|
|
292
|
+
|
|
293
|
+
if self.freq:
|
|
294
|
+
y = torch.fft.ifftshift(y)
|
|
295
|
+
y = torch.fft.ifft(y, dim=1)
|
|
296
|
+
y = torch.fft.fftshift(y)
|
|
297
|
+
y = y.movedim(1, self.axis)
|
|
298
|
+
x = x.movedim(1, self.axis)
|
|
299
|
+
|
|
300
|
+
sos = sqrt_(square_(y.abs()).sum(0, keepdim=True))
|
|
301
|
+
|
|
302
|
+
if 'matrix' in returned:
|
|
303
|
+
matrix = torch.stack(matrix)
|
|
304
|
+
else:
|
|
305
|
+
matrix = None
|
|
306
|
+
if 'flow' in returned:
|
|
307
|
+
flow = torch.stack(flow)
|
|
308
|
+
else:
|
|
309
|
+
flow = None
|
|
310
|
+
|
|
311
|
+
return prepare_output(
|
|
312
|
+
dict(input=x, sos=sos, output=sos, uncombined=y, sens=sens,
|
|
313
|
+
netsens=netsens, pattern=patterns, flow=flow,
|
|
314
|
+
matrix=matrix),
|
|
315
|
+
self.returns)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class IntraScanMotionTransform(NonFinalTransform):
|
|
319
|
+
"""Model intra-scan motion"""
|
|
320
|
+
|
|
321
|
+
Final = Next = IntraScanMotionFinalTransform
|
|
322
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
323
|
+
|
|
324
|
+
def __init__(
|
|
325
|
+
self,
|
|
326
|
+
shots: int = 4,
|
|
327
|
+
axis: int = -1,
|
|
328
|
+
freq: bool = True,
|
|
329
|
+
pattern: tx.Union[
|
|
330
|
+
tx.Literal['sequential', 'random'],
|
|
331
|
+
tx.Sequence[Tensor]
|
|
332
|
+
] = 'sequential',
|
|
333
|
+
translations: tx.Union[Sampler, float] = 0.1,
|
|
334
|
+
rotations: tx.Union[Sampler, float] = 15,
|
|
335
|
+
sos: bool = True,
|
|
336
|
+
coils: tx.Optional[Transform] = None,
|
|
337
|
+
*,
|
|
338
|
+
shared: tx.Union[str, bool] = 'channels',
|
|
339
|
+
**kwargs
|
|
340
|
+
) -> None:
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
Parameters
|
|
344
|
+
----------
|
|
345
|
+
shots : int
|
|
346
|
+
Number of acquisition shots.
|
|
347
|
+
The object is in a different position in each shot.
|
|
348
|
+
axis : int
|
|
349
|
+
Axis along which shots are acquired (slice or phase-encode)
|
|
350
|
+
freq : bool
|
|
351
|
+
Motion happens across a phase-encode direction, which means
|
|
352
|
+
that the k-space is build from pieces with different object
|
|
353
|
+
position. This typically happens in "3D" acquisitions.
|
|
354
|
+
If False, motion happens along the slice direction ("2D"
|
|
355
|
+
acquisition).
|
|
356
|
+
pattern : {'sequential', 'random'} or list[tensor[bool or int]]
|
|
357
|
+
k-space (or slice) sampling pattern. This argument encodes
|
|
358
|
+
the frequencies (or slices) that are acquired in each shot.
|
|
359
|
+
The 'sequential' options assumes that frequencies are
|
|
360
|
+
acquired in order. The 'random' option assumes that frequencies
|
|
361
|
+
are randomly distributed across shots.
|
|
362
|
+
translations : Sampler or float
|
|
363
|
+
Sampler (or upper-bound) for random translations (in % of FOV)
|
|
364
|
+
rotations : Sampler or float
|
|
365
|
+
Sampler (or upper-bound) for random rotations (in deg)
|
|
366
|
+
sos : bool
|
|
367
|
+
Whether to return the sum-of-squares combined image across coils.
|
|
368
|
+
coils : Transform
|
|
369
|
+
A transform that generates a set of complex sensitivity profiles
|
|
370
|
+
|
|
371
|
+
Other Parameters
|
|
372
|
+
----------------
|
|
373
|
+
shared
|
|
374
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
375
|
+
for details.
|
|
376
|
+
returns : [(list | dict) of] str
|
|
377
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
378
|
+
Default is `'sos'`.
|
|
379
|
+
|
|
380
|
+
| Value | Description | Shape |
|
|
381
|
+
| -------------- | ---------------------------------------- | ------------- |
|
|
382
|
+
| `'sos'` | Sum of square combined (magnitude) image | `(C,X,Y,Z)` |
|
|
383
|
+
| `'uncombined'` | Uncombined (complex) coil images | `(K,X,Y,Z)` |
|
|
384
|
+
| `'sens'` | Uncombined (complex) coil sensitivities | `(K,X,Y,Z)` |
|
|
385
|
+
| `'netsens'` | Net (magnitude) coil sensitivity | `(1,X,Y,Z)` |
|
|
386
|
+
| `'flow'` | Displacement field in each shot | `(N,3,X,Y,Z)` |
|
|
387
|
+
| `'matrix'` | Rigid matrix in each shot | `(N,4,4)` |
|
|
388
|
+
| `'pattern'` | Frequencies acquired in each shot | `(N,X)` |
|
|
389
|
+
|
|
390
|
+
append, prefix, include, exclude, consume
|
|
391
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
392
|
+
|
|
393
|
+
""" # noqa: E501
|
|
394
|
+
super().__init__(shared=shared, **kwargs)
|
|
395
|
+
self.shots = shots
|
|
396
|
+
self.axis = axis
|
|
397
|
+
self.pattern = pattern
|
|
398
|
+
self.sos = sos
|
|
399
|
+
self.coils = coils
|
|
400
|
+
self.freq = freq
|
|
401
|
+
self.motion = RandomAffineTransform(
|
|
402
|
+
translations=translations, rotations=rotations,
|
|
403
|
+
zooms=Fixed(0), shears=Fixed(0), bound='reflection')
|
|
404
|
+
|
|
405
|
+
def get_pattern(
|
|
406
|
+
self, n: int, device: tx.Optional[torch.device] = None
|
|
407
|
+
) -> Tensor:
|
|
408
|
+
shots = min(self.shots, n)
|
|
409
|
+
pattern = []
|
|
410
|
+
if self.pattern == 'sequential':
|
|
411
|
+
mask = torch.zeros(n, dtype=torch.bool, device=device)
|
|
412
|
+
length = int(math.ceil(n/self.shots))
|
|
413
|
+
for shot in range(shots):
|
|
414
|
+
mask1 = mask.clone()
|
|
415
|
+
mask1[shot*length:(shot+1)*length] = 1
|
|
416
|
+
pattern.append(mask1)
|
|
417
|
+
pattern = torch.stack(pattern)
|
|
418
|
+
elif self.pattern == 'random':
|
|
419
|
+
indices = list(range(n))
|
|
420
|
+
random.shuffle(indices)
|
|
421
|
+
mask = torch.zeros(n, dtype=torch.bool, device=device)
|
|
422
|
+
length = int(math.ceil(n/self.shots))
|
|
423
|
+
for shot in range(shots):
|
|
424
|
+
mask1 = mask.clone()
|
|
425
|
+
index1 = indices[shot*length:(shot+1)*length]
|
|
426
|
+
mask1[index1] = 1
|
|
427
|
+
pattern.append(mask1)
|
|
428
|
+
pattern = torch.stack(pattern)
|
|
429
|
+
elif isinstance(self.pattern, (list, tuple)):
|
|
430
|
+
if max(map(max, self.pattern)) > 1:
|
|
431
|
+
# indices
|
|
432
|
+
mask = torch.zeros(n, dtype=torch.bool, device=device)
|
|
433
|
+
for indices in self.pattern:
|
|
434
|
+
mask1 = mask.clone()
|
|
435
|
+
mask1[indices] = 1
|
|
436
|
+
pattern.append(mask1)
|
|
437
|
+
pattern = torch.stack(pattern)
|
|
438
|
+
else:
|
|
439
|
+
pattern = torch.stack(pattern)
|
|
440
|
+
else:
|
|
441
|
+
assert torch.is_tensor(self.pattern)
|
|
442
|
+
pattern = self.pattern
|
|
443
|
+
return pattern.to(device, torch.bool)
|
|
444
|
+
|
|
445
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
446
|
+
# compute number of motion shots
|
|
447
|
+
shots = min(self.shots, x.shape[self.axis])
|
|
448
|
+
|
|
449
|
+
# sample motion parameters for each shot
|
|
450
|
+
motion = []
|
|
451
|
+
for shot in range(shots):
|
|
452
|
+
motion_trf = self.motion.final(x)
|
|
453
|
+
motion.append(motion_trf)
|
|
454
|
+
|
|
455
|
+
# compute sampling pattern
|
|
456
|
+
pattern = self.get_pattern(x.shape[self.axis], x.device)
|
|
457
|
+
|
|
458
|
+
# sample coil sensitivities
|
|
459
|
+
sens = None
|
|
460
|
+
if self.coils:
|
|
461
|
+
sens = self.coils.final(x)
|
|
462
|
+
|
|
463
|
+
return self.Next(
|
|
464
|
+
motion, pattern, sens, self.axis, self.freq, **self.get_prm()
|
|
465
|
+
).unroll(x, max_depth-1)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class SmallIntraScanMotionTransform(IntraScanMotionTransform):
|
|
469
|
+
"""Model intra-scan motion that happens once across k-space"""
|
|
470
|
+
|
|
471
|
+
def __init__(
|
|
472
|
+
self,
|
|
473
|
+
translations: tx.Union[Sampler, float] = 0.05,
|
|
474
|
+
rotations: tx.Union[Sampler, float] = 5,
|
|
475
|
+
axis: int = -1,
|
|
476
|
+
*,
|
|
477
|
+
shared: tx.Union[str, bool] = 'channels',
|
|
478
|
+
**kwargs
|
|
479
|
+
) -> None:
|
|
480
|
+
"""
|
|
481
|
+
|
|
482
|
+
Parameters
|
|
483
|
+
----------
|
|
484
|
+
translations : Sampler or float
|
|
485
|
+
Sampler (or upper-bound) for random translations (in % of FOV)
|
|
486
|
+
rotations : Sampler or float
|
|
487
|
+
Sampler (or upper-bound) for random rotations (in deg)
|
|
488
|
+
axis : int
|
|
489
|
+
Axis along which shots are acquired (slice or phase-encode)
|
|
490
|
+
|
|
491
|
+
Other Parameters
|
|
492
|
+
----------------
|
|
493
|
+
shared, append, prefix, include, exclude, consume
|
|
494
|
+
See [`IntraScanMotionTransform`][cornucopia.kspace.IntraScanMotionTransform] for details.
|
|
495
|
+
"""
|
|
496
|
+
super().__init__(translations=translations, rotations=rotations,
|
|
497
|
+
shared=shared, shots=2, axis=axis, **kwargs)
|
|
498
|
+
|
|
499
|
+
def get_pattern(
|
|
500
|
+
self, n: int, device: tx.Optional[torch.device] = None
|
|
501
|
+
) -> Tensor:
|
|
502
|
+
k = random.randint(0, n-1)
|
|
503
|
+
mask = torch.zeros(n, dtype=torch.bool, device=device)
|
|
504
|
+
mask[:k] = 1
|
|
505
|
+
return torch.stack([mask, ~mask])
|