cornucopia 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {cornucopia-0.2.0 → cornucopia-0.3.0}/LICENSE +0 -0
  2. {cornucopia-0.2.0 → cornucopia-0.3.0}/PKG-INFO +8 -5
  3. {cornucopia-0.2.0 → cornucopia-0.3.0}/README.md +7 -4
  4. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/__init__.py +19 -12
  5. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/_version.py +3 -3
  6. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/base.py +44 -12
  7. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/baseutils.py +0 -0
  8. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/contrast.py +0 -0
  9. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/ctx.py +0 -0
  10. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/fov.py +138 -13
  11. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/geometric.py +1 -2
  12. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/intensity.py +0 -0
  13. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/io.py +0 -0
  14. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/kspace.py +0 -0
  15. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/labels.py +0 -0
  16. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/noise.py +0 -0
  17. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/psf.py +1 -1
  18. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/qmri.py +0 -0
  19. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/random.py +5 -6
  20. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/special.py +0 -0
  21. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/synth.py +0 -0
  22. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/__init__.py +0 -0
  23. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_contrast.py +0 -0
  24. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_fov.py +34 -0
  25. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_geometric.py +0 -0
  26. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_intensity.py +0 -0
  27. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_kspace.py +0 -0
  28. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_labels.py +0 -0
  29. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_noise.py +0 -0
  30. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_psf.py +0 -0
  31. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_qmri.py +0 -0
  32. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_synth.py +0 -0
  33. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/__init__.py +0 -0
  34. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/b0.py +0 -0
  35. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/bounds.py +0 -0
  36. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/conv.py +0 -0
  37. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/gmm.py +0 -0
  38. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/indexing.py +0 -0
  39. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/io.py +0 -0
  40. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/jit.py +0 -0
  41. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/kernels.py +0 -0
  42. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/morpho.py +0 -0
  43. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/padding.py +0 -0
  44. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/patch.py +0 -0
  45. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/py.py +0 -0
  46. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/version.py +0 -0
  47. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/warps.py +0 -0
  48. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/PKG-INFO +8 -5
  49. {cornucopia-0.2.0 → cornucopia-0.3.0}/pyproject.toml +0 -0
  50. {cornucopia-0.2.0 → cornucopia-0.3.0}/setup.cfg +0 -0
  51. {cornucopia-0.2.0 → cornucopia-0.3.0}/setup.py +0 -0
  52. {cornucopia-0.2.0 → cornucopia-0.3.0}/versioneer.py +0 -0
  53. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/SOURCES.txt +0 -0
  54. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/dependency_links.txt +0 -0
  55. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/requires.txt +0 -0
  56. {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/top_level.txt +0 -0
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cornucopia
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: An abundance of augmentation layers
5
5
  Home-page: UNKNOWN
6
6
  Author: Yael Balbastre
@@ -27,9 +27,6 @@ Description: <picture align="center">
27
27
  theoretically be used within any dataloader pipeline,
28
28
  independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
29
29
 
30
- ## Installation
31
-
32
-
33
30
  ## Installation
34
31
 
35
32
  ### Dependencies
@@ -46,12 +43,18 @@ Description: <picture align="center">
46
43
  conda install cornucopia -c balbasty -c pytorch -c conda-forge
47
44
  ```
48
45
 
49
- ### Pip
46
+ ### Pip (release)
50
47
 
51
48
  ```sh
52
49
  pip install cornucopia
53
50
  ```
54
51
 
52
+ ### Pip (dev)
53
+
54
+ ```sh
55
+ pip install cornucopia@git+https://github.com/balbasty/cornucopia
56
+ ```
57
+
55
58
  ## Documentation
56
59
 
57
60
  Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
@@ -18,9 +18,6 @@ Since gradients are not expected to backpropagate through its layers, it can
18
18
  theoretically be used within any dataloader pipeline,
19
19
  independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
20
20
 
21
- ## Installation
22
-
23
-
24
21
  ## Installation
25
22
 
26
23
  ### Dependencies
@@ -37,12 +34,18 @@ independent of the downstream learning framework (pytorch, tensorflow, jax, ...)
37
34
  conda install cornucopia -c balbasty -c pytorch -c conda-forge
38
35
  ```
39
36
 
40
- ### Pip
37
+ ### Pip (release)
41
38
 
42
39
  ```sh
43
40
  pip install cornucopia
44
41
  ```
45
42
 
43
+ ### Pip (dev)
44
+
45
+ ```sh
46
+ pip install cornucopia@git+https://github.com/balbasty/cornucopia
47
+ ```
48
+
46
49
  ## Documentation
47
50
 
48
51
  Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
@@ -1,4 +1,5 @@
1
- """Flexible transforms for pre-processing and augmentation
1
+ """
2
+ Flexible transforms for pre-processing and augmentation
2
3
 
3
4
  Example on how to use this machinery to generate within-subject
4
5
  image pairs with a random affine deformation between them::
@@ -30,19 +31,25 @@ image pairs with a random affine deformation between them::
30
31
 
31
32
  """
32
33
 
33
- # TODO:
34
- # [x] Make it a standalone package?
35
- # [x] Move samplers in their own file
36
- # [x] Add IO transforms (that transform filenames in tensors)
37
- # [ ] Better deal with separable/shared transforms
38
- # [ ] Add a SharedTransform class (like Randomized) that does the heavy
39
- # lifting
40
- # [ ] By default (non shared), let Transforms handle multi-channel
41
- # data (currently we loop across channels in the base class)
42
-
43
34
  from . import random # noqa: F401
44
35
  from . import ctx # noqa: F401
45
- from .ctx import batch # noqa: F401
36
+ from . import base # noqa: F401
37
+ from . import special # noqa: F401
38
+ from . import contrast # noqa: F401
39
+ from . import geometric # noqa: F401
40
+ from . import intensity # noqa: F401
41
+ from . import io # noqa: F401
42
+ from . import fov # noqa: F401
43
+ from . import kspace # noqa: F401
44
+ from . import labels # noqa: F401
45
+ from . import noise # noqa: F401
46
+ from . import psf # noqa: F401
47
+ from . import qmri # noqa: F401
48
+ from . import synth # noqa: F401
49
+ from . import utils # noqa: F401
50
+
51
+ from .random import * # noqa: F401,F403
52
+ from .ctx import * # noqa: F401,F403
46
53
  from .base import * # noqa: F401,F403
47
54
  from .special import * # noqa: F401,F403
48
55
  from .contrast import * # noqa: F401,F403
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2023-11-21T15:24:48-0500",
11
+ "date": "2024-04-19T14:23:50+0100",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "6d09025573589db13d99c6df247f138f22f5ab61",
15
- "version": "0.2.0"
14
+ "full-revisionid": "37de94f181b9a97eebd21460f4df63ae4a0750f8",
15
+ "version": "0.3.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -359,7 +359,7 @@ class NonFinalTransform(SharedMixin, Transform):
359
359
  super().__init__(**kwargs)
360
360
  self.shared = self._prepare_shared(shared)
361
361
 
362
- def make_final(self, x, max_depth=float('inf'), *args, **kwargs):
362
+ def make_final(self, x, max_depth=float('inf')):
363
363
  if self.is_final or max_depth == 0:
364
364
  return self
365
365
  return NotImplemented
@@ -1068,33 +1068,67 @@ class RandomizedTransform(NonFinalTransform):
1068
1068
  """
1069
1069
  Transform generated by randomizing some parameters of another transform.
1070
1070
 
1071
+ !!! note "`ctx.randomize` is an alias for `RandomizedTransform`"
1072
+
1071
1073
  !!! example "Gaussian noise with randomized variance"
1072
1074
  Object call
1073
1075
  ```python
1074
1076
  import cornucopia as cc
1075
- hypernoise = RandomizedTransform(cc.GaussianNoise, [cc.Uniform(0, 10)])
1077
+ hypernoise = cc.RandomizedTransform(cc.GaussianNoise, [cc.Uniform()])
1076
1078
  img = hypernoise(img)
1077
1079
  ```
1078
- Functional call
1080
+
1081
+ Delayed call
1079
1082
  ```python
1080
1083
  import cornucopia as cc
1081
- hypernoise = cc.randomize(cc.GaussianNoise)(cc.Uniform(0, 10))
1084
+ MyRandomNoise = cc.randomize(cc.GaussianNoise)
1085
+ hypernoise = MyRandomNoise(cc.Uniform())
1082
1086
  img = hypernoise(img)
1083
1087
  ```
1084
1088
 
1085
1089
  """
1086
1090
 
1087
- def __init__(self, transform, sample, ksample=None,
1091
+ class Delayed:
1092
+ # Temproary parameter holder for delayed calls
1093
+ def __init__(self, transform, **kwargs):
1094
+ self.transform = transform
1095
+ self.kwargs = kwargs
1096
+
1097
+ def __call__(self, *args, **kwargs):
1098
+ return RandomizedTransform(
1099
+ self.transform, args, kwargs, **self.kwargs)
1100
+
1101
+ def __new__(cls, *args, **kwargs):
1102
+ if cls is RandomizedTransform:
1103
+ return cls._base_new(*args, **kwargs)
1104
+ return super().__new__(cls)
1105
+
1106
+ @classmethod
1107
+ def _base_new(cls, transform, sample=tuple(), ksample=dict(),
1108
+ *, shared=False, **kwargs):
1109
+ assert cls is RandomizedTransform
1110
+ if not sample and not ksample:
1111
+ # If no arguments are passed, it means that the user calls
1112
+ # this in "delayed/functional" mode. In that case, we return
1113
+ # a callable object that returns the constructed instance
1114
+ # using the call-time arguments.
1115
+ return cls.Delayed(transform, shared=shared, **kwargs)
1116
+ # Otherwise, we're in object mode and we instantiate the
1117
+ # randomized object.
1118
+ return super().__new__(cls)
1119
+
1120
+ def __init__(self, transform, sample=tuple(), ksample=dict(),
1088
1121
  *, shared=False, **kwargs):
1089
1122
  """
1090
-
1091
1123
  Parameters
1092
1124
  ----------
1093
1125
  transform : callable(...) -> Transform
1094
1126
  A Transform subclass or a function that constructs a Transform.
1095
1127
  sample : [list or dict of] callable
1096
1128
  A collection of functions that generate parameter values provided
1097
- to `transform`.
1129
+ to `transform`. Can be args-like or kwargs-like arguments.
1130
+ ksample : dict[callable]
1131
+ Must be kwargs-like arguments.
1098
1132
 
1099
1133
  Other Parameters
1100
1134
  ----------------
@@ -1130,9 +1164,7 @@ class RandomizedTransform(NonFinalTransform):
1130
1164
 
1131
1165
  def __repr__(self):
1132
1166
  if type(self) is RandomizedTransform:
1133
- try:
1134
- if issubclass(self.subtransform, Transform):
1135
- return f'Randomized{self.subtransform.__name__}()'
1136
- except TypeError:
1137
- pass
1167
+ xform = self.subtransform
1168
+ if isinstance(xform, type) and issubclass(xform, Transform):
1169
+ return f'Randomized{xform.__name__}()'
1138
1170
  return super().__repr__()
File without changes
@@ -8,14 +8,16 @@ __all__ = [
8
8
  'CropTransform',
9
9
  'PadTransform',
10
10
  'PowerTwoTransform',
11
+ 'Rot90Transform',
12
+ 'Rot180Transform',
13
+ 'RandomRot90Transform',
11
14
  ]
12
-
13
15
  import math
14
16
  from random import shuffle
15
- from .base import FinalTransform, NonFinalTransform
17
+ from .base import FinalTransform, NonFinalTransform, PerChannelTransform
16
18
  from .utils.py import ensure_list
17
19
  from .utils.padding import pad
18
- from .random import Uniform, RandKFrom, Sampler
20
+ from .random import Uniform, RandKFrom, Sampler, RandInt, make_range
19
21
 
20
22
 
21
23
  class FlipTransform(FinalTransform):
@@ -23,7 +25,6 @@ class FlipTransform(FinalTransform):
23
25
 
24
26
  def __init__(self, axis=None, **kwargs):
25
27
  """
26
-
27
28
  Parameters
28
29
  ----------
29
30
  axis : [list of] int
@@ -46,24 +47,30 @@ class FlipTransform(FinalTransform):
46
47
  class RandomFlipTransform(NonFinalTransform):
47
48
  """Randomly flip one or more axes"""
48
49
 
49
- def __init__(self, axes=None, **kwargs):
50
+ def __init__(self, axes=None, *, shared=True, **kwargs):
50
51
  """
51
-
52
52
  Parameters
53
53
  ----------
54
54
  axes : Sampler or [list of] int
55
55
  Axes that can be flipped (default: all)
56
+
57
+ Other Parameters
58
+ ----------------
56
59
  shared : {'channels', 'tensors', 'channels+tensors', ''}
57
60
  Apply the same flip to all channels and/or tensors
58
61
  """
59
62
  axes = kwargs.pop('axis', axes)
60
- kwargs.setdefault('shared', True)
61
- super().__init__(**kwargs)
63
+ super().__init__(shared=shared, **kwargs)
62
64
  self.axes = axes
63
65
 
64
66
  def make_final(self, x, max_depth=float('inf')):
65
67
  if max_depth == 0:
66
68
  return self
69
+ if 'channels' not in self.shared and len(x) > 1:
70
+ return PerChannelTransform(
71
+ [self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
72
+ **self.get_prm()
73
+ ).make_final(x, max_depth-1)
67
74
  axes = self.axes or range(1, x.ndim)
68
75
  if not isinstance(axes, Sampler):
69
76
  rand_axes = RandKFrom(ensure_list(axes))
@@ -76,7 +83,6 @@ class PermuteAxesTransform(FinalTransform):
76
83
 
77
84
  def __init__(self, permutation=None, **kwargs):
78
85
  """
79
-
80
86
  Parameters
81
87
  ----------
82
88
  permutation : [list of] int
@@ -105,23 +111,29 @@ class PermuteAxesTransform(FinalTransform):
105
111
  class RandomPermuteAxesTransform(NonFinalTransform):
106
112
  """Randomly permute axes"""
107
113
 
108
- def __init__(self, axes=None, **kwargs):
114
+ def __init__(self, axes=None, *, shared=True, **kwargs):
109
115
  """
110
-
111
116
  Parameters
112
117
  ----------
113
118
  axes : [list of] int
114
119
  Axes that can be permuted (default: all)
120
+
121
+ Other Parameters
122
+ ----------------
115
123
  shared : {'channels', 'tensors', 'channels+tensors', ''}
116
124
  Apply the same permutation to all channels and/or tensors
117
125
  """
118
- kwargs.setdefault('shared', True)
119
- super().__init__(**kwargs)
126
+ super().__init__(shared=shared, **kwargs)
120
127
  self.axes = axes
121
128
 
122
129
  def make_final(self, x, max_depth=float('inf')):
123
130
  if max_depth == 0:
124
131
  return self
132
+ if 'channels' not in self.shared and len(x) > 1:
133
+ return PerChannelTransform(
134
+ [self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
135
+ **self.get_prm()
136
+ ).make_final(x, max_depth-1)
125
137
  axes = list(self.axes or range(x.ndim-1))
126
138
  shuffle(axes)
127
139
  return PermuteAxesTransform(
@@ -129,6 +141,119 @@ class RandomPermuteAxesTransform(NonFinalTransform):
129
141
  ).make_final(x, max_depth-1)
130
142
 
131
143
 
144
+ class Rot90Transform(FinalTransform):
145
+ """
146
+ Apply a 90 (or 180) rotation along one or several axes
147
+ """
148
+
149
+ def __init__(self, axis=0, negative=False, double=False, **kwargs):
150
+ """
151
+ Parameters
152
+ ----------
153
+ axis : int or list[int]
154
+ Rotation axis (indexing does not account for the channel axis)
155
+ negative : bool or list[bool]
156
+ Rotate by -90 deg instead of 90 deg
157
+ double : bool or list[bool]
158
+ Rotate be 180 instead of 90 (`negative` is then unused)
159
+ """
160
+ super().__init__(**kwargs)
161
+ self.axis = ensure_list(axis)
162
+ self.negative = ensure_list(negative, len(self.axis))
163
+ self.double = ensure_list(double, len(self.axis))
164
+
165
+ def xform(self, x):
166
+ # this implementation is suboptimal. We should fuse all transpose
167
+ # and all flips into a single "transpose + flip" operation so that
168
+ # a single allocation happens. This will be fine for now.
169
+
170
+ ndim = x.ndim - 1
171
+ axis = [1 + (ndim + a if a < 0 else a) for a in self.axis]
172
+ for ax, neg, dbl in zip(axis, self.negative, self.double):
173
+ if dbl:
174
+ if ndim == 2:
175
+ dims = [1, 2]
176
+ else:
177
+ assert ndim == 3
178
+ dims = [d for d in (1, 2, 3) if d != ax]
179
+ x = x.flip(dims)
180
+ else:
181
+ if ndim == 2:
182
+ dims = [1, 2]
183
+ else:
184
+ assert ndim == 3
185
+ dims = [d for d in (1, 2, 3) if d != ax]
186
+ x = x.transpose(*dims).flip(dims[1] if neg else dims[0])
187
+ return x
188
+
189
+
190
+ class Rot180Transform(Rot90Transform):
191
+ """Apply a 180 deg rotation along one or several axes"""
192
+
193
+ def __init__(self, axis=0, **kwargs):
194
+ """
195
+ Parameters
196
+ ----------
197
+ axis : int or list[int]
198
+ Rotation axis (indexing does not account for the channel axis)
199
+ """
200
+ super().__init__(axis, double=True, **kwargs)
201
+
202
+
203
+ class RandomRot90Transform(NonFinalTransform):
204
+ """Random set of 90 transforms"""
205
+
206
+ def __init__(self, axes=None, max_rot=2, negative=True,
207
+ *, shared=True, **kwargs):
208
+ """
209
+ Parameters
210
+ ----------
211
+ axes : int or list[int]
212
+ Axes along which rotations can happen.
213
+ If `None`, all axes.
214
+ max_rot : int or Sampler
215
+ Maximum number of consecutive rotations.
216
+ negative : bool
217
+ Whether to authorize negative rotations.
218
+
219
+ Other Parameters
220
+ ----------------
221
+ shared : {'channels', 'tensors', 'channels+tensors', ''}
222
+ Apply the same permutation to all channels and/or tensors
223
+ """
224
+ super().__init__(shared=shared, **kwargs)
225
+ self.axes = axes
226
+ self.max_rot = RandInt.make(make_range(1, max_rot))
227
+ self.negative = negative
228
+
229
+ def make_final(self, x, max_depth=float('inf')):
230
+ if max_depth == 0:
231
+ return self
232
+ if 'channels' not in self.shared and len(x) > 1:
233
+ return PerChannelTransform(
234
+ [self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
235
+ **self.get_prm()
236
+ ).make_final(x, max_depth-1)
237
+ ndim = x.ndim - 1
238
+ max_rot = self.max_rot
239
+ if isinstance(max_rot, Sampler):
240
+ max_rot = max_rot()
241
+ axes = self.axes
242
+ if axes is None:
243
+ axes = list(range(ndim))
244
+ if isinstance(axes, (int, list, tuple)):
245
+ axes = ensure_list(axes, max_rot, crop=False)
246
+ if not isinstance(axes, Sampler):
247
+ axes = RandKFrom(axes, max_rot, replacement=True)
248
+
249
+ axes = ensure_list(axes(), max_rot)
250
+ negative = RandKFrom([False, True], max_rot, replacement=True)() \
251
+ if self.negative else [False] * max_rot
252
+ return Rot90Transform(
253
+ axes, negative, **self.get_prm()
254
+ ).make_final(max_depth-1)
255
+
256
+
132
257
  class CropPadTransform(FinalTransform):
133
258
  """Crop and/or pad a tensor"""
134
259
 
@@ -1119,7 +1119,6 @@ class SlicewiseAffineTransform(NonFinalTransform):
1119
1119
  F = torch.eye(ndim+1, **backend)
1120
1120
  F[:ndim, -1] = -offsets
1121
1121
  Z = E.clone()
1122
- print(zooms.shape, Z.shape)
1123
1122
  Z.diagonal(0, -1, -2)[:, :-1].copy_(1 + zooms)
1124
1123
  T = E.clone()
1125
1124
  T[:, :ndim, -1] = translations
@@ -1362,7 +1361,7 @@ class RandomSlicewiseAffineTransform(NonFinalTransform):
1362
1361
  # get slice direction
1363
1362
  slice = self.slice
1364
1363
  if slice is None:
1365
- slice = RandInt(0, ndim)
1364
+ slice = RandInt(0, ndim - 1)
1366
1365
  if isinstance(slice, Sampler):
1367
1366
  slice = slice()
1368
1367
 
File without changes
@@ -52,7 +52,7 @@ class RandomSmoothTransform(RandomizedTransform):
52
52
 
53
53
  Parameters
54
54
  ----------
55
- fwhm : Sampler or float
55
+ fwhm : Sampler or float
56
56
  Sampler or upper bound for the full-width at half-maximum
57
57
 
58
58
  Other Parameters
@@ -163,6 +163,7 @@ class Uniform(Sampler):
163
163
  def __init__(self, *args, **kwargs):
164
164
  """
165
165
  ```python
166
+ Uniform()
166
167
  Uniform(max)
167
168
  Uniform(min, max)
168
169
  ```
@@ -171,10 +172,10 @@ class Uniform(Sampler):
171
172
  ----------
172
173
  min : float or sequence[float], default=0
173
174
  Lower bound (inclusive)
174
- max : float or sequence[float]
175
+ max : float or sequence[float], default=1
175
176
  Upper bound (inclusive or exclusive, depending on rounding)
176
177
  """
177
- min, max = 0, None
178
+ min, max = 0, 1
178
179
  if len(args) == 2:
179
180
  min, max = args
180
181
  elif len(args) == 1:
@@ -183,8 +184,6 @@ class Uniform(Sampler):
183
184
  min = kwargs['min']
184
185
  if 'max' in kwargs:
185
186
  max = kwargs['max']
186
- if max is None:
187
- raise ValueError('Expected at least one argument')
188
187
  super().__init__(min=min, max=max)
189
188
 
190
189
  def __call__(self, n=None, **backend):
@@ -261,7 +260,7 @@ class RandKFrom(Sampler):
261
260
  self.replacement = replacement
262
261
 
263
262
  def __call__(self, n=None, **backend):
264
- k = self.k or RandInt(len(self.range))()
263
+ k = self.k or RandInt(1, len(self.range))()
265
264
  if isinstance(n, (list, tuple)) or n:
266
265
  raise ValueError('RandKFrom cannot sample multiple elements')
267
266
  if not self.replacement:
@@ -269,7 +268,7 @@ class RandKFrom(Sampler):
269
268
  random.shuffle(range)
270
269
  return range[:k]
271
270
  else:
272
- index = RandInt(len(self.range))(k)
271
+ index = RandInt(0, len(self.range)-1)(k)
273
272
  return [self.range[i] for i in index]
274
273
 
275
274
 
@@ -11,6 +11,9 @@ from cornucopia.fov import (
11
11
  CropTransform,
12
12
  PadTransform,
13
13
  PowerTwoTransform,
14
+ Rot90Transform,
15
+ Rot180Transform,
16
+ RandomRot90Transform,
14
17
  )
15
18
 
16
19
  SEED = 12345678
@@ -53,6 +56,37 @@ def test_run_fov_permute_random(size):
53
56
  assert True
54
57
 
55
58
 
59
+ @pytest.mark.parametrize("size", sizes)
60
+ @pytest.mark.parametrize("axes", [0, 1, [0, 1], [0, 0]])
61
+ @pytest.mark.parametrize("negative", [False, True])
62
+ @pytest.mark.parametrize("double", [False, True])
63
+ def test_run_rot90_permute(size, axes, negative, double):
64
+ random.seed(SEED)
65
+ torch.random.manual_seed(SEED)
66
+ x = torch.randn(size)
67
+ _ = Rot90Transform(axes, negative, double)(x)
68
+ assert True
69
+
70
+
71
+ @pytest.mark.parametrize("size", sizes)
72
+ @pytest.mark.parametrize("axes", [0, 1, [0, 1], [0, 0]])
73
+ def test_run_rot180_permute(size, axes):
74
+ random.seed(SEED)
75
+ torch.random.manual_seed(SEED)
76
+ x = torch.randn(size)
77
+ _ = Rot180Transform(axes)(x)
78
+ assert True
79
+
80
+
81
+ @pytest.mark.parametrize("size", sizes)
82
+ def test_run_rot90_random(size):
83
+ random.seed(SEED)
84
+ torch.random.manual_seed(SEED)
85
+ x = torch.randn(size)
86
+ _ = RandomRot90Transform()(x)
87
+ assert True
88
+
89
+
56
90
  @pytest.mark.parametrize("size", sizes)
57
91
  def test_run_fov_patch(size):
58
92
  random.seed(SEED)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cornucopia
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: An abundance of augmentation layers
5
5
  Home-page: UNKNOWN
6
6
  Author: Yael Balbastre
@@ -27,9 +27,6 @@ Description: <picture align="center">
27
27
  theoretically be used within any dataloader pipeline,
28
28
  independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
29
29
 
30
- ## Installation
31
-
32
-
33
30
  ## Installation
34
31
 
35
32
  ### Dependencies
@@ -46,12 +43,18 @@ Description: <picture align="center">
46
43
  conda install cornucopia -c balbasty -c pytorch -c conda-forge
47
44
  ```
48
45
 
49
- ### Pip
46
+ ### Pip (release)
50
47
 
51
48
  ```sh
52
49
  pip install cornucopia
53
50
  ```
54
51
 
52
+ ### Pip (dev)
53
+
54
+ ```sh
55
+ pip install cornucopia@git+https://github.com/balbasty/cornucopia
56
+ ```
57
+
55
58
  ## Documentation
56
59
 
57
60
  Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
File without changes
File without changes
File without changes
File without changes