cornucopia 0.1.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {cornucopia-0.1.0 → cornucopia-0.3.0}/LICENSE +0 -0
  2. {cornucopia-0.1.0 → cornucopia-0.3.0}/PKG-INFO +10 -7
  3. {cornucopia-0.1.0 → cornucopia-0.3.0}/README.md +9 -6
  4. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/__init__.py +19 -12
  5. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/_version.py +3 -3
  6. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/base.py +51 -19
  7. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/baseutils.py +4 -4
  8. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/contrast.py +4 -3
  9. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/ctx.py +0 -0
  10. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/fov.py +141 -16
  11. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/geometric.py +6 -7
  12. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/intensity.py +9 -8
  13. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/io.py +2 -2
  14. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/kspace.py +3 -3
  15. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/labels.py +9 -9
  16. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/noise.py +2 -2
  17. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/psf.py +4 -4
  18. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/qmri.py +6 -6
  19. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/random.py +5 -6
  20. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/special.py +0 -0
  21. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/synth.py +11 -1
  22. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/__init__.py +0 -0
  23. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_contrast.py +0 -0
  24. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_fov.py +34 -0
  25. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_geometric.py +0 -0
  26. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_intensity.py +0 -0
  27. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_kspace.py +0 -0
  28. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_labels.py +0 -0
  29. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_noise.py +0 -0
  30. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_psf.py +0 -0
  31. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_qmri.py +0 -0
  32. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_synth.py +0 -0
  33. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/__init__.py +0 -0
  34. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/b0.py +0 -0
  35. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/bounds.py +0 -0
  36. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/conv.py +0 -0
  37. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/gmm.py +7 -2
  38. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/indexing.py +0 -0
  39. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/io.py +0 -0
  40. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/jit.py +0 -0
  41. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/kernels.py +0 -0
  42. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/morpho.py +0 -0
  43. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/padding.py +0 -0
  44. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/patch.py +0 -0
  45. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/py.py +0 -0
  46. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/version.py +0 -0
  47. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/warps.py +0 -0
  48. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/PKG-INFO +10 -7
  49. {cornucopia-0.1.0 → cornucopia-0.3.0}/pyproject.toml +0 -0
  50. {cornucopia-0.1.0 → cornucopia-0.3.0}/setup.cfg +0 -0
  51. {cornucopia-0.1.0 → cornucopia-0.3.0}/setup.py +0 -0
  52. {cornucopia-0.1.0 → cornucopia-0.3.0}/versioneer.py +0 -0
  53. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/SOURCES.txt +0 -0
  54. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/dependency_links.txt +0 -0
  55. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/requires.txt +0 -0
  56. {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/top_level.txt +0 -0
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cornucopia
3
- Version: 0.1.0
3
+ Version: 0.3.0
4
4
  Summary: An abundance of augmentation layers
5
5
  Home-page: UNKNOWN
6
6
  Author: Yael Balbastre
@@ -10,7 +10,7 @@ Project-URL: Source Code, https://github.com/balbasty/cornucopia
10
10
  Description: <picture align="center">
11
11
  <source media="(prefers-color-scheme: dark)" srcset="docs/icons/cornucopia_lightorange.svg">
12
12
  <source media="(prefers-color-scheme: light)" srcset="docs/icons/cornucopia_orange.svg">
13
- <img alt="Cornucopia logo" src="docs/icons/cornucopia_orange.svg">
13
+ <img alt="Cornucopia logo" src="https://github.com/balbasty/cornucopia/raw/main/docs/icons/cornucopia_orange.svg">
14
14
  </picture>
15
15
 
16
16
  The `cornucopia` package provides a generic framework for preprocessing,
@@ -27,9 +27,6 @@ Description: <picture align="center">
27
27
  theoretically be used within any dataloader pipeline,
28
28
  independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
29
29
 
30
- ## Installation
31
-
32
-
33
30
  ## Installation
34
31
 
35
32
  ### Dependencies
@@ -43,15 +40,21 @@ Description: <picture align="center">
43
40
  ### Conda
44
41
 
45
42
  ```sh
46
- conda install cornucopia -c balbasty -c pytorch
43
+ conda install cornucopia -c balbasty -c pytorch -c conda-forge
47
44
  ```
48
45
 
49
- ### Pip
46
+ ### Pip (release)
50
47
 
51
48
  ```sh
52
49
  pip install cornucopia
53
50
  ```
54
51
 
52
+ ### Pip (dev)
53
+
54
+ ```sh
55
+ pip install cornucopia@git+https://github.com/balbasty/cornucopia
56
+ ```
57
+
55
58
  ## Documentation
56
59
 
57
60
  Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
@@ -1,7 +1,7 @@
1
1
  <picture align="center">
2
2
  <source media="(prefers-color-scheme: dark)" srcset="docs/icons/cornucopia_lightorange.svg">
3
3
  <source media="(prefers-color-scheme: light)" srcset="docs/icons/cornucopia_orange.svg">
4
- <img alt="Cornucopia logo" src="docs/icons/cornucopia_orange.svg">
4
+ <img alt="Cornucopia logo" src="https://github.com/balbasty/cornucopia/raw/main/docs/icons/cornucopia_orange.svg">
5
5
  </picture>
6
6
 
7
7
  The `cornucopia` package provides a generic framework for preprocessing,
@@ -18,9 +18,6 @@ Since gradients are not expected to backpropagate through its layers, it can
18
18
  theoretically be used within any dataloader pipeline,
19
19
  independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
20
20
 
21
- ## Installation
22
-
23
-
24
21
  ## Installation
25
22
 
26
23
  ### Dependencies
@@ -34,15 +31,21 @@ independent of the downstream learning framework (pytorch, tensorflow, jax, ...)
34
31
  ### Conda
35
32
 
36
33
  ```sh
37
- conda install cornucopia -c balbasty -c pytorch
34
+ conda install cornucopia -c balbasty -c pytorch -c conda-forge
38
35
  ```
39
36
 
40
- ### Pip
37
+ ### Pip (release)
41
38
 
42
39
  ```sh
43
40
  pip install cornucopia
44
41
  ```
45
42
 
43
+ ### Pip (dev)
44
+
45
+ ```sh
46
+ pip install cornucopia@git+https://github.com/balbasty/cornucopia
47
+ ```
48
+
46
49
  ## Documentation
47
50
 
48
51
  Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
@@ -1,4 +1,5 @@
1
- """Flexible transforms for pre-processing and augmentation
1
+ """
2
+ Flexible transforms for pre-processing and augmentation
2
3
 
3
4
  Example on how to use this machinery to generate within-subject
4
5
  image pairs with a random affine deformation between them::
@@ -30,19 +31,25 @@ image pairs with a random affine deformation between them::
30
31
 
31
32
  """
32
33
 
33
- # TODO:
34
- # [x] Make it a standalone package?
35
- # [x] Move samplers in their own file
36
- # [x] Add IO transforms (that transform filenames in tensors)
37
- # [ ] Better deal with separable/shared transforms
38
- # [ ] Add a SharedTransform class (like Randomized) that does the heavy
39
- # lifting
40
- # [ ] By default (non shared), let Transforms handle multi-channel
41
- # data (currently we loop across channels in the base class)
42
-
43
34
  from . import random # noqa: F401
44
35
  from . import ctx # noqa: F401
45
- from .ctx import batch # noqa: F401
36
+ from . import base # noqa: F401
37
+ from . import special # noqa: F401
38
+ from . import contrast # noqa: F401
39
+ from . import geometric # noqa: F401
40
+ from . import intensity # noqa: F401
41
+ from . import io # noqa: F401
42
+ from . import fov # noqa: F401
43
+ from . import kspace # noqa: F401
44
+ from . import labels # noqa: F401
45
+ from . import noise # noqa: F401
46
+ from . import psf # noqa: F401
47
+ from . import qmri # noqa: F401
48
+ from . import synth # noqa: F401
49
+ from . import utils # noqa: F401
50
+
51
+ from .random import * # noqa: F401,F403
52
+ from .ctx import * # noqa: F401,F403
46
53
  from .base import * # noqa: F401,F403
47
54
  from .special import * # noqa: F401,F403
48
55
  from .contrast import * # noqa: F401,F403
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2023-08-28T17:04:06-0400",
11
+ "date": "2024-04-19T14:23:50+0100",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "57fb52ad37a27a28c5894754411fb6bb73e60b5a",
15
- "version": "0.1.0"
14
+ "full-revisionid": "37de94f181b9a97eebd21460f4df63ae4a0750f8",
15
+ "version": "0.3.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -148,7 +148,7 @@ class Transform(nn.Module):
148
148
  return x
149
149
 
150
150
  # now we're working with a single tensor (or str)
151
- y = self.apply(x)
151
+ y = self.xform(x)
152
152
  if not isinstance(y, Returned):
153
153
  if not isinstance(y, type(self.returns)):
154
154
  y = dict(input=x, output=y)
@@ -240,7 +240,7 @@ class FinalTransform(Transform):
240
240
  def is_final(self):
241
241
  return True
242
242
 
243
- def apply(self, x):
243
+ def xform(self, x):
244
244
  """Apply the transform to a tensor
245
245
 
246
246
  Parameters
@@ -283,7 +283,7 @@ class IdentityTransform(FinalTransform):
283
283
  def __init__(self, **kwargs):
284
284
  super().__init__(**kwargs)
285
285
 
286
- def apply(self, x):
286
+ def xform(self, x):
287
287
  return x
288
288
 
289
289
  def make_inverse(self):
@@ -305,12 +305,12 @@ class SharedMixin:
305
305
  shared = ''
306
306
  return shared
307
307
 
308
- def apply(self, x):
308
+ def xform(self, x):
309
309
  if 'channels' in self.shared:
310
310
  xform = self.make_final(x[:1], max_depth=1)
311
311
  else:
312
312
  xform = self.make_final(x, max_depth=1)
313
- return xform.apply(x)
313
+ return xform.xform(x)
314
314
 
315
315
  def forward(self, *a, **k):
316
316
  return self._shared_forward(*a, **k)
@@ -359,7 +359,7 @@ class NonFinalTransform(SharedMixin, Transform):
359
359
  super().__init__(**kwargs)
360
360
  self.shared = self._prepare_shared(shared)
361
361
 
362
- def make_final(self, x, max_depth=float('inf'), *args, **kwargs):
362
+ def make_final(self, x, max_depth=float('inf')):
363
363
  if self.is_final or max_depth == 0:
364
364
  return self
365
365
  return NotImplemented
@@ -447,7 +447,7 @@ class SequentialTransform(SpecialMixin, SharedMixin, Transform):
447
447
  def make_final(self, x, max_depth=float('inf')):
448
448
  if max_depth == 0:
449
449
  return self
450
- x = VirtualTensor.from_any(x, compute_stats=True)
450
+ # x = VirtualTensor.from_any(x, compute_stats=True)
451
451
  trf = []
452
452
  for t in self:
453
453
  t = t.make_final(x, max_depth=max_depth-1)
@@ -525,7 +525,7 @@ class PerChannelTransform(SpecialMixin, Transform):
525
525
  trf = PerChannelTransform(trf, **prm)
526
526
  return trf
527
527
 
528
- def apply(self, x):
528
+ def xform(self, x):
529
529
  results = []
530
530
  for i, t in enumerate(self.transforms):
531
531
  with ReturningTransform(t, self.returns), \
@@ -1068,33 +1068,67 @@ class RandomizedTransform(NonFinalTransform):
1068
1068
  """
1069
1069
  Transform generated by randomizing some parameters of another transform.
1070
1070
 
1071
+ !!! note "`ctx.randomize` is an alias for `RandomizedTransform`"
1072
+
1071
1073
  !!! example "Gaussian noise with randomized variance"
1072
1074
  Object call
1073
1075
  ```python
1074
1076
  import cornucopia as cc
1075
- hypernoise = RandomizedTransform(cc.GaussianNoise, [cc.Uniform(0, 10)])
1077
+ hypernoise = cc.RandomizedTransform(cc.GaussianNoise, [cc.Uniform()])
1076
1078
  img = hypernoise(img)
1077
1079
  ```
1078
- Functional call
1080
+
1081
+ Delayed call
1079
1082
  ```python
1080
1083
  import cornucopia as cc
1081
- hypernoise = cc.randomize(cc.GaussianNoise)(cc.Uniform(0, 10))
1084
+ MyRandomNoise = cc.randomize(cc.GaussianNoise)
1085
+ hypernoise = MyRandomNoise(cc.Uniform())
1082
1086
  img = hypernoise(img)
1083
1087
  ```
1084
1088
 
1085
1089
  """
1086
1090
 
1087
- def __init__(self, transform, sample, ksample=None,
1091
+ class Delayed:
1092
+ # Temproary parameter holder for delayed calls
1093
+ def __init__(self, transform, **kwargs):
1094
+ self.transform = transform
1095
+ self.kwargs = kwargs
1096
+
1097
+ def __call__(self, *args, **kwargs):
1098
+ return RandomizedTransform(
1099
+ self.transform, args, kwargs, **self.kwargs)
1100
+
1101
+ def __new__(cls, *args, **kwargs):
1102
+ if cls is RandomizedTransform:
1103
+ return cls._base_new(*args, **kwargs)
1104
+ return super().__new__(cls)
1105
+
1106
+ @classmethod
1107
+ def _base_new(cls, transform, sample=tuple(), ksample=dict(),
1108
+ *, shared=False, **kwargs):
1109
+ assert cls is RandomizedTransform
1110
+ if not sample and not ksample:
1111
+ # If no arguments are passed, it means that the user calls
1112
+ # this in "delayed/functional" mode. In that case, we return
1113
+ # a callable object that returns the constructed instance
1114
+ # using the call-time arguments.
1115
+ return cls.Delayed(transform, shared=shared, **kwargs)
1116
+ # Otherwise, we're in object mode and we instantiate the
1117
+ # randomized object.
1118
+ return super().__new__(cls)
1119
+
1120
+ def __init__(self, transform, sample=tuple(), ksample=dict(),
1088
1121
  *, shared=False, **kwargs):
1089
1122
  """
1090
-
1091
1123
  Parameters
1092
1124
  ----------
1093
1125
  transform : callable(...) -> Transform
1094
1126
  A Transform subclass or a function that constructs a Transform.
1095
1127
  sample : [list or dict of] callable
1096
1128
  A collection of functions that generate parameter values provided
1097
- to `transform`.
1129
+ to `transform`. Can be args-like or kwargs-like arguments.
1130
+ ksample : dict[callable]
1131
+ Must be kwargs-like arguments.
1098
1132
 
1099
1133
  Other Parameters
1100
1134
  ----------------
@@ -1130,9 +1164,7 @@ class RandomizedTransform(NonFinalTransform):
1130
1164
 
1131
1165
  def __repr__(self):
1132
1166
  if type(self) is RandomizedTransform:
1133
- try:
1134
- if issubclass(self.subtransform, Transform):
1135
- return f'Randomized{self.subtransform.__name__}()'
1136
- except TypeError:
1137
- pass
1167
+ xform = self.subtransform
1168
+ if isinstance(xform, type) and issubclass(xform, Transform):
1169
+ return f'Randomized{xform.__name__}()'
1138
1170
  return super().__repr__()
@@ -24,7 +24,7 @@ def get_first_element(x, include=None, exclude=None, types=None):
24
24
  if ok:
25
25
  return v, True
26
26
  return None, False
27
- if torch.is_tensor(x) or isinstance(x, types):
27
+ if torch.is_tensor(x) or (types and isinstance(x, types)):
28
28
  return x, True
29
29
  return x, False
30
30
 
@@ -213,9 +213,9 @@ class VirtualTensor:
213
213
  @classmethod
214
214
  def from_tensor(cls, x, compute_stats=False):
215
215
  if compute_stats:
216
- vmin = x.min(dim=list(range(1, x.ndim)))
217
- vmax = x.max(dim=list(range(1, x.ndim)))
218
- vmean = x.mean(dim=list(range(1, x.ndim)))
216
+ vmin = x.reshape([len(x), -1]).min(dim=-1).values
217
+ vmax = x.reshape([len(x), -1]).max(dim=-1).values
218
+ vmean = x.float().mean(dim=list(range(1, x.ndim)))
219
219
  else:
220
220
  vmin = vmax = vmean = None
221
221
  return VirtualTensor(x.shape, dtype=x.dtype, device=x.device,
@@ -45,7 +45,7 @@ class ContrastMixtureTransform(NonFinalTransform):
45
45
  self.mu = mu
46
46
  self.sigma = sigma
47
47
 
48
- def apply(self, x):
48
+ def xform(self, x):
49
49
  z = self.z.to(x)
50
50
  mu0 = self.mu0.to(x)
51
51
  sigma0 = self.sigma0.to(x)
@@ -114,7 +114,8 @@ class ContrastMixtureTransform(NonFinalTransform):
114
114
  mu = torch.rand_like(
115
115
  old_mu).mul_(old_mu_max - old_mu_min).add_(old_mu_min)
116
116
  sigma = torch.rand_like(
117
- old_sigma_diag).mul_(old_sigma_max - old_sigma_min).add_(old_sigma_min)
117
+ old_sigma_diag
118
+ ).mul_(old_sigma_max - old_sigma_min).add_(old_sigma_min)
118
119
  corr = torch.rand([len(old_mu), nc*(nc-1)//2], **backend).mul_(0.5)
119
120
 
120
121
  fullsigma = torch.eye(nc, **backend).expand([nk, nc, nc]).clone()
@@ -145,7 +146,7 @@ class ContrastLookupTransform(NonFinalTransform):
145
146
  self.edges = edges
146
147
  self.mu = mu
147
148
 
148
- def apply(self, x):
149
+ def xform(self, x):
149
150
  edges, mu = self.edges.to(x), self.mu.to(x)
150
151
  mu0 = (edges[:-1] + edges[1:]) / 2
151
152
  nk = len(mu)
File without changes
@@ -8,14 +8,16 @@ __all__ = [
8
8
  'CropTransform',
9
9
  'PadTransform',
10
10
  'PowerTwoTransform',
11
+ 'Rot90Transform',
12
+ 'Rot180Transform',
13
+ 'RandomRot90Transform',
11
14
  ]
12
-
13
15
  import math
14
16
  from random import shuffle
15
- from .base import FinalTransform, NonFinalTransform
17
+ from .base import FinalTransform, NonFinalTransform, PerChannelTransform
16
18
  from .utils.py import ensure_list
17
19
  from .utils.padding import pad
18
- from .random import Uniform, RandKFrom, Sampler
20
+ from .random import Uniform, RandKFrom, Sampler, RandInt, make_range
19
21
 
20
22
 
21
23
  class FlipTransform(FinalTransform):
@@ -23,7 +25,6 @@ class FlipTransform(FinalTransform):
23
25
 
24
26
  def __init__(self, axis=None, **kwargs):
25
27
  """
26
-
27
28
  Parameters
28
29
  ----------
29
30
  axis : [list of] int
@@ -32,7 +33,7 @@ class FlipTransform(FinalTransform):
32
33
  super().__init__(**kwargs)
33
34
  self.axis = axis
34
35
 
35
- def apply(self, x):
36
+ def xform(self, x):
36
37
  axis = self.axis
37
38
  if axis is None:
38
39
  axis = list(range(1, x.ndim))
@@ -46,24 +47,30 @@ class FlipTransform(FinalTransform):
46
47
  class RandomFlipTransform(NonFinalTransform):
47
48
  """Randomly flip one or more axes"""
48
49
 
49
- def __init__(self, axes=None, **kwargs):
50
+ def __init__(self, axes=None, *, shared=True, **kwargs):
50
51
  """
51
-
52
52
  Parameters
53
53
  ----------
54
54
  axes : Sampler or [list of] int
55
55
  Axes that can be flipped (default: all)
56
+
57
+ Other Parameters
58
+ ----------------
56
59
  shared : {'channels', 'tensors', 'channels+tensors', ''}
57
60
  Apply the same flip to all channels and/or tensors
58
61
  """
59
62
  axes = kwargs.pop('axis', axes)
60
- kwargs.setdefault('shared', True)
61
- super().__init__(**kwargs)
63
+ super().__init__(shared=shared, **kwargs)
62
64
  self.axes = axes
63
65
 
64
66
  def make_final(self, x, max_depth=float('inf')):
65
67
  if max_depth == 0:
66
68
  return self
69
+ if 'channels' not in self.shared and len(x) > 1:
70
+ return PerChannelTransform(
71
+ [self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
72
+ **self.get_prm()
73
+ ).make_final(x, max_depth-1)
67
74
  axes = self.axes or range(1, x.ndim)
68
75
  if not isinstance(axes, Sampler):
69
76
  rand_axes = RandKFrom(ensure_list(axes))
@@ -76,7 +83,6 @@ class PermuteAxesTransform(FinalTransform):
76
83
 
77
84
  def __init__(self, permutation=None, **kwargs):
78
85
  """
79
-
80
86
  Parameters
81
87
  ----------
82
88
  permutation : [list of] int
@@ -86,7 +92,7 @@ class PermuteAxesTransform(FinalTransform):
86
92
  super().__init__(**kwargs)
87
93
  self.permutation = permutation
88
94
 
89
- def apply(self, x):
95
+ def xform(self, x):
90
96
  permutation = self.permutation
91
97
  if permutation is None:
92
98
  permutation = list(reversed(range(x.dim()-1)))
@@ -105,23 +111,29 @@ class PermuteAxesTransform(FinalTransform):
105
111
  class RandomPermuteAxesTransform(NonFinalTransform):
106
112
  """Randomly permute axes"""
107
113
 
108
- def __init__(self, axes=None, **kwargs):
114
+ def __init__(self, axes=None, *, shared=True, **kwargs):
109
115
  """
110
-
111
116
  Parameters
112
117
  ----------
113
118
  axes : [list of] int
114
119
  Axes that can be permuted (default: all)
120
+
121
+ Other Parameters
122
+ ----------------
115
123
  shared : {'channels', 'tensors', 'channels+tensors', ''}
116
124
  Apply the same permutation to all channels and/or tensors
117
125
  """
118
- kwargs.setdefault('shared', True)
119
- super().__init__(**kwargs)
126
+ super().__init__(shared=shared, **kwargs)
120
127
  self.axes = axes
121
128
 
122
129
  def make_final(self, x, max_depth=float('inf')):
123
130
  if max_depth == 0:
124
131
  return self
132
+ if 'channels' not in self.shared and len(x) > 1:
133
+ return PerChannelTransform(
134
+ [self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
135
+ **self.get_prm()
136
+ ).make_final(x, max_depth-1)
125
137
  axes = list(self.axes or range(x.ndim-1))
126
138
  shuffle(axes)
127
139
  return PermuteAxesTransform(
@@ -129,6 +141,119 @@ class RandomPermuteAxesTransform(NonFinalTransform):
129
141
  ).make_final(x, max_depth-1)
130
142
 
131
143
 
144
+ class Rot90Transform(FinalTransform):
145
+ """
146
+ Apply a 90 (or 180) rotation along one or several axes
147
+ """
148
+
149
+ def __init__(self, axis=0, negative=False, double=False, **kwargs):
150
+ """
151
+ Parameters
152
+ ----------
153
+ axis : int or list[int]
154
+ Rotation axis (indexing does not account for the channel axis)
155
+ negative : bool or list[bool]
156
+ Rotate by -90 deg instead of 90 deg
157
+ double : bool or list[bool]
158
+ Rotate be 180 instead of 90 (`negative` is then unused)
159
+ """
160
+ super().__init__(**kwargs)
161
+ self.axis = ensure_list(axis)
162
+ self.negative = ensure_list(negative, len(self.axis))
163
+ self.double = ensure_list(double, len(self.axis))
164
+
165
+ def xform(self, x):
166
+ # this implementation is suboptimal. We should fuse all transpose
167
+ # and all flips into a single "transpose + flip" operation so that
168
+ # a single allocation happens. This will be fine for now.
169
+
170
+ ndim = x.ndim - 1
171
+ axis = [1 + (ndim + a if a < 0 else a) for a in self.axis]
172
+ for ax, neg, dbl in zip(axis, self.negative, self.double):
173
+ if dbl:
174
+ if ndim == 2:
175
+ dims = [1, 2]
176
+ else:
177
+ assert ndim == 3
178
+ dims = [d for d in (1, 2, 3) if d != ax]
179
+ x = x.flip(dims)
180
+ else:
181
+ if ndim == 2:
182
+ dims = [1, 2]
183
+ else:
184
+ assert ndim == 3
185
+ dims = [d for d in (1, 2, 3) if d != ax]
186
+ x = x.transpose(*dims).flip(dims[1] if neg else dims[0])
187
+ return x
188
+
189
+
190
+ class Rot180Transform(Rot90Transform):
191
+ """Apply a 180 deg rotation along one or several axes"""
192
+
193
+ def __init__(self, axis=0, **kwargs):
194
+ """
195
+ Parameters
196
+ ----------
197
+ axis : int or list[int]
198
+ Rotation axis (indexing does not account for the channel axis)
199
+ """
200
+ super().__init__(axis, double=True, **kwargs)
201
+
202
+
203
+ class RandomRot90Transform(NonFinalTransform):
204
+ """Random set of 90 transforms"""
205
+
206
+ def __init__(self, axes=None, max_rot=2, negative=True,
207
+ *, shared=True, **kwargs):
208
+ """
209
+ Parameters
210
+ ----------
211
+ axes : int or list[int]
212
+ Axes along which rotations can happen.
213
+ If `None`, all axes.
214
+ max_rot : int or Sampler
215
+ Maximum number of consecutive rotations.
216
+ negative : bool
217
+ Whether to authorize negative rotations.
218
+
219
+ Other Parameters
220
+ ----------------
221
+ shared : {'channels', 'tensors', 'channels+tensors', ''}
222
+ Apply the same permutation to all channels and/or tensors
223
+ """
224
+ super().__init__(shared=shared, **kwargs)
225
+ self.axes = axes
226
+ self.max_rot = RandInt.make(make_range(1, max_rot))
227
+ self.negative = negative
228
+
229
+ def make_final(self, x, max_depth=float('inf')):
230
+ if max_depth == 0:
231
+ return self
232
+ if 'channels' not in self.shared and len(x) > 1:
233
+ return PerChannelTransform(
234
+ [self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
235
+ **self.get_prm()
236
+ ).make_final(x, max_depth-1)
237
+ ndim = x.ndim - 1
238
+ max_rot = self.max_rot
239
+ if isinstance(max_rot, Sampler):
240
+ max_rot = max_rot()
241
+ axes = self.axes
242
+ if axes is None:
243
+ axes = list(range(ndim))
244
+ if isinstance(axes, (int, list, tuple)):
245
+ axes = ensure_list(axes, max_rot, crop=False)
246
+ if not isinstance(axes, Sampler):
247
+ axes = RandKFrom(axes, max_rot, replacement=True)
248
+
249
+ axes = ensure_list(axes(), max_rot)
250
+ negative = RandKFrom([False, True], max_rot, replacement=True)() \
251
+ if self.negative else [False] * max_rot
252
+ return Rot90Transform(
253
+ axes, negative, **self.get_prm()
254
+ ).make_final(max_depth-1)
255
+
256
+
132
257
  class CropPadTransform(FinalTransform):
133
258
  """Crop and/or pad a tensor"""
134
259
 
@@ -151,7 +276,7 @@ class CropPadTransform(FinalTransform):
151
276
  self.bound = bound
152
277
  self.value = value
153
278
 
154
- def apply(self, x):
279
+ def xform(self, x):
155
280
  crop = tuple([Ellipsis, *self.crop])
156
281
  x = x[crop]
157
282
  x = pad(x, self.pad, mode=self.bound, value=self.value)
@@ -218,7 +218,7 @@ class ElasticTransform(NonFinalTransform):
218
218
  ).movedim(-1, 1)
219
219
  return flow
220
220
 
221
- def apply(self, x):
221
+ def xform(self, x):
222
222
  """Deform the input tensor
223
223
 
224
224
  Parameters
@@ -503,7 +503,7 @@ class AffineTransform(NonFinalTransform):
503
503
  def make_flow(self, matrix, shape):
504
504
  return warps.affine_flow(matrix, shape).movedim(-1, 0)
505
505
 
506
- def apply(self, x):
506
+ def xform(self, x):
507
507
  flow = cast_like(self.flow, x)
508
508
  matrix = cast_like(self.matrix, x)
509
509
  required = return_requires(self.returns)
@@ -766,7 +766,7 @@ class AffineElasticTransform(NonFinalTransform):
766
766
  self.affine = affine
767
767
  self.bound = bound
768
768
 
769
- def apply(self, x):
769
+ def xform(self, x):
770
770
  flow = cast_like(self.flow, x)
771
771
  controls = cast_like(self.controls, x)
772
772
  affine = cast_like(self.affine, x)
@@ -954,7 +954,7 @@ class MakeAffinePair(NonFinalTransform):
954
954
  self.left = left
955
955
  self.right = right
956
956
 
957
- def apply(self, x):
957
+ def xform(self, x):
958
958
  x1 = self.left(x)
959
959
  x2 = self.right(x)
960
960
  mat1, mat2 = self.left.matrix, self.right.matrix
@@ -1119,7 +1119,6 @@ class SlicewiseAffineTransform(NonFinalTransform):
1119
1119
  F = torch.eye(ndim+1, **backend)
1120
1120
  F[:ndim, -1] = -offsets
1121
1121
  Z = E.clone()
1122
- print(zooms.shape, Z.shape)
1123
1122
  Z.diagonal(0, -1, -2)[:, :-1].copy_(1 + zooms)
1124
1123
  T = E.clone()
1125
1124
  T[:, :ndim, -1] = translations
@@ -1204,7 +1203,7 @@ class SlicewiseAffineTransform(NonFinalTransform):
1204
1203
  self.subsample = subsample
1205
1204
  self.bound = bound
1206
1205
 
1207
- def apply(self, x):
1206
+ def xform(self, x):
1208
1207
  flow = cast_like(self.flow, x)
1209
1208
  matrix = cast_like(self.matrix, x)
1210
1209
 
@@ -1362,7 +1361,7 @@ class RandomSlicewiseAffineTransform(NonFinalTransform):
1362
1361
  # get slice direction
1363
1362
  slice = self.slice
1364
1363
  if slice is None:
1365
- slice = RandInt(0, ndim)
1364
+ slice = RandInt(0, ndim - 1)
1366
1365
  if isinstance(slice, Sampler):
1367
1366
  slice = slice()
1368
1367