cornucopia 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cornucopia/__init__.py +73 -0
  2. cornucopia/base.py +1915 -0
  3. cornucopia/baseutils.py +575 -0
  4. cornucopia/contrast.py +260 -0
  5. cornucopia/ctx.py +25 -0
  6. cornucopia/fov.py +707 -0
  7. cornucopia/geometric.py +2068 -0
  8. cornucopia/intensity.py +1358 -0
  9. cornucopia/io.py +161 -0
  10. cornucopia/kspace.py +505 -0
  11. cornucopia/labels.py +1872 -0
  12. cornucopia/noise.py +508 -0
  13. cornucopia/psf.py +463 -0
  14. cornucopia/qmri.py +1288 -0
  15. cornucopia/random.py +1480 -0
  16. cornucopia/special.py +159 -0
  17. cornucopia/synth.py +708 -0
  18. cornucopia/tests/__init__.py +0 -0
  19. cornucopia/tests/test_backward_geometric.py +173 -0
  20. cornucopia/tests/test_backward_intensity.py +243 -0
  21. cornucopia/tests/test_backward_kspace.py +115 -0
  22. cornucopia/tests/test_backward_noise.py +169 -0
  23. cornucopia/tests/test_backward_psf.py +142 -0
  24. cornucopia/tests/test_backward_qmri.py +249 -0
  25. cornucopia/tests/test_backward_random.py +44 -0
  26. cornucopia/tests/test_backward_synth.py +72 -0
  27. cornucopia/tests/test_base.py +401 -0
  28. cornucopia/tests/test_geometric.py +26 -0
  29. cornucopia/tests/test_intensity.py +9 -0
  30. cornucopia/tests/test_random.py +722 -0
  31. cornucopia/tests/test_run_contrast.py +28 -0
  32. cornucopia/tests/test_run_fov.py +132 -0
  33. cornucopia/tests/test_run_geometric.py +157 -0
  34. cornucopia/tests/test_run_intensity.py +192 -0
  35. cornucopia/tests/test_run_kspace.py +70 -0
  36. cornucopia/tests/test_run_labels.py +224 -0
  37. cornucopia/tests/test_run_noise.py +127 -0
  38. cornucopia/tests/test_run_psf.py +115 -0
  39. cornucopia/tests/test_run_qmri.py +114 -0
  40. cornucopia/tests/test_run_synth.py +67 -0
  41. cornucopia/typing.py +97 -0
  42. cornucopia/utils/__init__.py +0 -0
  43. cornucopia/utils/b0.py +745 -0
  44. cornucopia/utils/bounds.py +412 -0
  45. cornucopia/utils/compat.py +47 -0
  46. cornucopia/utils/conv.py +305 -0
  47. cornucopia/utils/gmm.py +169 -0
  48. cornucopia/utils/indexing.py +911 -0
  49. cornucopia/utils/io.py +258 -0
  50. cornucopia/utils/jit.py +128 -0
  51. cornucopia/utils/kernels.py +288 -0
  52. cornucopia/utils/morpho.py +234 -0
  53. cornucopia/utils/mrf.py +574 -0
  54. cornucopia/utils/padding.py +173 -0
  55. cornucopia/utils/patch.py +302 -0
  56. cornucopia/utils/pool.py +282 -0
  57. cornucopia/utils/py.py +348 -0
  58. cornucopia/utils/smart_inplace.py +163 -0
  59. cornucopia/utils/version.py +57 -0
  60. cornucopia/utils/warps.py +606 -0
  61. cornucopia-0.0.0.dist-info/METADATA +92 -0
  62. cornucopia-0.0.0.dist-info/RECORD +65 -0
  63. cornucopia-0.0.0.dist-info/WHEEL +5 -0
  64. cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
  65. cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/fov.py ADDED
