cornucopia 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornucopia/__init__.py +73 -0
- cornucopia/base.py +1915 -0
- cornucopia/baseutils.py +575 -0
- cornucopia/contrast.py +260 -0
- cornucopia/ctx.py +25 -0
- cornucopia/fov.py +707 -0
- cornucopia/geometric.py +2068 -0
- cornucopia/intensity.py +1358 -0
- cornucopia/io.py +161 -0
- cornucopia/kspace.py +505 -0
- cornucopia/labels.py +1872 -0
- cornucopia/noise.py +508 -0
- cornucopia/psf.py +463 -0
- cornucopia/qmri.py +1288 -0
- cornucopia/random.py +1480 -0
- cornucopia/special.py +159 -0
- cornucopia/synth.py +708 -0
- cornucopia/tests/__init__.py +0 -0
- cornucopia/tests/test_backward_geometric.py +173 -0
- cornucopia/tests/test_backward_intensity.py +243 -0
- cornucopia/tests/test_backward_kspace.py +115 -0
- cornucopia/tests/test_backward_noise.py +169 -0
- cornucopia/tests/test_backward_psf.py +142 -0
- cornucopia/tests/test_backward_qmri.py +249 -0
- cornucopia/tests/test_backward_random.py +44 -0
- cornucopia/tests/test_backward_synth.py +72 -0
- cornucopia/tests/test_base.py +401 -0
- cornucopia/tests/test_geometric.py +26 -0
- cornucopia/tests/test_intensity.py +9 -0
- cornucopia/tests/test_random.py +722 -0
- cornucopia/tests/test_run_contrast.py +28 -0
- cornucopia/tests/test_run_fov.py +132 -0
- cornucopia/tests/test_run_geometric.py +157 -0
- cornucopia/tests/test_run_intensity.py +192 -0
- cornucopia/tests/test_run_kspace.py +70 -0
- cornucopia/tests/test_run_labels.py +224 -0
- cornucopia/tests/test_run_noise.py +127 -0
- cornucopia/tests/test_run_psf.py +115 -0
- cornucopia/tests/test_run_qmri.py +114 -0
- cornucopia/tests/test_run_synth.py +67 -0
- cornucopia/typing.py +97 -0
- cornucopia/utils/__init__.py +0 -0
- cornucopia/utils/b0.py +745 -0
- cornucopia/utils/bounds.py +412 -0
- cornucopia/utils/compat.py +47 -0
- cornucopia/utils/conv.py +305 -0
- cornucopia/utils/gmm.py +169 -0
- cornucopia/utils/indexing.py +911 -0
- cornucopia/utils/io.py +258 -0
- cornucopia/utils/jit.py +128 -0
- cornucopia/utils/kernels.py +288 -0
- cornucopia/utils/morpho.py +234 -0
- cornucopia/utils/mrf.py +574 -0
- cornucopia/utils/padding.py +173 -0
- cornucopia/utils/patch.py +302 -0
- cornucopia/utils/pool.py +282 -0
- cornucopia/utils/py.py +348 -0
- cornucopia/utils/smart_inplace.py +163 -0
- cornucopia/utils/version.py +57 -0
- cornucopia/utils/warps.py +606 -0
- cornucopia-0.0.0.dist-info/METADATA +92 -0
- cornucopia-0.0.0.dist-info/RECORD +65 -0
- cornucopia-0.0.0.dist-info/WHEEL +5 -0
- cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
- cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/fov.py
ADDED
|
@@ -0,0 +1,707 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains transforms that operate on the field of view (FOV)
|
|
3
|
+
of the image.
|
|
4
|
+
"""
|
|
5
|
+
__all__ = [
|
|
6
|
+
'FlipTransform',
|
|
7
|
+
'RandomFlipTransform',
|
|
8
|
+
'PermuteAxesTransform',
|
|
9
|
+
'RandomPermuteAxesTransform',
|
|
10
|
+
'PatchTransform',
|
|
11
|
+
'RandomPatchTransform',
|
|
12
|
+
'CropTransform',
|
|
13
|
+
'PadTransform',
|
|
14
|
+
'PowerTwoTransform',
|
|
15
|
+
'Rot90Transform',
|
|
16
|
+
'Rot180Transform',
|
|
17
|
+
'RandomRot90Transform',
|
|
18
|
+
]
|
|
19
|
+
# stdlib
|
|
20
|
+
import math
|
|
21
|
+
from math import inf
|
|
22
|
+
from numbers import Number
|
|
23
|
+
from random import shuffle
|
|
24
|
+
|
|
25
|
+
# dependencies
|
|
26
|
+
from torch import Tensor
|
|
27
|
+
import typing_extensions as tx
|
|
28
|
+
|
|
29
|
+
# internals
|
|
30
|
+
from .base import Transform
|
|
31
|
+
from .base import FinalTransform, NonFinalTransform, PerChannelTransform
|
|
32
|
+
from .utils.py import ensure_list
|
|
33
|
+
from .utils.padding import pad
|
|
34
|
+
from .random import Uniform, RandKFrom, Sampler, RandInt, make_range
|
|
35
|
+
from . import typing as cct
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class FlipTransform(FinalTransform):
|
|
39
|
+
"""Flip one or more axes."""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self, axis: tx.Optional[cct.ScalarOrSequence[int]] = None, **kwargs
|
|
43
|
+
) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
axis : [list of] int
|
|
48
|
+
Axes to flip. By default, flip all spatial axes.
|
|
49
|
+
|
|
50
|
+
Other Parameters
|
|
51
|
+
----------------
|
|
52
|
+
returns, append, prefix, include, exclude, consume
|
|
53
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
54
|
+
"""
|
|
55
|
+
super().__init__(**kwargs)
|
|
56
|
+
self.axis = axis
|
|
57
|
+
|
|
58
|
+
def _xform(self, x: Tensor) -> Tensor:
|
|
59
|
+
axis = self.axis
|
|
60
|
+
if axis is None:
|
|
61
|
+
axis = list(range(1, x.ndim))
|
|
62
|
+
axis = ensure_list(axis)
|
|
63
|
+
return x.flip(axis)
|
|
64
|
+
|
|
65
|
+
def make_inverse(self) -> 'FlipTransform':
|
|
66
|
+
return self
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class RandomFlipTransform(NonFinalTransform):
|
|
70
|
+
"""Randomly flip one or more axes."""
|
|
71
|
+
|
|
72
|
+
Final = Next = FlipTransform
|
|
73
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
axes: tx.Union[Sampler, cct.ScalarOrSequence[int], None] = None,
|
|
78
|
+
*,
|
|
79
|
+
shared: bool = True,
|
|
80
|
+
**kwargs
|
|
81
|
+
) -> None:
|
|
82
|
+
"""
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
axes : Sampler | [list of] int
|
|
86
|
+
Axes that can be flipped (default: all spatial axes)
|
|
87
|
+
|
|
88
|
+
Other Parameters
|
|
89
|
+
----------------
|
|
90
|
+
shared
|
|
91
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
92
|
+
for details.
|
|
93
|
+
returns, append, prefix, include, exclude, consume
|
|
94
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
95
|
+
"""
|
|
96
|
+
axes = kwargs.pop('axis', axes)
|
|
97
|
+
super().__init__(shared=shared, **kwargs)
|
|
98
|
+
self.axes = axes
|
|
99
|
+
|
|
100
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
101
|
+
if max_depth == 0:
|
|
102
|
+
return self
|
|
103
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
104
|
+
return PerChannelTransform(
|
|
105
|
+
[self.unroll(x[i:i+1], max_depth) for i in range(len(x))],
|
|
106
|
+
**self.get_prm()
|
|
107
|
+
).unroll(x, max_depth-1)
|
|
108
|
+
axes = self.axes or range(1, x.ndim)
|
|
109
|
+
if not isinstance(axes, Sampler):
|
|
110
|
+
rand_axes = RandKFrom(ensure_list(axes))
|
|
111
|
+
rand_axes = rand_axes()
|
|
112
|
+
return FlipTransform(rand_axes).unroll(x, max_depth-1)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class PermuteAxesTransform(FinalTransform):
|
|
116
|
+
"""Permute axes"""
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self, permutation: tx.Optional[tx.Sequence[int]] = None, **kwargs
|
|
120
|
+
) -> None:
|
|
121
|
+
"""
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
permutation : [list of] int
|
|
125
|
+
Axes permutation. By default, reverse axes.
|
|
126
|
+
Only applies to spatial axes, so axes are numbered [C, 0, 1, 2]
|
|
127
|
+
|
|
128
|
+
Other Parameters
|
|
129
|
+
----------------
|
|
130
|
+
returns, append, prefix, include, exclude, consume
|
|
131
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
132
|
+
"""
|
|
133
|
+
super().__init__(**kwargs)
|
|
134
|
+
self.permutation = permutation
|
|
135
|
+
|
|
136
|
+
def _xform(self, x: Tensor) -> Tensor:
|
|
137
|
+
permutation = self.permutation
|
|
138
|
+
if permutation is None:
|
|
139
|
+
permutation = list(reversed(range(x.dim()-1)))
|
|
140
|
+
permutation = [0] + [p+1 for p in permutation]
|
|
141
|
+
return x.permute(permutation)
|
|
142
|
+
|
|
143
|
+
def make_inverse(self) -> 'PermuteAxesTransform':
|
|
144
|
+
if self.permutation:
|
|
145
|
+
i = range(len(self.permutation))
|
|
146
|
+
iperm = [i[p] for p in self.permutation]
|
|
147
|
+
return PermuteAxesTransform(iperm, **self.get_prm())
|
|
148
|
+
else:
|
|
149
|
+
return self
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class RandomPermuteAxesTransform(NonFinalTransform):
|
|
153
|
+
"""Randomly permute axes."""
|
|
154
|
+
|
|
155
|
+
Final = Next = PermuteAxesTransform
|
|
156
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
157
|
+
|
|
158
|
+
def __init__(
|
|
159
|
+
self,
|
|
160
|
+
axes: tx.Optional[tx.Sequence[int]] = None,
|
|
161
|
+
*,
|
|
162
|
+
shared: bool =True,
|
|
163
|
+
**kwargs
|
|
164
|
+
) -> None:
|
|
165
|
+
"""
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
axes : [list of] int
|
|
169
|
+
Axes that can be permuted (default: all)
|
|
170
|
+
|
|
171
|
+
Other Parameters
|
|
172
|
+
----------------
|
|
173
|
+
shared
|
|
174
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
175
|
+
for details.
|
|
176
|
+
returns, append, prefix, include, exclude, consume
|
|
177
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
178
|
+
"""
|
|
179
|
+
super().__init__(shared=shared, **kwargs)
|
|
180
|
+
self.axes = axes
|
|
181
|
+
|
|
182
|
+
def _unroll(self, x: Tensor, max_depth: float = inf) -> Transform:
|
|
183
|
+
if max_depth == 0:
|
|
184
|
+
return self
|
|
185
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
186
|
+
return PerChannelTransform(
|
|
187
|
+
[self.unroll(x[i:i+1], max_depth) for i in range(len(x))],
|
|
188
|
+
**self.get_prm()
|
|
189
|
+
).unroll(x, max_depth-1)
|
|
190
|
+
axes = list(self.axes or range(x.ndim-1))
|
|
191
|
+
shuffle(axes)
|
|
192
|
+
return PermuteAxesTransform(
|
|
193
|
+
axes, **self.get_prm()
|
|
194
|
+
).unroll(x, max_depth-1)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class Rot90Transform(FinalTransform):
|
|
198
|
+
"""Apply a 90 (or 180) rotation along one or several axes."""
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
axis: cct.ScalarOrSequence[int] = 0,
|
|
203
|
+
negative: cct.ScalarOrSequence[bool] = False,
|
|
204
|
+
double: cct.ScalarOrSequence[bool] = False,
|
|
205
|
+
**kwargs
|
|
206
|
+
) -> None:
|
|
207
|
+
"""
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
axis : [list of] int
|
|
211
|
+
Rotation axis (indexing does not account for the channel axis)
|
|
212
|
+
negative : [list of] bool
|
|
213
|
+
Rotate by -90 deg instead of 90 deg
|
|
214
|
+
double : [list of] bool
|
|
215
|
+
Rotate be 180 instead of 90 (`negative` is then unused)
|
|
216
|
+
|
|
217
|
+
Other Parameters
|
|
218
|
+
----------------
|
|
219
|
+
returns, append, prefix, include, exclude, consume
|
|
220
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
221
|
+
"""
|
|
222
|
+
super().__init__(**kwargs)
|
|
223
|
+
self.axis = ensure_list(axis)
|
|
224
|
+
self.negative = ensure_list(negative, len(self.axis))
|
|
225
|
+
self.double = ensure_list(double, len(self.axis))
|
|
226
|
+
|
|
227
|
+
def _xform(self, x: Tensor) -> Tensor:
|
|
228
|
+
# this implementation is suboptimal. We should fuse all transpose
|
|
229
|
+
# and all flips into a single "transpose + flip" operation so that
|
|
230
|
+
# a single allocation happens. This will be fine for now.
|
|
231
|
+
|
|
232
|
+
ndim = x.ndim - 1
|
|
233
|
+
axis = [1 + (ndim + a if a < 0 else a) for a in self.axis]
|
|
234
|
+
for ax, neg, dbl in zip(axis, self.negative, self.double):
|
|
235
|
+
if dbl:
|
|
236
|
+
if ndim == 2:
|
|
237
|
+
dims = [1, 2]
|
|
238
|
+
else:
|
|
239
|
+
assert ndim == 3
|
|
240
|
+
dims = [d for d in (1, 2, 3) if d != ax]
|
|
241
|
+
x = x.flip(dims)
|
|
242
|
+
else:
|
|
243
|
+
if ndim == 2:
|
|
244
|
+
dims = [1, 2]
|
|
245
|
+
else:
|
|
246
|
+
assert ndim == 3
|
|
247
|
+
dims = [d for d in (1, 2, 3) if d != ax]
|
|
248
|
+
x = x.transpose(*dims).flip(dims[1] if neg else dims[0])
|
|
249
|
+
return x
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class Rot180Transform(Rot90Transform):
|
|
253
|
+
"""Apply a 180 deg rotation along one or several axes"""
|
|
254
|
+
|
|
255
|
+
def __init__(self, axis: cct.ScalarOrSequence[int] = 0, **kwargs) -> None:
|
|
256
|
+
"""
|
|
257
|
+
Parameters
|
|
258
|
+
----------
|
|
259
|
+
axis : [list of] int
|
|
260
|
+
Rotation axis (indexing does not account for the channel axis)
|
|
261
|
+
|
|
262
|
+
Other Parameters
|
|
263
|
+
----------------
|
|
264
|
+
returns, append, prefix, include, exclude, consume
|
|
265
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
266
|
+
"""
|
|
267
|
+
super().__init__(axis, double=True, **kwargs)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class RandomRot90Transform(NonFinalTransform):
|
|
271
|
+
"""Random set of 90 transforms."""
|
|
272
|
+
|
|
273
|
+
Final = Next = Rot90Transform
|
|
274
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
275
|
+
|
|
276
|
+
def __init__(
|
|
277
|
+
self,
|
|
278
|
+
axes: tx.Optional[cct.ScalarOrSequence[int]] = None,
|
|
279
|
+
max_rot: cct.SamplerOrBound[int] = 2,
|
|
280
|
+
negative: bool = True,
|
|
281
|
+
*,
|
|
282
|
+
shared: cct.SharedT = True,
|
|
283
|
+
**kwargs
|
|
284
|
+
) -> None:
|
|
285
|
+
"""
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
axes : [list of] int
|
|
289
|
+
Axes along which rotations can happen.
|
|
290
|
+
If `None`, all axes.
|
|
291
|
+
max_rot : Sampler | int
|
|
292
|
+
Maximum number of consecutive rotations.
|
|
293
|
+
negative : bool
|
|
294
|
+
Whether to authorize negative rotations.
|
|
295
|
+
|
|
296
|
+
Other Parameters
|
|
297
|
+
----------------
|
|
298
|
+
shared
|
|
299
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
300
|
+
for details.
|
|
301
|
+
returns, append, prefix, include, exclude, consume
|
|
302
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
303
|
+
"""
|
|
304
|
+
super().__init__(shared=shared, **kwargs)
|
|
305
|
+
self.axes = axes
|
|
306
|
+
self.max_rot = RandInt.make(make_range(1, max_rot))
|
|
307
|
+
self.negative = negative
|
|
308
|
+
|
|
309
|
+
def _unroll(self, x, max_depth=float('inf')):
|
|
310
|
+
if max_depth == 0:
|
|
311
|
+
return self
|
|
312
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
313
|
+
return PerChannelTransform(
|
|
314
|
+
[self.unroll(x[i:i+1], max_depth) for i in range(len(x))],
|
|
315
|
+
**self.get_prm()
|
|
316
|
+
).unroll(x, max_depth-1)
|
|
317
|
+
ndim = x.ndim - 1
|
|
318
|
+
max_rot = self.max_rot
|
|
319
|
+
if isinstance(max_rot, Sampler):
|
|
320
|
+
max_rot = max_rot()
|
|
321
|
+
axes = self.axes
|
|
322
|
+
if axes is None:
|
|
323
|
+
axes = list(range(ndim))
|
|
324
|
+
if isinstance(axes, (int, list, tuple)):
|
|
325
|
+
axes = ensure_list(axes, max_rot, crop=False)
|
|
326
|
+
if not isinstance(axes, Sampler):
|
|
327
|
+
axes = RandKFrom(axes, max_rot, replacement=True)
|
|
328
|
+
|
|
329
|
+
axes = ensure_list(axes(), max_rot)
|
|
330
|
+
negative = RandKFrom([False, True], max_rot, replacement=True)() \
|
|
331
|
+
if self.negative else [False] * max_rot
|
|
332
|
+
return Rot90Transform(
|
|
333
|
+
axes, negative, **self.get_prm()
|
|
334
|
+
).unroll(max_depth-1)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class CropPadTransform(FinalTransform):
|
|
338
|
+
"""Crop and/or pad a tensor."""
|
|
339
|
+
|
|
340
|
+
def __init__(
|
|
341
|
+
self,
|
|
342
|
+
crop: tx.Sequence[slice] = (),
|
|
343
|
+
pad: tx.Sequence[int] = (),
|
|
344
|
+
bound: cct.ItemOrSequence[str] = 'zero',
|
|
345
|
+
value: Number = 0,
|
|
346
|
+
**kwargs
|
|
347
|
+
) -> None:
|
|
348
|
+
"""
|
|
349
|
+
Parameters
|
|
350
|
+
----------
|
|
351
|
+
crop : list[slice]
|
|
352
|
+
Slicing operator per dimension.
|
|
353
|
+
pad : list[int]
|
|
354
|
+
Left and right padding per dimensions
|
|
355
|
+
bound : [list of] str
|
|
356
|
+
Boundary condition for padding
|
|
357
|
+
value : number
|
|
358
|
+
Padding value in case `bound='constant`
|
|
359
|
+
|
|
360
|
+
Other Parameters
|
|
361
|
+
----------------
|
|
362
|
+
returns, append, prefix, include, exclude, consume
|
|
363
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
364
|
+
"""
|
|
365
|
+
super().__init__(**kwargs)
|
|
366
|
+
self.crop = crop
|
|
367
|
+
self.pad = pad
|
|
368
|
+
self.bound = bound
|
|
369
|
+
self.value = value
|
|
370
|
+
|
|
371
|
+
def _xform(self, x: Tensor) -> Tensor:
|
|
372
|
+
crop = tuple([Ellipsis, *self.crop])
|
|
373
|
+
x = x[crop]
|
|
374
|
+
x = pad(x, self.pad, mode=self.bound, value=self.value)
|
|
375
|
+
return x
|
|
376
|
+
|
|
377
|
+
def make_inverse(self) -> 'CropPadTransform':
|
|
378
|
+
ipad = [slice(left, (-right) or None) for left, right in self.pad]
|
|
379
|
+
icrop = [[s.start or 0, -s.stop if s.stop else 0] for s in self.crop]
|
|
380
|
+
return CropPadTransform(
|
|
381
|
+
ipad, icrop, bound=self.bound, value=self.value, **self.get_prm()
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class PatchTransform(NonFinalTransform):
|
|
386
|
+
"""Extract a patch from the volume"""
|
|
387
|
+
|
|
388
|
+
Final = Next = CropPadTransform
|
|
389
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
390
|
+
|
|
391
|
+
def __init__(
|
|
392
|
+
self,
|
|
393
|
+
shape: cct.ScalarOrSequence[int] = 64,
|
|
394
|
+
center: cct.ScalarOrSequence[float] = 0,
|
|
395
|
+
bound: cct.ItemOrSequence[str] = 'zero',
|
|
396
|
+
*,
|
|
397
|
+
shared: cct.SharedT = True,
|
|
398
|
+
**kwargs
|
|
399
|
+
) -> None:
|
|
400
|
+
"""
|
|
401
|
+
Parameters
|
|
402
|
+
----------
|
|
403
|
+
shape : [list of] int
|
|
404
|
+
Patch shape
|
|
405
|
+
center : [list of] float
|
|
406
|
+
Patch center, in relative coordinates -1..1
|
|
407
|
+
bound : [list of]str
|
|
408
|
+
Boundary condition in case padding is needed
|
|
409
|
+
|
|
410
|
+
Other Parameters
|
|
411
|
+
----------------
|
|
412
|
+
shared
|
|
413
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
414
|
+
for details.
|
|
415
|
+
|
|
416
|
+
!!! changedin " \
|
|
417
|
+
Default for `shared` changed from `"channels"` to `True`"
|
|
418
|
+
|
|
419
|
+
returns, append, prefix, include, exclude, consume
|
|
420
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
421
|
+
"""
|
|
422
|
+
kwargs.setdefault('shared', shared)
|
|
423
|
+
super().__init__(**kwargs)
|
|
424
|
+
self.shape = shape
|
|
425
|
+
self.center = center
|
|
426
|
+
self.bound = bound
|
|
427
|
+
|
|
428
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
429
|
+
if max_depth == 0:
|
|
430
|
+
return self
|
|
431
|
+
ndim = x.dim() - 1
|
|
432
|
+
shape = ensure_list(self.shape, ndim)
|
|
433
|
+
center = ensure_list(self.center, ndim)
|
|
434
|
+
center = [(c + 1) / 2 * (s - 1) for c, s in zip(center, x.shape[1:])]
|
|
435
|
+
crop = []
|
|
436
|
+
padding = []
|
|
437
|
+
for ss, cc, sv in zip(shape, center, x.shape[1:]):
|
|
438
|
+
first = int(math.floor(cc - ss/2))
|
|
439
|
+
pad_first = max(0, -first)
|
|
440
|
+
last = first + ss
|
|
441
|
+
pad_last = max(0, last - sv)
|
|
442
|
+
first = max(0, first)
|
|
443
|
+
last = min(sv, last)
|
|
444
|
+
last = (last - sv) or None # ensure negative for CropPad
|
|
445
|
+
crop.append(slice(first, last))
|
|
446
|
+
padding.extend([pad_first, pad_last])
|
|
447
|
+
return CropPadTransform(
|
|
448
|
+
crop, padding, bound=self.bound, **self.get_prm()
|
|
449
|
+
).unroll(x, max_depth-1)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
class RandomPatchTransform(NonFinalTransform):
|
|
453
|
+
"""Extract a (randomly located) patch from the volume.
|
|
454
|
+
|
|
455
|
+
This transform ensures that the patch is fully contained within the
|
|
456
|
+
original field of view (unless the patch size is larger than the
|
|
457
|
+
input shape).
|
|
458
|
+
"""
|
|
459
|
+
|
|
460
|
+
Next = PatchTransform
|
|
461
|
+
"""The transform type returned by `next`."""
|
|
462
|
+
|
|
463
|
+
Final = CropPadTransform
|
|
464
|
+
"""The transform type returned by `final`."""
|
|
465
|
+
|
|
466
|
+
def __init__(
|
|
467
|
+
self,
|
|
468
|
+
shape: cct.ScalarOrSequence[int],
|
|
469
|
+
bound: cct.ItemOrSequence[str] = 'zero',
|
|
470
|
+
*,
|
|
471
|
+
shared: cct.SharedT = True,
|
|
472
|
+
**kwargs
|
|
473
|
+
) -> None:
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
shape : [list of] int
|
|
479
|
+
Patch shape
|
|
480
|
+
bound : [list of] str
|
|
481
|
+
Boundary condition in case padding is needed
|
|
482
|
+
|
|
483
|
+
Other Parameters
|
|
484
|
+
----------------
|
|
485
|
+
shared
|
|
486
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
487
|
+
for details.
|
|
488
|
+
|
|
489
|
+
!!! changedin " \
|
|
490
|
+
Default for `shared` changed from `"channels"` to `True`"
|
|
491
|
+
returns, append, prefix, include, exclude, consume
|
|
492
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
493
|
+
"""
|
|
494
|
+
shape = kwargs.pop('patch_size', shape) # support legacy name
|
|
495
|
+
kwargs.setdefault('shared', shared)
|
|
496
|
+
super().__init__(**kwargs)
|
|
497
|
+
self.shape = shape
|
|
498
|
+
self.bound = bound
|
|
499
|
+
|
|
500
|
+
def _unroll(self, x, max_depth=float('inf')):
|
|
501
|
+
if max_depth == 0:
|
|
502
|
+
return self
|
|
503
|
+
shape = x.shape[1:]
|
|
504
|
+
patch_size = ensure_list(self.shape, len(shape))
|
|
505
|
+
min_center = [max(p/s - 1, -1) for p, s in zip(patch_size, shape)]
|
|
506
|
+
max_center = [min(1 - p/s, 1) for p, s in zip(patch_size, shape)]
|
|
507
|
+
center = [Uniform(mn, mx)() for mn, mx in zip(min_center, max_center)]
|
|
508
|
+
return PatchTransform(
|
|
509
|
+
patch_size, center, self.bound, **self.get_prm()
|
|
510
|
+
).unroll(x, max_depth-1)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
class CropTransform(NonFinalTransform):
|
|
514
|
+
"""Crop a tensor by some amount"""
|
|
515
|
+
|
|
516
|
+
def __init__(
|
|
517
|
+
self,
|
|
518
|
+
cropping: cct.ScalarOrSequence[tx.Union[int, float]],
|
|
519
|
+
unit: str = 'vox',
|
|
520
|
+
side: str = 'both',
|
|
521
|
+
**kwargs
|
|
522
|
+
) -> None:
|
|
523
|
+
"""
|
|
524
|
+
|
|
525
|
+
Parameters
|
|
526
|
+
----------
|
|
527
|
+
cropping : [list of] int or float
|
|
528
|
+
Amount of cropping. If `side` is `None`, pre and post cropping
|
|
529
|
+
must be provided in turn.
|
|
530
|
+
unit : {'vox', 'pct'}
|
|
531
|
+
Padding unit
|
|
532
|
+
side : {'pre', 'post', 'both', None}
|
|
533
|
+
Side to crop
|
|
534
|
+
|
|
535
|
+
Other Parameters
|
|
536
|
+
----------------
|
|
537
|
+
shared
|
|
538
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
539
|
+
for details.
|
|
540
|
+
|
|
541
|
+
!!! changedin " \
|
|
542
|
+
Default for `shared` changed from `"channels"` to `True`"
|
|
543
|
+
returns, append, prefix, include, exclude, consume
|
|
544
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
545
|
+
"""
|
|
546
|
+
kwargs.setdefault('shared', True)
|
|
547
|
+
super().__init__(**kwargs)
|
|
548
|
+
self.cropping = cropping
|
|
549
|
+
self.unit = unit
|
|
550
|
+
self.side = side
|
|
551
|
+
|
|
552
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
553
|
+
if max_depth == 0:
|
|
554
|
+
return self
|
|
555
|
+
ndim = x.dim() - 1
|
|
556
|
+
cropping = self.cropping
|
|
557
|
+
if self.side is not None:
|
|
558
|
+
cropping = ensure_list(cropping, ndim)
|
|
559
|
+
if self.unit[0] == 'p':
|
|
560
|
+
cropping = [int(math.ceil(c * s))
|
|
561
|
+
for c, s in zip(cropping, x.shape[1:])]
|
|
562
|
+
cropping = [slice(c, -c if c else None) for c in cropping]
|
|
563
|
+
else:
|
|
564
|
+
cropping = ensure_list(cropping)
|
|
565
|
+
cropping = [0] * (2*ndim - len(cropping))
|
|
566
|
+
if self.unit[0] == 'p':
|
|
567
|
+
shape2 = [s for s in x.shape[1:] for _ in range(2)]
|
|
568
|
+
cropping = [int(math.ceil(c * s))
|
|
569
|
+
for c, s in zip(cropping, shape2)]
|
|
570
|
+
cropping = [slice(c0, -c1 if c1 else None)
|
|
571
|
+
for c0, c1 in zip(cropping[::2], cropping[1::2])]
|
|
572
|
+
return CropPadTransform(
|
|
573
|
+
cropping, [0]*(2*ndim), **self.get_prm()
|
|
574
|
+
).unroll(x, max_depth-1)
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class PadTransform(NonFinalTransform):
|
|
578
|
+
"""Pad a tensor by some amount"""
|
|
579
|
+
|
|
580
|
+
Final = Next = CropPadTransform
|
|
581
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
582
|
+
|
|
583
|
+
def __init__(
|
|
584
|
+
self,
|
|
585
|
+
padding: cct.ScalarOrSequence[tx.Union[int, float]],
|
|
586
|
+
unit: str = 'vox',
|
|
587
|
+
side: str = 'both',
|
|
588
|
+
bound: str = 'zero',
|
|
589
|
+
value: Number = 0,
|
|
590
|
+
**kwargs
|
|
591
|
+
) -> None:
|
|
592
|
+
"""
|
|
593
|
+
|
|
594
|
+
Parameters
|
|
595
|
+
----------
|
|
596
|
+
padding : [list of] int or float
|
|
597
|
+
Amount of padding. If `side` is `None`, pre and post padding
|
|
598
|
+
must be provided in turn.
|
|
599
|
+
unit : {'vox', 'pct'}
|
|
600
|
+
Padding unit
|
|
601
|
+
side : {'pre', 'post', 'both', None}
|
|
602
|
+
Side to pad
|
|
603
|
+
bound : str
|
|
604
|
+
Boundary condition
|
|
605
|
+
value : float
|
|
606
|
+
Value for case `bound='constant'`
|
|
607
|
+
|
|
608
|
+
Other Parameters
|
|
609
|
+
----------------
|
|
610
|
+
shared
|
|
611
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
612
|
+
for details.
|
|
613
|
+
|
|
614
|
+
!!! changedin " \
|
|
615
|
+
Default for `shared` changed from `"channels"` to `True`"
|
|
616
|
+
returns, append, prefix, include, exclude, consume
|
|
617
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
618
|
+
"""
|
|
619
|
+
kwargs.setdefault('shared', True)
|
|
620
|
+
super().__init__(**kwargs)
|
|
621
|
+
self.padding = padding
|
|
622
|
+
self.unit = unit
|
|
623
|
+
self.side = side
|
|
624
|
+
self.bound = bound
|
|
625
|
+
self.value = value
|
|
626
|
+
|
|
627
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
628
|
+
if max_depth == 0:
|
|
629
|
+
return self
|
|
630
|
+
ndim = x.dim() - 1
|
|
631
|
+
padding = self.padding
|
|
632
|
+
if self.side is not None:
|
|
633
|
+
padding = ensure_list(padding, ndim)
|
|
634
|
+
if self.unit[0] == 'p':
|
|
635
|
+
padding = [int(math.ceil(p * s))
|
|
636
|
+
for p, s in zip(padding, x.shape[1:])]
|
|
637
|
+
|
|
638
|
+
else:
|
|
639
|
+
padding = ensure_list(padding)
|
|
640
|
+
padding = [0] * (2 * ndim - len(padding)) + padding
|
|
641
|
+
if self.unit[0] == 'p':
|
|
642
|
+
shape2 = [s for s in x.shape[1:] for _ in range(2)]
|
|
643
|
+
padding = [int(math.ceil(p * s))
|
|
644
|
+
for p, s in zip(padding, shape2)]
|
|
645
|
+
|
|
646
|
+
if self.side == 'pre':
|
|
647
|
+
padding = [p for pz in zip(padding, [0]*ndim) for p in pz]
|
|
648
|
+
elif self.side == 'post':
|
|
649
|
+
padding = [p for zp in zip([0]*ndim, padding) for p in zp]
|
|
650
|
+
elif self.side == 'both':
|
|
651
|
+
padding = [p for pp in zip(padding, padding) for p in pp]
|
|
652
|
+
|
|
653
|
+
return CropPadTransform(
|
|
654
|
+
[slice(None)]*ndim, padding, bound=self.bound, value=self.value,
|
|
655
|
+
**self.get_prm()
|
|
656
|
+
).unroll(x, max_depth-1)
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
class PowerTwoTransform(NonFinalTransform):
|
|
660
|
+
"""Pad the volume such that the tensor shape can be divided by 2**x"""
|
|
661
|
+
|
|
662
|
+
Next = PatchTransform
|
|
663
|
+
"""The transform type returned by `next`."""
|
|
664
|
+
|
|
665
|
+
Final = CropPadTransform
|
|
666
|
+
"""The transform type returned by `final`."""
|
|
667
|
+
|
|
668
|
+
def __init__(
|
|
669
|
+
self,
|
|
670
|
+
exponent: cct.ScalarOrSequence[int] = 1,
|
|
671
|
+
bound: cct.ItemOrSequence[str] = 'zero',
|
|
672
|
+
**kwargs
|
|
673
|
+
) -> None:
|
|
674
|
+
"""
|
|
675
|
+
|
|
676
|
+
Parameters
|
|
677
|
+
----------
|
|
678
|
+
exponent : [list of] int
|
|
679
|
+
Ensure that the shape can be divided by 2 ** exponent
|
|
680
|
+
bound : [list of] str
|
|
681
|
+
Boundary condition for padding
|
|
682
|
+
|
|
683
|
+
Other Parameters
|
|
684
|
+
----------------
|
|
685
|
+
shared
|
|
686
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
687
|
+
for details.
|
|
688
|
+
|
|
689
|
+
!!! changedin " \
|
|
690
|
+
Default for `shared` changed from `"channels"` to `True`"
|
|
691
|
+
returns, append, prefix, include, exclude, consume
|
|
692
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
693
|
+
"""
|
|
694
|
+
kwargs.setdefault('shared', True)
|
|
695
|
+
super().__init__(**kwargs)
|
|
696
|
+
self.exponent = exponent
|
|
697
|
+
self.bound = bound
|
|
698
|
+
|
|
699
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
700
|
+
if max_depth == 0:
|
|
701
|
+
return self
|
|
702
|
+
shape = x.shape[1:]
|
|
703
|
+
exponent = ensure_list(self.exponent, len(shape))
|
|
704
|
+
bigshape = [max(2 ** e, s) for e, s in zip(exponent, shape)]
|
|
705
|
+
return PatchTransform(
|
|
706
|
+
bigshape, bound=self.bound, **self.get_prm()
|
|
707
|
+
).unroll(x, max_depth-1)
|