cornucopia 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornucopia/__init__.py +73 -0
- cornucopia/base.py +1915 -0
- cornucopia/baseutils.py +575 -0
- cornucopia/contrast.py +260 -0
- cornucopia/ctx.py +25 -0
- cornucopia/fov.py +707 -0
- cornucopia/geometric.py +2068 -0
- cornucopia/intensity.py +1358 -0
- cornucopia/io.py +161 -0
- cornucopia/kspace.py +505 -0
- cornucopia/labels.py +1872 -0
- cornucopia/noise.py +508 -0
- cornucopia/psf.py +463 -0
- cornucopia/qmri.py +1288 -0
- cornucopia/random.py +1480 -0
- cornucopia/special.py +159 -0
- cornucopia/synth.py +708 -0
- cornucopia/tests/__init__.py +0 -0
- cornucopia/tests/test_backward_geometric.py +173 -0
- cornucopia/tests/test_backward_intensity.py +243 -0
- cornucopia/tests/test_backward_kspace.py +115 -0
- cornucopia/tests/test_backward_noise.py +169 -0
- cornucopia/tests/test_backward_psf.py +142 -0
- cornucopia/tests/test_backward_qmri.py +249 -0
- cornucopia/tests/test_backward_random.py +44 -0
- cornucopia/tests/test_backward_synth.py +72 -0
- cornucopia/tests/test_base.py +401 -0
- cornucopia/tests/test_geometric.py +26 -0
- cornucopia/tests/test_intensity.py +9 -0
- cornucopia/tests/test_random.py +722 -0
- cornucopia/tests/test_run_contrast.py +28 -0
- cornucopia/tests/test_run_fov.py +132 -0
- cornucopia/tests/test_run_geometric.py +157 -0
- cornucopia/tests/test_run_intensity.py +192 -0
- cornucopia/tests/test_run_kspace.py +70 -0
- cornucopia/tests/test_run_labels.py +224 -0
- cornucopia/tests/test_run_noise.py +127 -0
- cornucopia/tests/test_run_psf.py +115 -0
- cornucopia/tests/test_run_qmri.py +114 -0
- cornucopia/tests/test_run_synth.py +67 -0
- cornucopia/typing.py +97 -0
- cornucopia/utils/__init__.py +0 -0
- cornucopia/utils/b0.py +745 -0
- cornucopia/utils/bounds.py +412 -0
- cornucopia/utils/compat.py +47 -0
- cornucopia/utils/conv.py +305 -0
- cornucopia/utils/gmm.py +169 -0
- cornucopia/utils/indexing.py +911 -0
- cornucopia/utils/io.py +258 -0
- cornucopia/utils/jit.py +128 -0
- cornucopia/utils/kernels.py +288 -0
- cornucopia/utils/morpho.py +234 -0
- cornucopia/utils/mrf.py +574 -0
- cornucopia/utils/padding.py +173 -0
- cornucopia/utils/patch.py +302 -0
- cornucopia/utils/pool.py +282 -0
- cornucopia/utils/py.py +348 -0
- cornucopia/utils/smart_inplace.py +163 -0
- cornucopia/utils/version.py +57 -0
- cornucopia/utils/warps.py +606 -0
- cornucopia-0.0.0.dist-info/METADATA +92 -0
- cornucopia-0.0.0.dist-info/RECORD +65 -0
- cornucopia-0.0.0.dist-info/WHEEL +5 -0
- cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
- cornucopia-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,606 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.nn import functional as F
|
|
3
|
+
from .py import ensure_list, cartesian_grid, meshgrid_ij
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def add_identity_(flow):
|
|
7
|
+
"""Adds the identity grid to a displacement field, inplace.
|
|
8
|
+
|
|
9
|
+
Parameters
|
|
10
|
+
----------
|
|
11
|
+
flow : (..., *shape, dim) tensor
|
|
12
|
+
Displacement field
|
|
13
|
+
|
|
14
|
+
Returns
|
|
15
|
+
-------
|
|
16
|
+
flow : (..., *shape, dim) tensor
|
|
17
|
+
Transformation field
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
dim = flow.shape[-1]
|
|
21
|
+
spatial = flow.shape[-dim-1:-1]
|
|
22
|
+
grid = cartesian_grid(spatial, dtype=flow.dtype, device=flow.device)
|
|
23
|
+
flow = flow.movedim(-1, 0)
|
|
24
|
+
for i, grid1 in enumerate(grid):
|
|
25
|
+
flow[i].add_(grid1)
|
|
26
|
+
flow = flow.movedim(0, -1)
|
|
27
|
+
return flow
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def sub_identity_(flow):
|
|
31
|
+
"""Subtracts the identity grid from a transformation field, inplace.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
flow : (..., *shape, dim) tensor
|
|
36
|
+
Transformation field
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
flow : (..., *shape, dim) tensor
|
|
41
|
+
Displacement field
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
dim = flow.shape[-1]
|
|
45
|
+
spatial = flow.shape[-dim-1:-1]
|
|
46
|
+
grid = cartesian_grid(spatial, dtype=flow.dtype, device=flow.device)
|
|
47
|
+
flow = flow.movedim(-1, 0)
|
|
48
|
+
for i, grid1 in enumerate(grid):
|
|
49
|
+
flow[i].sub_(grid1)
|
|
50
|
+
flow = flow.movedim(0, -1)
|
|
51
|
+
return flow
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def add_identity(flow):
|
|
55
|
+
"""Adds the identity grid to a displacement field.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
flow : (..., *shape, dim) tensor
|
|
60
|
+
Displacement field
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
flow : (..., *shape, dim) tensor
|
|
65
|
+
Transformation field
|
|
66
|
+
|
|
67
|
+
"""
|
|
68
|
+
return add_identity_(flow.clone())
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def sub_identity(flow):
|
|
72
|
+
"""Subtracts the identity grid from a transformation field.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
flow : (..., *shape, dim) tensor
|
|
77
|
+
Transformation field
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
flow : (..., *shape, dim) tensor
|
|
82
|
+
Displacement field
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
return sub_identity_(flow.clone())
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def identity(shape, **backend):
|
|
89
|
+
"""Returns an identity transformation field.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
shape : (dim,) sequence of int
|
|
94
|
+
Spatial dimension of the field.
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
grid : (*shape, dim) tensor
|
|
99
|
+
Transformation field
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
backend.setdefault('dtype', torch.get_default_dtype())
|
|
103
|
+
return torch.stack(cartesian_grid(shape, **backend), dim=-1)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def affine_flow(affine, shape, with_identity=False):
|
|
107
|
+
"""Generate an affine flow field
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
affine : ([B], D+1, D+1) tensor
|
|
112
|
+
Affine matrix
|
|
113
|
+
shape : (D,) list[int]
|
|
114
|
+
Lattice size
|
|
115
|
+
with_identity : bool, default=False
|
|
116
|
+
If True, the returned flow contains absolute coordinates.
|
|
117
|
+
If False, the returned flow contains relative displacements.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
flow : ([B], *shape, D) tensor, Affine flow
|
|
122
|
+
|
|
123
|
+
"""
|
|
124
|
+
ndim = len(shape)
|
|
125
|
+
backend = dict(dtype=affine.dtype, device=affine.device)
|
|
126
|
+
|
|
127
|
+
# add spatial dimensions so that we can use batch matmul
|
|
128
|
+
for _ in range(ndim):
|
|
129
|
+
affine = affine.unsqueeze(-3)
|
|
130
|
+
lin, trl = affine[..., :ndim, :ndim], affine[..., :ndim, -1]
|
|
131
|
+
|
|
132
|
+
# create affine transform
|
|
133
|
+
flow = identity(shape, **backend)
|
|
134
|
+
flow = lin.matmul(flow.unsqueeze(-1)).squeeze(-1)
|
|
135
|
+
flow = flow.add_(trl)
|
|
136
|
+
|
|
137
|
+
# subtract identity to get a flow
|
|
138
|
+
if not with_identity:
|
|
139
|
+
flow = sub_identity_(flow)
|
|
140
|
+
|
|
141
|
+
return flow
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def flow_to_torch(flow, shape, align_corners=True, has_identity=False):
|
|
145
|
+
"""Convert a voxel displacement field to a torch sampling grid
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
flow : (..., *shape, D) tensor
|
|
150
|
+
Displacement field
|
|
151
|
+
shape : list[int] tensor
|
|
152
|
+
Spatial shape of the input image
|
|
153
|
+
align_corners : bool, default=True
|
|
154
|
+
Torch's grid mode
|
|
155
|
+
has_identity : bool, default=False
|
|
156
|
+
If False, `flow` is contains relative displacement.
|
|
157
|
+
If False, `flow` contains absolute coordinates.
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
grid : (..., *shape, D) tensor
|
|
162
|
+
Sampling grid to be used with torch's `grid_sample`
|
|
163
|
+
|
|
164
|
+
"""
|
|
165
|
+
backend = dict(dtype=flow.dtype, device=flow.device)
|
|
166
|
+
# 1) reverse last dimension
|
|
167
|
+
flow = torch.flip(flow, [-1])
|
|
168
|
+
# 2) add identity grid
|
|
169
|
+
if not has_identity:
|
|
170
|
+
grid = cartesian_grid(shape, **backend)
|
|
171
|
+
grid = list(reversed(grid))
|
|
172
|
+
for d, g in enumerate(grid):
|
|
173
|
+
flow[..., d].add_(g)
|
|
174
|
+
shape = list(reversed(shape))
|
|
175
|
+
# 3) convert coordinates
|
|
176
|
+
for d, s in enumerate(shape):
|
|
177
|
+
if align_corners:
|
|
178
|
+
# (0, N-1) -> (-1, 1)
|
|
179
|
+
flow[..., d].mul_(2/(s-1)).add_(-1)
|
|
180
|
+
else:
|
|
181
|
+
# (-0.5, N-0.5) -> (-1, 1)
|
|
182
|
+
flow[..., d].mul_(2/s).add_(1/s-1)
|
|
183
|
+
return flow
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def apply_flow(image, flow, has_identity=False, **kwargs):
|
|
187
|
+
"""Warp an image according to a (voxel) displacement field.
|
|
188
|
+
|
|
189
|
+
Parameters
|
|
190
|
+
----------
|
|
191
|
+
image : (B, C, *shape_in) tensor
|
|
192
|
+
Input image.
|
|
193
|
+
If input dtype is integer, assumes labels: each unique labels
|
|
194
|
+
gets warped using linear interpolation, and the label map gets
|
|
195
|
+
reconstructed by argmax.
|
|
196
|
+
flow : ([B], *shape_out, D) tensor
|
|
197
|
+
Displacement field, in voxels.
|
|
198
|
+
Note that the order of the last dimension is inverse of what's
|
|
199
|
+
usually expected in torch's grid_sample.
|
|
200
|
+
has_identity : bool, default=False
|
|
201
|
+
- If False, `flow` is contains relative displacement.
|
|
202
|
+
- If True, `flow` contains absolute coordinates.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
warped : (B, C, *shape_out) tensor
|
|
207
|
+
Warped image
|
|
208
|
+
|
|
209
|
+
"""
|
|
210
|
+
kwargs.setdefault('align_corners', True)
|
|
211
|
+
B, C, *shape_in = image.shape
|
|
212
|
+
D = flow.shape[-1]
|
|
213
|
+
if flow.dim() == D+1:
|
|
214
|
+
flow = flow[None]
|
|
215
|
+
shape_out = flow.shape[1:-1]
|
|
216
|
+
flow = flow_to_torch(flow, shape_in,
|
|
217
|
+
align_corners=kwargs['align_corners'],
|
|
218
|
+
has_identity=has_identity)
|
|
219
|
+
B = max(len(flow), len(image))
|
|
220
|
+
if len(flow) != B:
|
|
221
|
+
flow = flow.expand([B, *flow.shape[1:]])
|
|
222
|
+
if len(image) != B:
|
|
223
|
+
image = image.expand([B, *image.shape[1:]])
|
|
224
|
+
nn = kwargs.get('mode', 'bilinear') == 'nearest'
|
|
225
|
+
if not image.dtype.is_floating_point and not nn:
|
|
226
|
+
vmax = flow.new_full([B, C, *shape_out], -float('inf'))
|
|
227
|
+
warped = image.new_zeros([B, C, *shape_out])
|
|
228
|
+
for label in image.unique():
|
|
229
|
+
w = F.grid_sample((image == label).to(flow), flow, **kwargs)
|
|
230
|
+
warped[w > vmax] = label
|
|
231
|
+
vmax = torch.maximum(vmax, w)
|
|
232
|
+
return warped
|
|
233
|
+
else:
|
|
234
|
+
dtype = image.dtype
|
|
235
|
+
if not dtype.is_floating_point:
|
|
236
|
+
image = image.float()
|
|
237
|
+
return F.grid_sample(image, flow, **kwargs).to(dtype)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def downsample(image, factor=None, shape=None, anchor='center'):
|
|
241
|
+
"""Downsample using centers or edges of the corner voxels as anchors.
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
image : (B, C, *shape_in) tensor
|
|
246
|
+
factor OR shape : int or list[int]
|
|
247
|
+
anchor : {'center', 'edge'} tensor
|
|
248
|
+
|
|
249
|
+
Returns
|
|
250
|
+
-------
|
|
251
|
+
image : (B, C, *shape_out)
|
|
252
|
+
|
|
253
|
+
"""
|
|
254
|
+
if shape and factor:
|
|
255
|
+
raise ValueError('Only one of `shape` and `factor` should be used.')
|
|
256
|
+
ndim = image.dim() - 2
|
|
257
|
+
mode = 'linear' if ndim == 1 else 'bilinear' if ndim == 2 else 'trilinear'
|
|
258
|
+
align_corners = (anchor[0].lower() == 'c')
|
|
259
|
+
recompute_scale_factor = factor is not None
|
|
260
|
+
if factor:
|
|
261
|
+
if isinstance(factor, (list, tuple)):
|
|
262
|
+
factor = [1/f for f in factor]
|
|
263
|
+
else:
|
|
264
|
+
factor = 1/factor
|
|
265
|
+
image = F.interpolate(image, size=shape, scale_factor=factor,
|
|
266
|
+
mode=mode, align_corners=align_corners,
|
|
267
|
+
recompute_scale_factor=recompute_scale_factor)
|
|
268
|
+
return image
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def upsample(image, factor=None, shape=None, anchor='center'):
|
|
272
|
+
"""Upsample using centers or edges of the corner voxels as anchors.
|
|
273
|
+
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
image : (B, C, *shape_in) tensor
|
|
277
|
+
factor OR shape : int or list[int]
|
|
278
|
+
anchor : {'center', 'edge'}
|
|
279
|
+
|
|
280
|
+
Returns
|
|
281
|
+
-------
|
|
282
|
+
image : (B, C, *shape_out) tensor
|
|
283
|
+
|
|
284
|
+
"""
|
|
285
|
+
if shape and factor:
|
|
286
|
+
raise ValueError('Only one of `shape` and `factor` should be used.')
|
|
287
|
+
ndim = image.dim() - 2
|
|
288
|
+
mode = 'linear' if ndim == 1 else 'bilinear' if ndim == 2 else 'trilinear'
|
|
289
|
+
align_corners = (anchor[0].lower() == 'c')
|
|
290
|
+
recompute_scale_factor = factor is not None
|
|
291
|
+
image = F.interpolate(image, size=shape, scale_factor=factor,
|
|
292
|
+
mode=mode, align_corners=align_corners,
|
|
293
|
+
recompute_scale_factor=recompute_scale_factor)
|
|
294
|
+
return image
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def downsample_flow(flow, factor=None, shape=None, anchor='center'):
|
|
298
|
+
"""Downsample a flow field using centers or edges of the corner
|
|
299
|
+
voxels as anchors.
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
flow : (B, *shape_in, D) tensor
|
|
304
|
+
factor OR shape : int or list[int]
|
|
305
|
+
anchor : {'center', 'edge'}
|
|
306
|
+
|
|
307
|
+
Returns
|
|
308
|
+
-------
|
|
309
|
+
flow : (B, *shape_out, D) tensor
|
|
310
|
+
|
|
311
|
+
"""
|
|
312
|
+
shape_in = flow.shape[1:-1]
|
|
313
|
+
|
|
314
|
+
# downsample flow
|
|
315
|
+
flow = flow.movedim(-1, 1)
|
|
316
|
+
flow = downsample(flow, factor, shape, anchor)
|
|
317
|
+
flow = flow.movedim(1, -1)
|
|
318
|
+
|
|
319
|
+
# compute scale
|
|
320
|
+
shape_out = flow.shape[1:-1]
|
|
321
|
+
if anchor[0] == 'c':
|
|
322
|
+
factor = [(fout - 1) / (fin - 1)
|
|
323
|
+
for fout, fin in zip(shape_out, shape_in)]
|
|
324
|
+
else:
|
|
325
|
+
factor = [fout / fin
|
|
326
|
+
for fout, fin in zip(shape_out, shape_in)]
|
|
327
|
+
|
|
328
|
+
# rescale displacement
|
|
329
|
+
ndim = flow.dim() - 2
|
|
330
|
+
for d in range(ndim):
|
|
331
|
+
flow[..., d] /= factor[d]
|
|
332
|
+
|
|
333
|
+
return flow
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def upsample_flow(flow, factor=None, shape=None, anchor='center'):
|
|
337
|
+
"""Upsample a flow field using centers or edges of the corner
|
|
338
|
+
voxels as anchors.
|
|
339
|
+
|
|
340
|
+
Parameters
|
|
341
|
+
----------
|
|
342
|
+
flow : (B, *shape_in, D) tensor
|
|
343
|
+
factor OR shape : int or list[int]
|
|
344
|
+
anchor : {'center', 'edge'}
|
|
345
|
+
|
|
346
|
+
Returns
|
|
347
|
+
-------
|
|
348
|
+
flow : (B, *shape_out, D) tensor
|
|
349
|
+
|
|
350
|
+
"""
|
|
351
|
+
shape_in = flow.shape[1:-1]
|
|
352
|
+
|
|
353
|
+
# upsample flow
|
|
354
|
+
flow = flow.movedim(-1, 1)
|
|
355
|
+
flow = upsample(flow, factor, shape, anchor)
|
|
356
|
+
flow = flow.movedim(1, -1)
|
|
357
|
+
|
|
358
|
+
# compute scale
|
|
359
|
+
shape_out = flow.shape[1:-1]
|
|
360
|
+
if anchor[0] == 'c':
|
|
361
|
+
factor = [(fout - 1) / (fin - 1)
|
|
362
|
+
for fout, fin in zip(shape_out, shape_in)]
|
|
363
|
+
else:
|
|
364
|
+
factor = [fout / fin
|
|
365
|
+
for fout, fin in zip(shape_out, shape_in)]
|
|
366
|
+
|
|
367
|
+
# rescale displacement
|
|
368
|
+
ndim = flow.dim() - 2
|
|
369
|
+
for d in range(ndim):
|
|
370
|
+
flow[..., d] /= factor[d]
|
|
371
|
+
|
|
372
|
+
return flow
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def downsample_convlike(image, kernel_size, stride, padding=0):
|
|
376
|
+
"""Downsample using the same alignment pattern as a strided convolution
|
|
377
|
+
|
|
378
|
+
Parameters
|
|
379
|
+
----------
|
|
380
|
+
image : (B, C, *shape_in) tensor
|
|
381
|
+
kernel_size : int or list[int]
|
|
382
|
+
stride : int or list[int]
|
|
383
|
+
padding : int or list[int]
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
image : (B, C, *shape_out) tensor
|
|
388
|
+
|
|
389
|
+
"""
|
|
390
|
+
shape_in = image.shape[2:]
|
|
391
|
+
ndim = image.dim() - 2
|
|
392
|
+
kernel_size = ensure_list(kernel_size, ndim)
|
|
393
|
+
stride = ensure_list(stride, ndim)
|
|
394
|
+
padding = ensure_list(padding, ndim)
|
|
395
|
+
|
|
396
|
+
# create sampling grid
|
|
397
|
+
backend = dict(dtype=image.dtype, device=image.device)
|
|
398
|
+
shape_out = [(l + 2 * p - k)//s + 1 for l, k, s, p
|
|
399
|
+
in zip(shape_in, kernel_size, stride, padding)]
|
|
400
|
+
|
|
401
|
+
flow = [torch.arange(s, **backend) for s in shape_out]
|
|
402
|
+
for f, k, s, p in zip(flow, kernel_size, stride, padding):
|
|
403
|
+
f.mul_(s).add_((k-1-p)/2)
|
|
404
|
+
flow = torch.stack(meshgrid_ij(*flow), -1)
|
|
405
|
+
|
|
406
|
+
# interpolate
|
|
407
|
+
return apply_flow(image, flow[None], has_identity=True)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def downsample_flow_convlike(flow, kernel_size, stride, padding=0):
|
|
411
|
+
"""Downsample a flow field using the same alignment pattern as a
|
|
412
|
+
strided convolution
|
|
413
|
+
|
|
414
|
+
Parameters
|
|
415
|
+
----------
|
|
416
|
+
flow : (B, *shape_in, D) tensor
|
|
417
|
+
Input image
|
|
418
|
+
kernel_size : int or list[int]
|
|
419
|
+
Kernel size of the equivalent convolution
|
|
420
|
+
stride : int or list[int]
|
|
421
|
+
Stride of the equivalent convolution
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
flow : (B, *shape_out, D) tensor
|
|
426
|
+
|
|
427
|
+
"""
|
|
428
|
+
# downsample flow
|
|
429
|
+
flow = flow.movedim(-1, 1)
|
|
430
|
+
flow = downsample_convlike(flow, kernel_size, stride, padding)
|
|
431
|
+
flow = flow.movedim(1, -1)
|
|
432
|
+
|
|
433
|
+
# rescale displacement
|
|
434
|
+
ndim = flow.dim() - 2
|
|
435
|
+
stride = ensure_list(stride, ndim)
|
|
436
|
+
for d in range(ndim):
|
|
437
|
+
flow[..., d] /= stride[d]
|
|
438
|
+
|
|
439
|
+
return flow
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def upsample_convlike(image, kernel_size, stride, padding=0, shape=None):
|
|
443
|
+
"""Upsample using the same alignment pattern as a transposed convolution
|
|
444
|
+
|
|
445
|
+
Parameters
|
|
446
|
+
----------
|
|
447
|
+
image : (B, C, *shape_in) tensor
|
|
448
|
+
kernel_size : int or list[int]
|
|
449
|
+
stride : int or list[int]
|
|
450
|
+
shape : int or list[int]
|
|
451
|
+
|
|
452
|
+
Returns
|
|
453
|
+
-------
|
|
454
|
+
image : (B, C, *shape_out) tensor
|
|
455
|
+
|
|
456
|
+
"""
|
|
457
|
+
shape_in = image.shape[2:]
|
|
458
|
+
ndim = image.dim() - 2
|
|
459
|
+
kernel_size = ensure_list(kernel_size, ndim)
|
|
460
|
+
stride = ensure_list(stride, ndim)
|
|
461
|
+
padding = ensure_list(padding, ndim)
|
|
462
|
+
if shape:
|
|
463
|
+
shape = ensure_list(shape, ndim)
|
|
464
|
+
|
|
465
|
+
# create sampling grid
|
|
466
|
+
backend = dict(dtype=image.dtype, device=image.device)
|
|
467
|
+
if not shape:
|
|
468
|
+
shape = [(l - 1) * s - 2 * p + k for l, k, s, p
|
|
469
|
+
in zip(shape_in, kernel_size, stride, padding)]
|
|
470
|
+
|
|
471
|
+
flow = [torch.arange(s, **backend) for s in shape]
|
|
472
|
+
for f, k, s, p in zip(flow, kernel_size, stride, padding):
|
|
473
|
+
f.sub_((k-1-p)/2).div_(s)
|
|
474
|
+
flow = torch.stack(meshgrid_ij(*flow), -1)
|
|
475
|
+
|
|
476
|
+
# interpolate
|
|
477
|
+
return apply_flow(image, flow[None], has_identity=True)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def upsample_flow_convlike(flow, kernel_size, stride, padding=0, shape=None):
|
|
481
|
+
"""Upsample a flow field using the same alignment pattern as a
|
|
482
|
+
transposed convolution
|
|
483
|
+
|
|
484
|
+
Parameters
|
|
485
|
+
----------
|
|
486
|
+
flow : (B, *shape_in, D) tensor
|
|
487
|
+
Input image
|
|
488
|
+
kernel_size : int or list[int]
|
|
489
|
+
Kernel size of the equivalent convolution
|
|
490
|
+
stride : int or list[int]
|
|
491
|
+
Stride of the equivalent convolution
|
|
492
|
+
|
|
493
|
+
Returns
|
|
494
|
+
-------
|
|
495
|
+
flow : (B, *shape_out, D) tensor
|
|
496
|
+
|
|
497
|
+
"""
|
|
498
|
+
# upsample flow
|
|
499
|
+
flow = flow.movedim(-1, 1)
|
|
500
|
+
flow = upsample_convlike(flow, kernel_size, stride, padding, shape)
|
|
501
|
+
flow = flow.movedim(1, -1)
|
|
502
|
+
|
|
503
|
+
# rescale displacement
|
|
504
|
+
ndim = flow.dim() - 2
|
|
505
|
+
stride = ensure_list(stride, ndim)
|
|
506
|
+
for d, s in enumerate(stride):
|
|
507
|
+
flow[..., d].mul_(s)
|
|
508
|
+
|
|
509
|
+
return flow
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def compose_flows(flow_left, flow_right, has_identity=False):
|
|
513
|
+
"""Compute flow_left o flow_right
|
|
514
|
+
|
|
515
|
+
Parameters
|
|
516
|
+
----------
|
|
517
|
+
flow_left : (B, *shape, D) tensor
|
|
518
|
+
flow_right : (B, *shape, D) tensor
|
|
519
|
+
has_identity : bool, default=False
|
|
520
|
+
|
|
521
|
+
Returns
|
|
522
|
+
-------
|
|
523
|
+
flow : (B, *shape, D) tensor
|
|
524
|
+
|
|
525
|
+
"""
|
|
526
|
+
if has_identity:
|
|
527
|
+
flow_left = sub_identity(flow_left)
|
|
528
|
+
flow_left = flow_left.movedim(-1, 1)
|
|
529
|
+
flow = apply_flow(flow_left, flow_right, has_identity=has_identity)
|
|
530
|
+
flow = flow.movedim(1, -1)
|
|
531
|
+
flow += flow_right
|
|
532
|
+
return flow
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def bracket(vel_left, vel_right):
|
|
536
|
+
"""Compute the Lie bracket of two SVFs
|
|
537
|
+
|
|
538
|
+
Parameters
|
|
539
|
+
----------
|
|
540
|
+
vel_left : (B, *shape, D) tensor
|
|
541
|
+
vel_right : (B, *shape, D) tensor
|
|
542
|
+
|
|
543
|
+
Returns
|
|
544
|
+
-------
|
|
545
|
+
bkt : (B, *shape, D) tensor
|
|
546
|
+
|
|
547
|
+
"""
|
|
548
|
+
return (compose_flows(vel_left, vel_right) -
|
|
549
|
+
compose_flows(vel_right, vel_left))
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def exp_velocity(vel, steps=8):
|
|
553
|
+
"""Exponentiate a stationary velocity field by scaling and squaring
|
|
554
|
+
|
|
555
|
+
Parameters
|
|
556
|
+
----------
|
|
557
|
+
vel : (B, *shape, D) tensor
|
|
558
|
+
Stationary velocity
|
|
559
|
+
steps : int, default=8
|
|
560
|
+
Number of scaling and squaring steps
|
|
561
|
+
|
|
562
|
+
Returns
|
|
563
|
+
-------
|
|
564
|
+
flow : (B, *shape, D) tensor
|
|
565
|
+
Displacement field
|
|
566
|
+
|
|
567
|
+
"""
|
|
568
|
+
vel = vel / (2**steps)
|
|
569
|
+
for i in range(steps):
|
|
570
|
+
vel = vel + apply_flow(vel.movedim(-1, 1), vel).movedim(1, -1)
|
|
571
|
+
return vel
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def compose_velocities(vel_left, vel_right, order=2):
|
|
575
|
+
"""Find v such that exp(v) = exp(u) o exp(w) using the
|
|
576
|
+
(truncated) Baker–Campbell–Hausdorff formula.
|
|
577
|
+
|
|
578
|
+
https://en.wikipedia.org/wiki/BCH_formula
|
|
579
|
+
|
|
580
|
+
Parameters
|
|
581
|
+
----------
|
|
582
|
+
vel_left : (B, *shape, D) tensor
|
|
583
|
+
vel_right : (B, *shape, D) tensor
|
|
584
|
+
order : 1..4, default=2
|
|
585
|
+
Truncation order.
|
|
586
|
+
|
|
587
|
+
Returns
|
|
588
|
+
-------
|
|
589
|
+
vel : (B, *shape, D) tensor
|
|
590
|
+
|
|
591
|
+
"""
|
|
592
|
+
vel = vel_left + vel_right
|
|
593
|
+
if order > 1:
|
|
594
|
+
b1 = bracket(vel_left, vel_right)
|
|
595
|
+
vel.add_(b1, alpha=1/2)
|
|
596
|
+
if order > 2:
|
|
597
|
+
b2_left = bracket(vel_left, b1)
|
|
598
|
+
vel.add_(b2_left, alpha=1/12)
|
|
599
|
+
b2_right = bracket(vel_right, b1)
|
|
600
|
+
vel.add_(b2_right, alpha=-1/12)
|
|
601
|
+
if order > 3:
|
|
602
|
+
b3 = bracket(vel_right, b2_left)
|
|
603
|
+
vel.add_(b3, alpha=-1/24)
|
|
604
|
+
if order > 4:
|
|
605
|
+
raise ValueError('BCH only implemented up to order 4')
|
|
606
|
+
return vel
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cornucopia
|
|
3
|
+
Version: 0.0.0
|
|
4
|
+
Summary: An abundance of augmentation layers
|
|
5
|
+
Author-email: Yael Balbastre <yael.balbastre@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/balbasty/cornucopia
|
|
8
|
+
Project-URL: Issues, https://github.com/balbasty/cornucopia/issues
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
14
|
+
Requires-Python: >=3.8
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: torch>=1.8
|
|
18
|
+
Requires-Dist: torch-interpol>=0.3
|
|
19
|
+
Requires-Dist: torch-distmap>=0.3
|
|
20
|
+
Requires-Dist: typing_extensions>=4.7
|
|
21
|
+
Requires-Dist: nibabel
|
|
22
|
+
Provides-Extra: io
|
|
23
|
+
Requires-Dist: nibabel; extra == "io"
|
|
24
|
+
Requires-Dist: pillow; extra == "io"
|
|
25
|
+
Requires-Dist: tifffile; extra == "io"
|
|
26
|
+
Requires-Dist: numpy; extra == "io"
|
|
27
|
+
Provides-Extra: typing
|
|
28
|
+
Requires-Dist: numpy>=1.20; extra == "typing"
|
|
29
|
+
Dynamic: license-file
|
|
30
|
+
|
|
31
|
+
<picture align="center">
|
|
32
|
+
<source media="(prefers-color-scheme: dark)" srcset="docs/icons/cornucopia_lightorange.svg">
|
|
33
|
+
<source media="(prefers-color-scheme: light)" srcset="docs/icons/cornucopia_orange.svg">
|
|
34
|
+
<img alt="Cornucopia logo" src="https://github.com/balbasty/cornucopia/raw/main/docs/icons/cornucopia_orange.svg">
|
|
35
|
+
</picture>
|
|
36
|
+
|
|
37
|
+
The `cornucopia` package provides a generic framework for preprocessing,
|
|
38
|
+
augmentation, and domain randomization; along with an abundance of specific layers,
|
|
39
|
+
mostly targeted at (medical) imaging. `cornucopia` is written using a PyTorch
|
|
40
|
+
backend, and therefore runs **on the CPU or GPU**.
|
|
41
|
+
|
|
42
|
+
Cornucopia is *intended* to be used on the GPU for on-line augmentation.
|
|
43
|
+
A quick [benchmark](docs/examples/benchmark.ipynb) of affine and elastic augmentation
|
|
44
|
+
shows that while cornucopia is slower than [TorchIO](https://github.com/fepegar/torchio)
|
|
45
|
+
on the CPU (~ 3s vs 1s), it is greatly accelerated on the GPU (~ 50ms).
|
|
46
|
+
|
|
47
|
+
Since gradients are not expected to backpropagate through its layers, it can
|
|
48
|
+
theoretically be used within any dataloader pipeline,
|
|
49
|
+
independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
|
|
50
|
+
|
|
51
|
+
## Installation
|
|
52
|
+
|
|
53
|
+
### Dependencies
|
|
54
|
+
|
|
55
|
+
- `pytorch >= 1.8`
|
|
56
|
+
- `numpy`
|
|
57
|
+
- `nibabel`
|
|
58
|
+
- `torch-interpol`
|
|
59
|
+
- `torch-distmap`
|
|
60
|
+
|
|
61
|
+
### Pip (release)
|
|
62
|
+
|
|
63
|
+
```sh
|
|
64
|
+
pip install cornucopia
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
### Pip (dev)
|
|
68
|
+
|
|
69
|
+
```sh
|
|
70
|
+
pip install cornucopia@git+https://github.com/balbasty/cornucopia
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
## Documentation
|
|
74
|
+
|
|
75
|
+
Read the [documentation](https://balbasty.github.io/cornucopia) and in particular:
|
|
76
|
+
- [installation](https://balbasty.github.io/cornucopia/install/)
|
|
77
|
+
- [get started](https://balbasty.github.io/cornucopia/start/)
|
|
78
|
+
- [examples](https://balbasty.github.io/cornucopia/examples/)
|
|
79
|
+
- [API](https://balbasty.github.io/cornucopia/api/)
|
|
80
|
+
|
|
81
|
+
## Other augmentation packages
|
|
82
|
+
|
|
83
|
+
There are other great, and much more mature, augmentation packages
|
|
84
|
+
out-there (although few run on the GPU). Here's a non-exhaustive list:
|
|
85
|
+
- [MONAI](https://github.com/Project-MONAI/MONAI)
|
|
86
|
+
- [TorchIO](https://github.com/fepegar/torchio)
|
|
87
|
+
- [Albumentations](https://github.com/albumentations-team/albumentations) (2D only)
|
|
88
|
+
- [Volumentations](https://github.com/ZFTurbo/volumentations) (3D extension of Albumentations)
|
|
89
|
+
|
|
90
|
+
## Contributions
|
|
91
|
+
|
|
92
|
+
If you find this project useful and wish to contribute, please reach out!
|