cornucopia 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cornucopia/__init__.py +73 -0
  2. cornucopia/base.py +1915 -0
  3. cornucopia/baseutils.py +575 -0
  4. cornucopia/contrast.py +260 -0
  5. cornucopia/ctx.py +25 -0
  6. cornucopia/fov.py +707 -0
  7. cornucopia/geometric.py +2068 -0
  8. cornucopia/intensity.py +1358 -0
  9. cornucopia/io.py +161 -0
  10. cornucopia/kspace.py +505 -0
  11. cornucopia/labels.py +1872 -0
  12. cornucopia/noise.py +508 -0
  13. cornucopia/psf.py +463 -0
  14. cornucopia/qmri.py +1288 -0
  15. cornucopia/random.py +1480 -0
  16. cornucopia/special.py +159 -0
  17. cornucopia/synth.py +708 -0
  18. cornucopia/tests/__init__.py +0 -0
  19. cornucopia/tests/test_backward_geometric.py +173 -0
  20. cornucopia/tests/test_backward_intensity.py +243 -0
  21. cornucopia/tests/test_backward_kspace.py +115 -0
  22. cornucopia/tests/test_backward_noise.py +169 -0
  23. cornucopia/tests/test_backward_psf.py +142 -0
  24. cornucopia/tests/test_backward_qmri.py +249 -0
  25. cornucopia/tests/test_backward_random.py +44 -0
  26. cornucopia/tests/test_backward_synth.py +72 -0
  27. cornucopia/tests/test_base.py +401 -0
  28. cornucopia/tests/test_geometric.py +26 -0
  29. cornucopia/tests/test_intensity.py +9 -0
  30. cornucopia/tests/test_random.py +722 -0
  31. cornucopia/tests/test_run_contrast.py +28 -0
  32. cornucopia/tests/test_run_fov.py +132 -0
  33. cornucopia/tests/test_run_geometric.py +157 -0
  34. cornucopia/tests/test_run_intensity.py +192 -0
  35. cornucopia/tests/test_run_kspace.py +70 -0
  36. cornucopia/tests/test_run_labels.py +224 -0
  37. cornucopia/tests/test_run_noise.py +127 -0
  38. cornucopia/tests/test_run_psf.py +115 -0
  39. cornucopia/tests/test_run_qmri.py +114 -0
  40. cornucopia/tests/test_run_synth.py +67 -0
  41. cornucopia/typing.py +97 -0
  42. cornucopia/utils/__init__.py +0 -0
  43. cornucopia/utils/b0.py +745 -0
  44. cornucopia/utils/bounds.py +412 -0
  45. cornucopia/utils/compat.py +47 -0
  46. cornucopia/utils/conv.py +305 -0
  47. cornucopia/utils/gmm.py +169 -0
  48. cornucopia/utils/indexing.py +911 -0
  49. cornucopia/utils/io.py +258 -0
  50. cornucopia/utils/jit.py +128 -0
  51. cornucopia/utils/kernels.py +288 -0
  52. cornucopia/utils/morpho.py +234 -0
  53. cornucopia/utils/mrf.py +574 -0
  54. cornucopia/utils/padding.py +173 -0
  55. cornucopia/utils/patch.py +302 -0
  56. cornucopia/utils/pool.py +282 -0
  57. cornucopia/utils/py.py +348 -0
  58. cornucopia/utils/smart_inplace.py +163 -0
  59. cornucopia/utils/version.py +57 -0
  60. cornucopia/utils/warps.py +606 -0
  61. cornucopia-0.0.0.dist-info/METADATA +92 -0
  62. cornucopia-0.0.0.dist-info/RECORD +65 -0
  63. cornucopia-0.0.0.dist-info/WHEEL +5 -0
  64. cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
  65. cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/io.py ADDED