@@ -0,0 +1,707 @@
1
+ """
2
+ This module contains transforms that operate on the field of view (FOV)
3
+ of the image.
4
+ """
5
+ __all__ = [
6
+ 'FlipTransform',
7
+ 'RandomFlipTransform',
8
+ 'PermuteAxesTransform',
9
+ 'RandomPermuteAxesTransform',
10
+ 'PatchTransform',
11
+ 'RandomPatchTransform',
12
+ 'CropTransform',
13
+ 'PadTransform',
14
+ 'PowerTwoTransform',
15
+ 'Rot90Transform',
16
+ 'Rot180Transform',
17
+ 'RandomRot90Transform',
18
+ ]
19
+ # stdlib
20
+ import math
21
+ from math import inf
22
+ from numbers import Number
23
+ from random import shuffle
24
+
25
+ # dependencies
26
+ from torch import Tensor
27
+ import typing_extensions as tx
28
+
29
+ # internals
30
+ from .base import Transform
31
+ from .base import FinalTransform, NonFinalTransform, PerChannelTransform
32
+ from .utils.py import ensure_list
33
+ from .utils.padding import pad
34
+ from .random import Uniform, RandKFrom, Sampler, RandInt, make_range
35
+ from . import typing as cct
36
+
37
+
38
+ class FlipTransform(FinalTransform):
39
+ """Flip one or more axes."""
40
+
41
+ def __init__(
42
+ self, axis: tx.Optional[cct.ScalarOrSequence[int]] = None, **kwargs
43
+ ) -> None:
44
+ """
45
+ Parameters
46
+ ----------
47
+ axis : [list of] int
48
+ Axes to flip. By default, flip all spatial axes.
49
+
50
+ Other Parameters
51
+ ----------------
52
+ returns, append, prefix, include, exclude, consume
53
+ See [`Transform`][cornucopia.base.Transform] for details.
54
+ """
55
+ super().__init__(**kwargs)
56
+ self.axis = axis
57
+
58
+ def _xform(self, x: Tensor) -> Tensor:
59
+ axis = self.axis
60
+ if axis is None:
61
+ axis = list(range(1, x.ndim))
62
+ axis = ensure_list(axis)
63
+ return x.flip(axis)
64
+
65
+ def make_inverse(self) -> 'FlipTransform':
66
+ return self
67
+
68
+
69
+ class RandomFlipTransform(NonFinalTransform):
70
+ """Randomly flip one or more axes."""
71
+
72
+ Final = Next = FlipTransform
73
+ """The transform type returned by `unroll`, `next` and `final`."""
74
+
75
+ def __init__(
76
+ self,
77
+ axes: tx.Union[Sampler, cct.ScalarOrSequence[int], None] = None,
78
+ *,
79
+ shared: bool = True,
80
+ **kwargs
81
+ ) -> None:
82
+ """
83
+ Parameters
84
+ ----------
85
+ axes : Sampler | [list of] int
86
+ Axes that can be flipped (default: all spatial axes)
87
+
88
+ Other Parameters
89
+ ----------------
90
+ shared
91
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
92
+ for details.
93
+ returns, append, prefix, include, exclude, consume
94
+ See [`Transform`][cornucopia.base.Transform] for details.
95
+ """
96
+ axes = kwargs.pop('axis', axes)
97
+ super().__init__(shared=shared, **kwargs)
98
+ self.axes = axes
99
+
100
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
101
+ if max_depth == 0:
102
+ return self
103
+ if 'channels' not in self.shared and len(x) > 1:
104
+ return PerChannelTransform(
105
+ [self.unroll(x[i:i+1], max_depth) for i in range(len(x))],
106
+ **self.get_prm()
107
+ ).unroll(x, max_depth-1)
108
+ axes = self.axes or range(1, x.ndim)
109
+ if not isinstance(axes, Sampler):
110
+ rand_axes = RandKFrom(ensure_list(axes))
111
+ rand_axes = rand_axes()
112
+ return FlipTransform(rand_axes).unroll(x, max_depth-1)
113
+
114
+
115
+ class PermuteAxesTransform(FinalTransform):
116
+ """Permute axes"""
117
+
118
+ def __init__(
119
+ self, permutation: tx.Optional[tx.Sequence[int]] = None, **kwargs
120
+ ) -> None:
121
+ """
122
+ Parameters
123
+ ----------
124
+ permutation : [list of] int
125
+ Axes permutation. By default, reverse axes.
126
+ Only applies to spatial axes, so axes are numbered [C, 0, 1, 2]
127
+
128
+ Other Parameters
129
+ ----------------
130
+ returns, append, prefix, include, exclude, consume
131
+ See [`Transform`][cornucopia.base.Transform] for details.
132
+ """
133
+ super().__init__(**kwargs)
134
+ self.permutation = permutation
135
+
136
+ def _xform(self, x: Tensor) -> Tensor:
137
+ permutation = self.permutation
138
+ if permutation is None:
139
+ permutation = list(reversed(range(x.dim()-1)))
140
+ permutation = [0] + [p+1 for p in permutation]
141
+ return x.permute(permutation)
142
+
143
+ def make_inverse(self) -> 'PermuteAxesTransform':
144
+ if self.permutation:
145
+ i = range(len(self.permutation))
146
+ iperm = [i[p] for p in self.permutation]
147
+ return PermuteAxesTransform(iperm, **self.get_prm())
148
+ else:
149
+ return self
150
+
151
+
152
+ class RandomPermuteAxesTransform(NonFinalTransform):
153
+ """Randomly permute axes."""
154
+
155
+ Final = Next = PermuteAxesTransform
156
+ """The transform type returned by `unroll`, `next` and `final`."""
157
+
158
+ def __init__(
159
+ self,
160
+ axes: tx.Optional[tx.Sequence[int]] = None,
161
+ *,
162
+ shared: bool =True,
163
+ **kwargs
164
+ ) -> None:
165
+ """
166
+ Parameters
167
+ ----------
168
+ axes : [list of] int
169
+ Axes that can be permuted (default: all)
170
+
171
+ Other Parameters
172
+ ----------------
173
+ shared
174
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
175
+ for details.
176
+ returns, append, prefix, include, exclude, consume
177
+ See [`Transform`][cornucopia.base.Transform] for details.
178
+ """
179
+ super().__init__(shared=shared, **kwargs)
180
+ self.axes = axes
181
+
182
+ def _unroll(self, x: Tensor, max_depth: float = inf) -> Transform:
183
+ if max_depth == 0:
184
+ return self
185
+ if 'channels' not in self.shared and len(x) > 1:
186
+ return PerChannelTransform(
187
+ [self.unroll(x[i:i+1], max_depth) for i in range(len(x))],
188
+ **self.get_prm()
189
+ ).unroll(x, max_depth-1)
190
+ axes = list(self.axes or range(x.ndim-1))
191
+ shuffle(axes)
192
+ return PermuteAxesTransform(
193
+ axes, **self.get_prm()
194
+ ).unroll(x, max_depth-1)
195
+
196
+
197
+ class Rot90Transform(FinalTransform):
198
+ """Apply a 90 (or 180) rotation along one or several axes."""
199
+
200
+ def __init__(
201
+ self,
202
+ axis: cct.ScalarOrSequence[int] = 0,
203
+ negative: cct.ScalarOrSequence[bool] = False,
204
+ double: cct.ScalarOrSequence[bool] = False,
205
+ **kwargs
206
+ ) -> None:
207
+ """
208
+ Parameters
209
+ ----------
210
+ axis : [list of] int
211
+ Rotation axis (indexing does not account for the channel axis)
212
+ negative : [list of] bool
213
+ Rotate by -90 deg instead of 90 deg
214
+ double : [list of] bool
215
+ Rotate be 180 instead of 90 (`negative` is then unused)
216
+
217
+ Other Parameters
218
+ ----------------
219
+ returns, append, prefix, include, exclude, consume
220
+ See [`Transform`][cornucopia.base.Transform] for details.
221
+ """
222
+ super().__init__(**kwargs)
223
+ self.axis = ensure_list(axis)
224
+ self.negative = ensure_list(negative, len(self.axis))
225
+ self.double = ensure_list(double, len(self.axis))
226
+
227
+ def _xform(self, x: Tensor) -> Tensor:
228
+ # this implementation is suboptimal. We should fuse all transpose
229
+ # and all flips into a single "transpose + flip" operation so that
230
+ # a single allocation happens. This will be fine for now.
231
+
232
+ ndim = x.ndim - 1
233
+ axis = [1 + (ndim + a if a < 0 else a) for a in self.axis]
234
+ for ax, neg, dbl in zip(axis, self.negative, self.double):
235
+ if dbl:
236
+ if ndim == 2:
237
+ dims = [1, 2]
238
+ else:
239
+ assert ndim == 3
240
+ dims = [d for d in (1, 2, 3) if d != ax]
241
+ x = x.flip(dims)
242
+ else:
243
+ if ndim == 2:
244
+ dims = [1, 2]
245
+ else:
246
+ assert ndim == 3
247
+ dims = [d for d in (1, 2, 3) if d != ax]
248
+ x = x.transpose(*dims).flip(dims[1] if neg else dims[0])
249
+ return x
250
+
251
+
252
+ class Rot180Transform(Rot90Transform):
253
+ """Apply a 180 deg rotation along one or several axes"""
254
+
255
+ def __init__(self, axis: cct.ScalarOrSequence[int] = 0, **kwargs) -> None:
256
+ """
257
+ Parameters
258
+ ----------
259
+ axis : [list of] int
260
+ Rotation axis (indexing does not account for the channel axis)
261
+
262
+ Other Parameters
263
+ ----------------
264
+ returns, append, prefix, include, exclude, consume
265
+ See [`Transform`][cornucopia.base.Transform] for details.
266
+ """
267
+ super().__init__(axis, double=True, **kwargs)
268
+
269
+
270
+ class RandomRot90Transform(NonFinalTransform):
271
+ """Random set of 90 transforms."""
272
+
273
+ Final = Next = Rot90Transform
274
+ """The transform type returned by `unroll`, `next` and `final`."""
275
+
276
+ def __init__(
277
+ self,
278
+ axes: tx.Optional[cct.ScalarOrSequence[int]] = None,
279
+ max_rot: cct.SamplerOrBound[int] = 2,
280
+ negative: bool = True,
281
+ *,
282
+ shared: cct.SharedT = True,
283
+ **kwargs
284
+ ) -> None:
285
+ """
286
+ Parameters
287
+ ----------
288
+ axes : [list of] int
289
+ Axes along which rotations can happen.
290
+ If `None`, all axes.
291
+ max_rot : Sampler | int
292
+ Maximum number of consecutive rotations.
293
+ negative : bool
294
+ Whether to authorize negative rotations.
295
+
296
+ Other Parameters
297
+ ----------------
298
+ shared
299
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
300
+ for details.
301
+ returns, append, prefix, include, exclude, consume
302
+ See [`Transform`][cornucopia.base.Transform] for details.
303
+ """
304
+ super().__init__(shared=shared, **kwargs)
305
+ self.axes = axes
306
+ self.max_rot = RandInt.make(make_range(1, max_rot))
307
+ self.negative = negative
308
+
309
+ def _unroll(self, x, max_depth=float('inf')):
310
+ if max_depth == 0:
311
+ return self
312
+ if 'channels' not in self.shared and len(x) > 1:
313
+ return PerChannelTransform(
314
+ [self.unroll(x[i:i+1], max_depth) for i in range(len(x))],
315
+ **self.get_prm()
316
+ ).unroll(x, max_depth-1)
317
+ ndim = x.ndim - 1
318
+ max_rot = self.max_rot
319
+ if isinstance(max_rot, Sampler):
320
+ max_rot = max_rot()
321
+ axes = self.axes
322
+ if axes is None:
323
+ axes = list(range(ndim))
324
+ if isinstance(axes, (int, list, tuple)):
325
+ axes = ensure_list(axes, max_rot, crop=False)
326
+ if not isinstance(axes, Sampler):
327
+ axes = RandKFrom(axes, max_rot, replacement=True)
328
+
329
+ axes = ensure_list(axes(), max_rot)
330
+ negative = RandKFrom([False, True], max_rot, replacement=True)() \
331
+ if self.negative else [False] * max_rot
332
+ return Rot90Transform(
333
+ axes, negative, **self.get_prm()
334
+ ).unroll(max_depth-1)
335
+
336
+
337
+ class CropPadTransform(FinalTransform):
338
+ """Crop and/or pad a tensor."""
339
+
340
+ def __init__(
341
+ self,
342
+ crop: tx.Sequence[slice] = (),
343
+ pad: tx.Sequence[int] = (),
344
+ bound: cct.ItemOrSequence[str] = 'zero',
345
+ value: Number = 0,
346
+ **kwargs
347
+ ) -> None:
348
+ """
349
+ Parameters
350
+ ----------
351
+ crop : list[slice]
352
+ Slicing operator per dimension.
353
+ pad : list[int]
354
+ Left and right padding per dimensions
355
+ bound : [list of] str
356
+ Boundary condition for padding
357
+ value : number
358
+ Padding value in case `bound='constant`
359
+
360
+ Other Parameters
361
+ ----------------
362
+ returns, append, prefix, include, exclude, consume
363
+ See [`Transform`][cornucopia.base.Transform] for details.
364
+ """
365
+ super().__init__(**kwargs)
366
+ self.crop = crop
367
+ self.pad = pad
368
+ self.bound = bound
369
+ self.value = value
370
+
371
+ def _xform(self, x: Tensor) -> Tensor:
372
+ crop = tuple([Ellipsis, *self.crop])
373
+ x = x[crop]
374
+ x = pad(x, self.pad, mode=self.bound, value=self.value)
375
+ return x
376
+
377
+ def make_inverse(self) -> 'CropPadTransform':
378
+ ipad = [slice(left, (-right) or None) for left, right in self.pad]
379
+ icrop = [[s.start or 0, -s.stop if s.stop else 0] for s in self.crop]
380
+ return CropPadTransform(
381
+ ipad, icrop, bound=self.bound, value=self.value, **self.get_prm()
382
+ )
383
+
384
+
385
+ class PatchTransform(NonFinalTransform):
386
+ """Extract a patch from the volume"""
387
+
388
+ Final = Next = CropPadTransform
389
+ """The transform type returned by `unroll`, `next` and `final`."""
390
+
391
+ def __init__(
392
+ self,
393
+ shape: cct.ScalarOrSequence[int] = 64,
394
+ center: cct.ScalarOrSequence[float] = 0,
395
+ bound: cct.ItemOrSequence[str] = 'zero',
396
+ *,
397
+ shared: cct.SharedT = True,
398
+ **kwargs
399
+ ) -> None:
400
+ """
401
+ Parameters
402
+ ----------
403
+ shape : [list of] int
404
+ Patch shape
405
+ center : [list of] float
406
+ Patch center, in relative coordinates -1..1
407
+ bound : [list of]str
408
+ Boundary condition in case padding is needed
409
+
410
+ Other Parameters
411
+ ----------------
412
+ shared
413
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
414
+ for details.
415
+
416
+ !!! changedin "![v0.5](https://img.shields.io/badge/v0.5-yellow) \
417
+ Default for `shared` changed from `"channels"` to `True`"
418
+
419
+ returns, append, prefix, include, exclude, consume
420
+ See [`Transform`][cornucopia.base.Transform] for details.
421
+ """
422
+ kwargs.setdefault('shared', shared)
423
+ super().__init__(**kwargs)
424
+ self.shape = shape
425
+ self.center = center
426
+ self.bound = bound
427
+
428
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
429
+ if max_depth == 0:
430
+ return self
431
+ ndim = x.dim() - 1
432
+ shape = ensure_list(self.shape, ndim)
433
+ center = ensure_list(self.center, ndim)
434
+ center = [(c + 1) / 2 * (s - 1) for c, s in zip(center, x.shape[1:])]
435
+ crop = []
436
+ padding = []
437
+ for ss, cc, sv in zip(shape, center, x.shape[1:]):
438
+ first = int(math.floor(cc - ss/2))
439
+ pad_first = max(0, -first)
440
+ last = first + ss
441
+ pad_last = max(0, last - sv)
442
+ first = max(0, first)
443
+ last = min(sv, last)
444
+ last = (last - sv) or None # ensure negative for CropPad
445
+ crop.append(slice(first, last))
446
+ padding.extend([pad_first, pad_last])
447
+ return CropPadTransform(
448
+ crop, padding, bound=self.bound, **self.get_prm()
449
+ ).unroll(x, max_depth-1)
450
+
451
+
452
+ class RandomPatchTransform(NonFinalTransform):
453
+ """Extract a (randomly located) patch from the volume.
454
+
455
+ This transform ensures that the patch is fully contained within the
456
+ original field of view (unless the patch size is larger than the
457
+ input shape).
458
+ """
459
+
460
+ Next = PatchTransform
461
+ """The transform type returned by `next`."""
462
+
463
+ Final = CropPadTransform
464
+ """The transform type returned by `final`."""
465
+
466
+ def __init__(
467
+ self,
468
+ shape: cct.ScalarOrSequence[int],
469
+ bound: cct.ItemOrSequence[str] = 'zero',
470
+ *,
471
+ shared: cct.SharedT = True,
472
+ **kwargs
473
+ ) -> None:
474
+ """
475
+
476
+ Parameters
477
+ ----------
478
+ shape : [list of] int
479
+ Patch shape
480
+ bound : [list of] str
481
+ Boundary condition in case padding is needed
482
+
483
+ Other Parameters
484
+ ----------------
485
+ shared
486
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
487
+ for details.
488
+
489
+ !!! changedin "![v0.5](https://img.shields.io/badge/v0.5-yellow) \
490
+ Default for `shared` changed from `"channels"` to `True`"
491
+ returns, append, prefix, include, exclude, consume
492
+ See [`Transform`][cornucopia.base.Transform] for details.
493
+ """
494
+ shape = kwargs.pop('patch_size', shape) # support legacy name
495
+ kwargs.setdefault('shared', shared)
496
+ super().__init__(**kwargs)
497
+ self.shape = shape
498
+ self.bound = bound
499
+
500
+ def _unroll(self, x, max_depth=float('inf')):
501
+ if max_depth == 0:
502
+ return self
503
+ shape = x.shape[1:]
504
+ patch_size = ensure_list(self.shape, len(shape))
505
+ min_center = [max(p/s - 1, -1) for p, s in zip(patch_size, shape)]
506
+ max_center = [min(1 - p/s, 1) for p, s in zip(patch_size, shape)]
507
+ center = [Uniform(mn, mx)() for mn, mx in zip(min_center, max_center)]
508
+ return PatchTransform(
509
+ patch_size, center, self.bound, **self.get_prm()
510
+ ).unroll(x, max_depth-1)
511
+
512
+
513
+ class CropTransform(NonFinalTransform):
514
+ """Crop a tensor by some amount"""
515
+
516
+ def __init__(
517
+ self,
518
+ cropping: cct.ScalarOrSequence[tx.Union[int, float]],
519
+ unit: str = 'vox',
520
+ side: str = 'both',
521
+ **kwargs
522
+ ) -> None:
523
+ """
524
+
525
+ Parameters
526
+ ----------
527
+ cropping : [list of] int or float
528
+ Amount of cropping. If `side` is `None`, pre and post cropping
529
+ must be provided in turn.
530
+ unit : {'vox', 'pct'}
531
+ Padding unit
532
+ side : {'pre', 'post', 'both', None}
533
+ Side to crop
534
+
535
+ Other Parameters
536
+ ----------------
537
+ shared
538
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
539
+ for details.
540
+
541
+ !!! changedin "![v0.5](https://img.shields.io/badge/v0.5-yellow) \
542
+ Default for `shared` changed from `"channels"` to `True`"
543
+ returns, append, prefix, include, exclude, consume
544
+ See [`Transform`][cornucopia.base.Transform] for details.
545
+ """
546
+ kwargs.setdefault('shared', True)
547
+ super().__init__(**kwargs)
548
+ self.cropping = cropping
549
+ self.unit = unit
550
+ self.side = side
551
+
552
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
553
+ if max_depth == 0:
554
+ return self
555
+ ndim = x.dim() - 1
556
+ cropping = self.cropping
557
+ if self.side is not None:
558
+ cropping = ensure_list(cropping, ndim)
559
+ if self.unit[0] == 'p':
560
+ cropping = [int(math.ceil(c * s))
561
+ for c, s in zip(cropping, x.shape[1:])]
562
+ cropping = [slice(c, -c if c else None) for c in cropping]
563
+ else:
564
+ cropping = ensure_list(cropping)
565
+ cropping = [0] * (2*ndim - len(cropping))
566
+ if self.unit[0] == 'p':
567
+ shape2 = [s for s in x.shape[1:] for _ in range(2)]
568
+ cropping = [int(math.ceil(c * s))
569
+ for c, s in zip(cropping, shape2)]
570
+ cropping = [slice(c0, -c1 if c1 else None)
571
+ for c0, c1 in zip(cropping[::2], cropping[1::2])]
572
+ return CropPadTransform(
573
+ cropping, [0]*(2*ndim), **self.get_prm()
574
+ ).unroll(x, max_depth-1)
575
+
576
+
577
+ class PadTransform(NonFinalTransform):
578
+ """Pad a tensor by some amount"""
579
+
580
+ Final = Next = CropPadTransform
581
+ """The transform type returned by `unroll`, `next` and `final`."""
582
+
583
+ def __init__(
584
+ self,
585
+ padding: cct.ScalarOrSequence[tx.Union[int, float]],
586
+ unit: str = 'vox',
587
+ side: str = 'both',
588
+ bound: str = 'zero',
589
+ value: Number = 0,
590
+ **kwargs
591
+ ) -> None:
592
+ """
593
+
594
+ Parameters
595
+ ----------
596
+ padding : [list of] int or float
597
+ Amount of padding. If `side` is `None`, pre and post padding
598
+ must be provided in turn.
599
+ unit : {'vox', 'pct'}
600
+ Padding unit
601
+ side : {'pre', 'post', 'both', None}
602
+ Side to pad
603
+ bound : str
604
+ Boundary condition
605
+ value : float
606
+ Value for case `bound='constant'`
607
+
608
+ Other Parameters
609
+ ----------------
610
+ shared
611
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
612
+ for details.
613
+
614
+ !!! changedin "![v0.5](https://img.shields.io/badge/v0.5-yellow) \
615
+ Default for `shared` changed from `"channels"` to `True`"
616
+ returns, append, prefix, include, exclude, consume
617
+ See [`Transform`][cornucopia.base.Transform] for details.
618
+ """
619
+ kwargs.setdefault('shared', True)
620
+ super().__init__(**kwargs)
621
+ self.padding = padding
622
+ self.unit = unit
623
+ self.side = side
624
+ self.bound = bound
625
+ self.value = value
626
+
627
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
628
+ if max_depth == 0:
629
+ return self
630
+ ndim = x.dim() - 1
631
+ padding = self.padding
632
+ if self.side is not None:
633
+ padding = ensure_list(padding, ndim)
634
+ if self.unit[0] == 'p':
635
+ padding = [int(math.ceil(p * s))
636
+ for p, s in zip(padding, x.shape[1:])]
637
+
638
+ else:
639
+ padding = ensure_list(padding)
640
+ padding = [0] * (2 * ndim - len(padding)) + padding
641
+ if self.unit[0] == 'p':
642
+ shape2 = [s for s in x.shape[1:] for _ in range(2)]
643
+ padding = [int(math.ceil(p * s))
644
+ for p, s in zip(padding, shape2)]
645
+
646
+ if self.side == 'pre':
647
+ padding = [p for pz in zip(padding, [0]*ndim) for p in pz]
648
+ elif self.side == 'post':
649
+ padding = [p for zp in zip([0]*ndim, padding) for p in zp]
650
+ elif self.side == 'both':
651
+ padding = [p for pp in zip(padding, padding) for p in pp]
652
+
653
+ return CropPadTransform(
654
+ [slice(None)]*ndim, padding, bound=self.bound, value=self.value,
655
+ **self.get_prm()
656
+ ).unroll(x, max_depth-1)
657
+
658
+
659
+ class PowerTwoTransform(NonFinalTransform):
660
+ """Pad the volume such that the tensor shape can be divided by 2**x"""
661
+
662
+ Next = PatchTransform
663
+ """The transform type returned by `next`."""
664
+
665
+ Final = CropPadTransform
666
+ """The transform type returned by `final`."""
667
+
668
+ def __init__(
669
+ self,
670
+ exponent: cct.ScalarOrSequence[int] = 1,
671
+ bound: cct.ItemOrSequence[str] = 'zero',
672
+ **kwargs
673
+ ) -> None:
674
+ """
675
+
676
+ Parameters
677
+ ----------
678
+ exponent : [list of] int
679
+ Ensure that the shape can be divided by 2 ** exponent
680
+ bound : [list of] str
681
+ Boundary condition for padding
682
+
683
+ Other Parameters
684
+ ----------------
685
+ shared
686
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
687
+ for details.
688
+
689
+ !!! changedin "![v0.5](https://img.shields.io/badge/v0.5-yellow) \
690
+ Default for `shared` changed from `"channels"` to `True`"
691
+ returns, append, prefix, include, exclude, consume
692
+ See [`Transform`][cornucopia.base.Transform] for details.
693
+ """
694
+ kwargs.setdefault('shared', True)
695
+ super().__init__(**kwargs)
696
+ self.exponent = exponent
697
+ self.bound = bound
698
+
699
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
700
+ if max_depth == 0:
701
+ return self
702
+ shape = x.shape[1:]
703
+ exponent = ensure_list(self.exponent, len(shape))
704
+ bigshape = [max(2 ** e, s) for e, s in zip(exponent, shape)]
705
+ return PatchTransform(
706
+ bigshape, bound=self.bound, **self.get_prm()
707
+ ).unroll(x, max_depth-1)