@@ -0,0 +1,161 @@
1
+ """This module contains transforms that load data from disk."""
2
+ __all__ = ['ToTensorTransform', 'LoadTransform']
3
+ # stdlib
4
+ import os.path
5
+ from os import PathLike
6
+ from pathlib import Path
7
+
8
+ # dependencies
9
+ import torch
10
+ from torch import Tensor
11
+ import typing_extensions as tx
12
+
13
+ # internals
14
+ from .base import FinalTransform
15
+ from .utils.io import loaders
16
+ from . import typing as cct
17
+
18
+
19
+ class ToTensorTransform(FinalTransform):
20
+ """Convert to Tensor (or to other dtype/device)"""
21
+
22
+ def __init__(
23
+ self,
24
+ ndim: tx.Optional[int] = None,
25
+ dtype: tx.Optional[torch.dtype] = None,
26
+ device: tx.Optional[cct.TorchDevice] = None,
27
+ **kwargs
28
+ ) -> None:
29
+ """
30
+ Parameters
31
+ ----------
32
+ ndim : int, optional
33
+ Number of spatial dimensions (default: guess from data)
34
+ dtype : torch.dtype, optional
35
+ Returned data type (default: keep same)
36
+ device : torch.device, optional
37
+ Returned device (default: keep same)
38
+
39
+ Other Parameters
40
+ ----------------
41
+ returns, append, prefix, include, exclude, consume
42
+ See [`Transform`][cornucopia.base.Transform] for details.
43
+ """
44
+ super().__init__(**kwargs)
45
+ self.dim = ndim
46
+ self.dtype = dtype
47
+ self.device = device
48
+
49
+ def _xform(self, x: Tensor) -> Tensor:
50
+ x = torch.as_tensor(x, dtype=self.dtype, device=self.device).squeeze()
51
+ if self.dim:
52
+ for _ in range(max(0, self.dim + 1 - x.ndim)):
53
+ x = x[None]
54
+ if x.ndim > self.dim + 1:
55
+ raise ValueError(f'Too many dimensions: '
56
+ f'{x.ndim} > 1 + {self.dim}')
57
+ return x
58
+
59
+
60
+ class LoadTransform(FinalTransform):
61
+ """
62
+ Load data from disk.
63
+
64
+ Available loaders are:
65
+
66
+ - `BabelLoader`: for medical image formats (nifti, mgz, minc, etc.)
67
+ - `TiffLoader`: for TIFF files (including multi-page)
68
+ - `PillowLoader`: for common image formats (png, jpg, etc., with optional rot90)
69
+ - `NumpyLoader`: for .npy and .npz files (with optional field name)
70
+
71
+ Custom loaders can be added by registering them in
72
+ `cornucopia.utils.io.loaders`.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ ndim: tx.Optional[int] = None,
78
+ dtype: tx.Optional[torch.dtype] = None,
79
+ *,
80
+ device: tx.Optional[cct.TorchDevice] = None,
81
+ returns: tx.Optional[cct.ReturnsT] = None,
82
+ append: cct.AppendT = False,
83
+ prefix: cct.PrefixT = False,
84
+ include: tx.Optional[cct.IncludeT] = None,
85
+ exclude: tx.Optional[cct.ExcludeT] = None,
86
+ consume: tx.Optional[cct.ConsumeT] = None,
87
+ **kwargs
88
+ ) -> None:
89
+ """
90
+ Parameters
91
+ ----------
92
+ ndim : int | None
93
+ Number of spatial dimensions (default: guess from file)
94
+ dtype : torch.dtype | str | None
95
+ Data type (default: guess from file)
96
+ device : torch.device | str | None
97
+ Device on which to load data (default: cpu)
98
+
99
+ Other Parameters
100
+ ------------------
101
+ to_ras : bool, default=True
102
+ Reorient data so that it has a RAS layout.
103
+ Only used by `BabelLoader`.
104
+ rot90 : bool, default=True
105
+ Rotate by 90 degrees in-plane.
106
+ Only used by `PillowLoader`.
107
+ field : str, default="arr_0"
108
+ Field to load from a npz file.
109
+ Only used by `NumpyLoader`.
110
+ returns, append, prefix, include, exclude, consume
111
+ See [`Transform`][cornucopia.base.Transform] for details.
112
+ """
113
+ super().__init__(
114
+ returns=returns,
115
+ append=append,
116
+ prefix=prefix,
117
+ include=include,
118
+ exclude=exclude,
119
+ consume=consume,
120
+ )
121
+ self.ndim = ndim
122
+ self.dtype = dtype
123
+ self.device = device
124
+ self.kwargs = kwargs
125
+
126
+ def _xform(self, x: tx.Union[str, Tensor]) -> Tensor:
127
+ try:
128
+ return torch.as_tensor(x, dtype=self.dtype, device=self.device)
129
+ except Exception:
130
+ pass
131
+
132
+ if isinstance(x, str):
133
+ x = Path(x)
134
+
135
+ exceptions = []
136
+ if isinstance(x, PathLike):
137
+ parts, ext = os.path.splitext(str(x))
138
+ if ext.lower() in ('.gz', '.bz', '.bz2', '.gzip', '.bzip2'):
139
+ _, preext = os.path.splitext(parts)
140
+ ext = preext + ext
141
+ ext = ext.lower()
142
+ if ext in loaders:
143
+ for loader in loaders[ext]:
144
+ try:
145
+ return loader(self.ndim, self.dtype, self.device,
146
+ **self.kwargs)(x)
147
+ except Exception as e:
148
+ exceptions.append(str(e))
149
+
150
+ all_loaders = set(loader for loader_ext in loaders.values()
151
+ for loader in loader_ext)
152
+ for loader in all_loaders:
153
+ try:
154
+ return loader(self.ndim, self.dtype, self.device)(x)
155
+ except Exception as e:
156
+ exceptions.append(str(e))
157
+ pass
158
+
159
+ message = [f'Could not load {x}:'] + exceptions
160
+ message = '\n'.join(message)
161
+ raise ValueError(message)
cornucopia/kspace.py ADDED
@@ -0,0 +1,505 @@
1
+ """This module contains transforms that operate in k-space (Fourier space)."""
2
+ __all__ = [
3
+ 'ArrayCoilCombinationTransform',
4
+ 'ArrayCoilTransform',
5
+ 'SumOfSquaresTransform',
6
+ 'IntraScanMotionFinalTransform',
7
+ 'IntraScanMotionTransform',
8
+ 'SmallIntraScanMotionTransform',
9
+ ]
10
+ # stdlib
11
+ import math
12
+ import random
13
+ from math import inf
14
+
15
+ # dependencies
16
+ import torch
17
+ import typing_extensions as tx
18
+ from torch import Tensor
19
+
20
+ # internals
21
+ from .base import Transform, NonFinalTransform, FinalTransform
22
+ from .baseutils import Returned, prepare_output, return_requires
23
+ from .intensity import MulFieldTransform
24
+ from .geometric import RandomAffineTransform
25
+ from .random import Fixed, Sampler
26
+ from .utils.warps import identity
27
+ from .utils.smart_inplace import sqrt_, square_, abs_, mul_, exp_, sub_, add_
28
+ from . import ctx
29
+ from . import typing as cct
30
+
31
+
32
+ class ArrayCoilCombinationTransform(FinalTransform):
33
+ """Apply coil sensitivities to an image and combine across coils."""
34
+
35
+ def __init__(self, sens: Tensor, **kwargs) -> None:
36
+ """
37
+ Parameters
38
+ ----------
39
+ sens : (K, *spatial) tensor
40
+ Complex coil sensitivities
41
+
42
+ Other Parameters
43
+ ----------------
44
+ returns : [(list | dict) of] str
45
+ See [`Transform`][cornucopia.base.Transform] for details.
46
+ Default is `'uncombined'`.
47
+
48
+ | Value | Description |
49
+ | -------------- | ---------------------------------------- |
50
+ | `'sos'` | Sum of square combined (magnitude) image |
51
+ | `'uncombined'` | Uncombined (complex) coil images |
52
+ | `'sens'` | Uncombined (complex) coil sensitivities |
53
+ | `'netsens'` | Net (magnitude) coil sensitivity |
54
+
55
+ append, prefix, include, exclude, consume
56
+ See [`Transform`][cornucopia.base.Transform] for details.
57
+ """
58
+ super().__init__(**kwargs)
59
+ self.sens = sens
60
+
61
+ def _xform(self, x: Tensor) -> Returned:
62
+ sens = self.sens.to(x.device)
63
+ uncombined = x * sens
64
+ netsens = sqrt_(square_(sens.abs()).sum(0))[None]
65
+ sos = sqrt_(square_(uncombined.abs()).sum(0))[None]
66
+ return prepare_output(
67
+ dict(input=x, sos=sos, output=uncombined,
68
+ uncombined=uncombined, netsens=netsens, sens=sens),
69
+ self.returns)
70
+
71
+
72
+ class ArrayCoilTransform(NonFinalTransform):
73
+ """Generate and apply random coil sensitivities (real or complex)"""
74
+
75
+ Final = Next = ArrayCoilCombinationTransform
76
+ """The transform type returned by `unroll`, `next` and `final`."""
77
+
78
+ def __init__(
79
+ self,
80
+ ncoils: int = 8,
81
+ fwhm: float = 0.5,
82
+ diameter: float = 0.8,
83
+ jitter: float = 0.01,
84
+ unit: tx.Literal['fov', 'vox'] = 'fov',
85
+ shape: cct.NumberOrSequence[int] = 4,
86
+ sos: bool = True,
87
+ *,
88
+ shared=True,
89
+ **kwargs
90
+ ) -> None:
91
+ """
92
+
93
+ Parameters
94
+ ----------
95
+ ncoils : int
96
+ Number of complex receiver channels
97
+ fwhm : float
98
+ Width of each receiver profile
99
+ diameter : float
100
+ Diameter of the ellipsoid on wich receivers are centered
101
+ jitter : float
102
+ Amount of jitter off the ellipsoid
103
+ unit : {'fov', 'vox'}
104
+ Unit of `fwhm`, `diameter`, `jitter`
105
+ shape : [list of] int
106
+ Number of control points for the underlying smooth component.
107
+
108
+ Other Parameters
109
+ ----------------
110
+ shared
111
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
112
+ for details.
113
+ returns : [(list | dict) of] str
114
+ See [`Transform`][cornucopia.base.Transform] for details.
115
+ Default is `'uncombined'`.
116
+
117
+ | Value | Description |
118
+ | -------------- | ---------------------------------------- |
119
+ | `'sos'` | Sum of square combined (magnitude) image |
120
+ | `'uncombined'` | Uncombined (complex) coil images |
121
+ | `'sens'` | Uncombined (complex) coil sensitivities |
122
+ | `'netsens'` | Net (magnitude) coil sensitivity |
123
+
124
+ append, prefix, include, exclude, consume
125
+ See [`Transform`][cornucopia.base.Transform] for details.
126
+ """ # noqa: E501
127
+ super().__init__(shared=shared, **kwargs)
128
+ self.ncoils = ncoils
129
+ self.fwhm = fwhm
130
+ self.diameter = diameter
131
+ self.jitter = jitter
132
+ self.unit = unit
133
+ self.shape = shape
134
+ self.sos = sos
135
+
136
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
137
+ if max_depth == 0:
138
+ return self
139
+
140
+ ndim = x.dim() - 1
141
+ backend = dict(dtype=x.dtype, device=x.device)
142
+ fake_x = torch.ones([], **backend)
143
+ fake_x = fake_x.expand([2*self.ncoils, *x.shape[1:]])
144
+
145
+ smooth_bias = MulFieldTransform(shape=self.shape, vmin=-1, vmax=1)
146
+ smooth_bias = smooth_bias(fake_x)
147
+ phase = smooth_bias[::2].atan2(smooth_bias[1::2])
148
+ magnitude = sqrt_(add_(smooth_bias[0::2].square(),
149
+ smooth_bias[1::2].square()))
150
+
151
+ fov = torch.as_tensor(x.shape[1:], **backend)
152
+ fwhm = self.fwhm
153
+ if self.unit == 'fov':
154
+ fwhm = fwhm * fov
155
+ lam = (2.355 / fwhm) ** 2
156
+ for k in range(self.ncoils):
157
+ loc = torch.randn(ndim, **backend)
158
+ loc /= loc.square().sum().sqrt_()
159
+ loc = mul_(loc, self.diameter)
160
+ if self.jitter:
161
+ jitter = torch.rand(ndim, **backend)
162
+ jitter = mul_(jitter, self.jitter)
163
+ loc = add_(loc, jitter)
164
+ loc = (1 + loc) / 2
165
+ if self.unit == 'fov':
166
+ loc = mul_(loc, fov)
167
+ exp_bias = sub_(identity(x.shape[1:], **backend), loc)
168
+ exp_bias = mul_(square_(exp_bias), lam).sum(-1)
169
+ exp_bias = exp_(mul_(exp_bias, -0.5))
170
+ if exp_bias.requires_grad:
171
+ magnitude_k = magnitude[k].clone()
172
+ magnitude[k].copy_(magnitude_k * exp_bias)
173
+ else:
174
+ mul_(magnitude[k], exp_bias)
175
+
176
+ sens = mul_(exp_(1j * phase), magnitude)
177
+ return self.Next(
178
+ sens, **self.get_prm()
179
+ ).unroll(x, max_depth-1)
180
+
181
+
182
+ class SumOfSquaresTransform(FinalTransform):
183
+ """Compute the sum-of-squares across coils/channels"""
184
+
185
+ def _xform(self, x: Tensor) -> Tensor:
186
+ return sqrt_(square_(abs_(x)).sum(0, keepdim=True))
187
+
188
+
189
+ class IntraScanMotionFinalTransform(FinalTransform):
190
+ """Apply pre-computed intra-scan motion"""
191
+
192
+ def __init__(
193
+ self,
194
+ motion: FinalTransform,
195
+ patterns: Tensor,
196
+ sens: tx.Optional[FinalTransform] = None,
197
+ axis: int = -1,
198
+ freq: bool = False,
199
+ **kwargs
200
+ ) -> None:
201
+ """
202
+ Parameters
203
+ ----------
204
+ motion : FinalTransform
205
+ A transform that applies the motion to an image
206
+ patterns : tensor
207
+ Binary tensor of shape (N, X) indicating which frequencies/slices
208
+ are acquired in each shot. N is the number of shots, X is the
209
+ number of frequencies/slices along the motion axis.
210
+ sens : FinalTransform
211
+ A transform that generates a set of complex sensitivity profiles.
212
+ axis : int
213
+ Axis along which shots are acquired (slice or phase-encode)
214
+ freq : bool
215
+ Motion happens across a phase-encode direction, which means
216
+ that the k-space is build from pieces with different object
217
+ position. This typically happens in "3D" acquisitions.
218
+ If False, motion happens along the slice direction ("2D"
219
+ acquisition).
220
+
221
+ Other Parameters
222
+ ----------------
223
+ returns : [(list | dict) of] str
224
+ See [`Transform`][cornucopia.base.Transform] for details.
225
+ Default is `'sos'`.
226
+
227
+ | Value | Description | Shape |
228
+ | -------------- | ---------------------------------------- | ------------- |
229
+ | `'sos'` | Sum of square combined (magnitude) image | `(C,X,Y,Z)` |
230
+ | `'uncombined'` | Uncombined (complex) coil images | `(K,X,Y,Z)` |
231
+ | `'sens'` | Uncombined (complex) coil sensitivities | `(K,X,Y,Z)` |
232
+ | `'netsens'` | Net (magnitude) coil sensitivity | `(1,X,Y,Z)` |
233
+ | `'flow'` | Displacement field in each shot | `(N,3,X,Y,Z)` |
234
+ | `'matrix'` | Rigid matrix in each shot | `(N,4,4)` |
235
+ | `'pattern'` | Frequencies acquired in each shot | `(N,X)` |
236
+
237
+ append, prefix, include, exclude, consume
238
+ See [`Transform`][cornucopia.base.Transform] for details.
239
+ """
240
+ super().__init__(**kwargs)
241
+ self.motion = motion
242
+ self.patterns = patterns
243
+ self.sens = sens
244
+ self.axis = axis
245
+ self.freq = freq
246
+
247
+ def _xform(self, x: Tensor) -> Returned:
248
+ motions = self.motion
249
+ patterns = self.patterns.to(x.device)
250
+
251
+ if self.sens:
252
+ assert len(x) == 1
253
+ with ctx.returns(self.sens, ['sens', 'netsens']):
254
+ sens, netsens = self.sens(x)
255
+ # sens = self.sens.to(x.device)
256
+ # netsens = sens.abs().square_().sum(0).sqrt_()[None]
257
+ y = x.new_empty([len(sens), *x.shape[1:]],
258
+ dtype=torch.complex64)
259
+ else:
260
+ ydtype = torch.complex64 if self.freq else x.dtype
261
+ y = torch.empty_like(x, dtype=ydtype)
262
+ sens = netsens = None
263
+
264
+ x = x.movedim(self.axis, 1)
265
+ y = y.movedim(self.axis, 1)
266
+
267
+ matrix, flow = [], []
268
+ returned = return_requires(self.returns)
269
+ returns = dict(moved='output')
270
+ if 'matrix' in returned:
271
+ returns['matrix'] = 'matrix'
272
+ if 'flow' in returned:
273
+ returns['flow'] = 'flow'
274
+ for motion_trf, pattern in zip(motions, patterns):
275
+ with ctx.returns(motion_trf, returns):
276
+ moved = motion_trf(x)
277
+ matrix.append(moved.get('matrix', None))
278
+ flow.append(moved.get('flow', None))
279
+ moved = moved['moved']
280
+ if sens is not None:
281
+ moved = moved * sens
282
+ if self.freq:
283
+ moved = torch.fft.ifftshift(moved)
284
+ moved = torch.fft.fft(moved, dim=1)
285
+ moved = torch.fft.fftshift(moved)
286
+ # NOTE: In torch < 1.*:
287
+ # >> y[:, pattern] = moved[:, pattern]
288
+ # RuntimeError: index does not support automatic
289
+ # differentiation for outputs with complex dtype.
290
+ # Use torch.where instead.
291
+ y = torch.where(pattern[None], y, moved)
292
+
293
+ if self.freq:
294
+ y = torch.fft.ifftshift(y)
295
+ y = torch.fft.ifft(y, dim=1)
296
+ y = torch.fft.fftshift(y)
297
+ y = y.movedim(1, self.axis)
298
+ x = x.movedim(1, self.axis)
299
+
300
+ sos = sqrt_(square_(y.abs()).sum(0, keepdim=True))
301
+
302
+ if 'matrix' in returned:
303
+ matrix = torch.stack(matrix)
304
+ else:
305
+ matrix = None
306
+ if 'flow' in returned:
307
+ flow = torch.stack(flow)
308
+ else:
309
+ flow = None
310
+
311
+ return prepare_output(
312
+ dict(input=x, sos=sos, output=sos, uncombined=y, sens=sens,
313
+ netsens=netsens, pattern=patterns, flow=flow,
314
+ matrix=matrix),
315
+ self.returns)
316
+
317
+
318
+ class IntraScanMotionTransform(NonFinalTransform):
319
+ """Model intra-scan motion"""
320
+
321
+ Final = Next = IntraScanMotionFinalTransform
322
+ """The transform type returned by `unroll`, `next` and `final`."""
323
+
324
+ def __init__(
325
+ self,
326
+ shots: int = 4,
327
+ axis: int = -1,
328
+ freq: bool = True,
329
+ pattern: tx.Union[
330
+ tx.Literal['sequential', 'random'],
331
+ tx.Sequence[Tensor]
332
+ ] = 'sequential',
333
+ translations: tx.Union[Sampler, float] = 0.1,
334
+ rotations: tx.Union[Sampler, float] = 15,
335
+ sos: bool = True,
336
+ coils: tx.Optional[Transform] = None,
337
+ *,
338
+ shared: tx.Union[str, bool] = 'channels',
339
+ **kwargs
340
+ ) -> None:
341
+ """
342
+
343
+ Parameters
344
+ ----------
345
+ shots : int
346
+ Number of acquisition shots.
347
+ The object is in a different position in each shot.
348
+ axis : int
349
+ Axis along which shots are acquired (slice or phase-encode)
350
+ freq : bool
351
+ Motion happens across a phase-encode direction, which means
352
+ that the k-space is build from pieces with different object
353
+ position. This typically happens in "3D" acquisitions.
354
+ If False, motion happens along the slice direction ("2D"
355
+ acquisition).
356
+ pattern : {'sequential', 'random'} or list[tensor[bool or int]]
357
+ k-space (or slice) sampling pattern. This argument encodes
358
+ the frequencies (or slices) that are acquired in each shot.
359
+ The 'sequential' options assumes that frequencies are
360
+ acquired in order. The 'random' option assumes that frequencies
361
+ are randomly distributed across shots.
362
+ translations : Sampler or float
363
+ Sampler (or upper-bound) for random translations (in % of FOV)
364
+ rotations : Sampler or float
365
+ Sampler (or upper-bound) for random rotations (in deg)
366
+ sos : bool
367
+ Whether to return the sum-of-squares combined image across coils.
368
+ coils : Transform
369
+ A transform that generates a set of complex sensitivity profiles
370
+
371
+ Other Parameters
372
+ ----------------
373
+ shared
374
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
375
+ for details.
376
+ returns : [(list | dict) of] str
377
+ See [`Transform`][cornucopia.base.Transform] for details.
378
+ Default is `'sos'`.
379
+
380
+ | Value | Description | Shape |
381
+ | -------------- | ---------------------------------------- | ------------- |
382
+ | `'sos'` | Sum of square combined (magnitude) image | `(C,X,Y,Z)` |
383
+ | `'uncombined'` | Uncombined (complex) coil images | `(K,X,Y,Z)` |
384
+ | `'sens'` | Uncombined (complex) coil sensitivities | `(K,X,Y,Z)` |
385
+ | `'netsens'` | Net (magnitude) coil sensitivity | `(1,X,Y,Z)` |
386
+ | `'flow'` | Displacement field in each shot | `(N,3,X,Y,Z)` |
387
+ | `'matrix'` | Rigid matrix in each shot | `(N,4,4)` |
388
+ | `'pattern'` | Frequencies acquired in each shot | `(N,X)` |
389
+
390
+ append, prefix, include, exclude, consume
391
+ See [`Transform`][cornucopia.base.Transform] for details.
392
+
393
+ """ # noqa: E501
394
+ super().__init__(shared=shared, **kwargs)
395
+ self.shots = shots
396
+ self.axis = axis
397
+ self.pattern = pattern
398
+ self.sos = sos
399
+ self.coils = coils
400
+ self.freq = freq
401
+ self.motion = RandomAffineTransform(
402
+ translations=translations, rotations=rotations,
403
+ zooms=Fixed(0), shears=Fixed(0), bound='reflection')
404
+
405
+ def get_pattern(
406
+ self, n: int, device: tx.Optional[torch.device] = None
407
+ ) -> Tensor:
408
+ shots = min(self.shots, n)
409
+ pattern = []
410
+ if self.pattern == 'sequential':
411
+ mask = torch.zeros(n, dtype=torch.bool, device=device)
412
+ length = int(math.ceil(n/self.shots))
413
+ for shot in range(shots):
414
+ mask1 = mask.clone()
415
+ mask1[shot*length:(shot+1)*length] = 1
416
+ pattern.append(mask1)
417
+ pattern = torch.stack(pattern)
418
+ elif self.pattern == 'random':
419
+ indices = list(range(n))
420
+ random.shuffle(indices)
421
+ mask = torch.zeros(n, dtype=torch.bool, device=device)
422
+ length = int(math.ceil(n/self.shots))
423
+ for shot in range(shots):
424
+ mask1 = mask.clone()
425
+ index1 = indices[shot*length:(shot+1)*length]
426
+ mask1[index1] = 1
427
+ pattern.append(mask1)
428
+ pattern = torch.stack(pattern)
429
+ elif isinstance(self.pattern, (list, tuple)):
430
+ if max(map(max, self.pattern)) > 1:
431
+ # indices
432
+ mask = torch.zeros(n, dtype=torch.bool, device=device)
433
+ for indices in self.pattern:
434
+ mask1 = mask.clone()
435
+ mask1[indices] = 1
436
+ pattern.append(mask1)
437
+ pattern = torch.stack(pattern)
438
+ else:
439
+ pattern = torch.stack(pattern)
440
+ else:
441
+ assert torch.is_tensor(self.pattern)
442
+ pattern = self.pattern
443
+ return pattern.to(device, torch.bool)
444
+
445
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
446
+ # compute number of motion shots
447
+ shots = min(self.shots, x.shape[self.axis])
448
+
449
+ # sample motion parameters for each shot
450
+ motion = []
451
+ for shot in range(shots):
452
+ motion_trf = self.motion.final(x)
453
+ motion.append(motion_trf)
454
+
455
+ # compute sampling pattern
456
+ pattern = self.get_pattern(x.shape[self.axis], x.device)
457
+
458
+ # sample coil sensitivities
459
+ sens = None
460
+ if self.coils:
461
+ sens = self.coils.final(x)
462
+
463
+ return self.Next(
464
+ motion, pattern, sens, self.axis, self.freq, **self.get_prm()
465
+ ).unroll(x, max_depth-1)
466
+
467
+
468
+ class SmallIntraScanMotionTransform(IntraScanMotionTransform):
469
+ """Model intra-scan motion that happens once across k-space"""
470
+
471
+ def __init__(
472
+ self,
473
+ translations: tx.Union[Sampler, float] = 0.05,
474
+ rotations: tx.Union[Sampler, float] = 5,
475
+ axis: int = -1,
476
+ *,
477
+ shared: tx.Union[str, bool] = 'channels',
478
+ **kwargs
479
+ ) -> None:
480
+ """
481
+
482
+ Parameters
483
+ ----------
484
+ translations : Sampler or float
485
+ Sampler (or upper-bound) for random translations (in % of FOV)
486
+ rotations : Sampler or float
487
+ Sampler (or upper-bound) for random rotations (in deg)
488
+ axis : int
489
+ Axis along which shots are acquired (slice or phase-encode)
490
+
491
+ Other Parameters
492
+ ----------------
493
+ shared, append, prefix, include, exclude, consume
494
+ See [`IntraScanMotionTransform`][cornucopia.kspace.IntraScanMotionTransform] for details.
495
+ """
496
+ super().__init__(translations=translations, rotations=rotations,
497
+ shared=shared, shots=2, axis=axis, **kwargs)
498
+
499
+ def get_pattern(
500
+ self, n: int, device: tx.Optional[torch.device] = None
501
+ ) -> Tensor:
502
+ k = random.randint(0, n-1)
503
+ mask = torch.zeros(n, dtype=torch.bool, device=device)
504
+ mask[:k] = 1
505
+ return torch.stack([mask, ~mask])