cornucopia 0.2.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cornucopia-0.2.0 → cornucopia-0.3.0}/LICENSE +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/PKG-INFO +8 -5
- {cornucopia-0.2.0 → cornucopia-0.3.0}/README.md +7 -4
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/__init__.py +19 -12
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/_version.py +3 -3
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/base.py +44 -12
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/baseutils.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/contrast.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/ctx.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/fov.py +138 -13
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/geometric.py +1 -2
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/intensity.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/io.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/kspace.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/labels.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/noise.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/psf.py +1 -1
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/qmri.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/random.py +5 -6
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/special.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/synth.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/__init__.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_contrast.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_fov.py +34 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_geometric.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_intensity.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_kspace.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_labels.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_noise.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_psf.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_qmri.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_synth.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/__init__.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/b0.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/bounds.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/conv.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/gmm.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/indexing.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/io.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/jit.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/kernels.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/morpho.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/padding.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/patch.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/py.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/version.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia/utils/warps.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/PKG-INFO +8 -5
- {cornucopia-0.2.0 → cornucopia-0.3.0}/pyproject.toml +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/setup.cfg +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/setup.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/versioneer.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/SOURCES.txt +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/dependency_links.txt +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/requires.txt +0 -0
- {cornucopia-0.2.0 → cornucopia-0.3.0}/cornucopia.egg-info/top_level.txt +0 -0
|
File without changes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: cornucopia
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: An abundance of augmentation layers
|
|
5
5
|
Home-page: UNKNOWN
|
|
6
6
|
Author: Yael Balbastre
|
|
@@ -27,9 +27,6 @@ Description: <picture align="center">
|
|
|
27
27
|
theoretically be used within any dataloader pipeline,
|
|
28
28
|
independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
|
|
29
29
|
|
|
30
|
-
## Installation
|
|
31
|
-
|
|
32
|
-
|
|
33
30
|
## Installation
|
|
34
31
|
|
|
35
32
|
### Dependencies
|
|
@@ -46,12 +43,18 @@ Description: <picture align="center">
|
|
|
46
43
|
conda install cornucopia -c balbasty -c pytorch -c conda-forge
|
|
47
44
|
```
|
|
48
45
|
|
|
49
|
-
### Pip
|
|
46
|
+
### Pip (release)
|
|
50
47
|
|
|
51
48
|
```sh
|
|
52
49
|
pip install cornucopia
|
|
53
50
|
```
|
|
54
51
|
|
|
52
|
+
### Pip (dev)
|
|
53
|
+
|
|
54
|
+
```sh
|
|
55
|
+
pip install cornucopia@git+https://github.com/balbasty/cornucopia
|
|
56
|
+
```
|
|
57
|
+
|
|
55
58
|
## Documentation
|
|
56
59
|
|
|
57
60
|
Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
|
|
@@ -18,9 +18,6 @@ Since gradients are not expected to backpropagate through its layers, it can
|
|
|
18
18
|
theoretically be used within any dataloader pipeline,
|
|
19
19
|
independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
|
|
20
20
|
|
|
21
|
-
## Installation
|
|
22
|
-
|
|
23
|
-
|
|
24
21
|
## Installation
|
|
25
22
|
|
|
26
23
|
### Dependencies
|
|
@@ -37,12 +34,18 @@ independent of the downstream learning framework (pytorch, tensorflow, jax, ...)
|
|
|
37
34
|
conda install cornucopia -c balbasty -c pytorch -c conda-forge
|
|
38
35
|
```
|
|
39
36
|
|
|
40
|
-
### Pip
|
|
37
|
+
### Pip (release)
|
|
41
38
|
|
|
42
39
|
```sh
|
|
43
40
|
pip install cornucopia
|
|
44
41
|
```
|
|
45
42
|
|
|
43
|
+
### Pip (dev)
|
|
44
|
+
|
|
45
|
+
```sh
|
|
46
|
+
pip install cornucopia@git+https://github.com/balbasty/cornucopia
|
|
47
|
+
```
|
|
48
|
+
|
|
46
49
|
## Documentation
|
|
47
50
|
|
|
48
51
|
Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Flexible transforms for pre-processing and augmentation
|
|
2
3
|
|
|
3
4
|
Example on how to use this machinery to generate within-subject
|
|
4
5
|
image pairs with a random affine deformation between them::
|
|
@@ -30,19 +31,25 @@ image pairs with a random affine deformation between them::
|
|
|
30
31
|
|
|
31
32
|
"""
|
|
32
33
|
|
|
33
|
-
# TODO:
|
|
34
|
-
# [x] Make it a standalone package?
|
|
35
|
-
# [x] Move samplers in their own file
|
|
36
|
-
# [x] Add IO transforms (that transform filenames in tensors)
|
|
37
|
-
# [ ] Better deal with separable/shared transforms
|
|
38
|
-
# [ ] Add a SharedTransform class (like Randomized) that does the heavy
|
|
39
|
-
# lifting
|
|
40
|
-
# [ ] By default (non shared), let Transforms handle multi-channel
|
|
41
|
-
# data (currently we loop across channels in the base class)
|
|
42
|
-
|
|
43
34
|
from . import random # noqa: F401
|
|
44
35
|
from . import ctx # noqa: F401
|
|
45
|
-
from .
|
|
36
|
+
from . import base # noqa: F401
|
|
37
|
+
from . import special # noqa: F401
|
|
38
|
+
from . import contrast # noqa: F401
|
|
39
|
+
from . import geometric # noqa: F401
|
|
40
|
+
from . import intensity # noqa: F401
|
|
41
|
+
from . import io # noqa: F401
|
|
42
|
+
from . import fov # noqa: F401
|
|
43
|
+
from . import kspace # noqa: F401
|
|
44
|
+
from . import labels # noqa: F401
|
|
45
|
+
from . import noise # noqa: F401
|
|
46
|
+
from . import psf # noqa: F401
|
|
47
|
+
from . import qmri # noqa: F401
|
|
48
|
+
from . import synth # noqa: F401
|
|
49
|
+
from . import utils # noqa: F401
|
|
50
|
+
|
|
51
|
+
from .random import * # noqa: F401,F403
|
|
52
|
+
from .ctx import * # noqa: F401,F403
|
|
46
53
|
from .base import * # noqa: F401,F403
|
|
47
54
|
from .special import * # noqa: F401,F403
|
|
48
55
|
from .contrast import * # noqa: F401,F403
|
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "
|
|
11
|
+
"date": "2024-04-19T14:23:50+0100",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "37de94f181b9a97eebd21460f4df63ae4a0750f8",
|
|
15
|
+
"version": "0.3.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -359,7 +359,7 @@ class NonFinalTransform(SharedMixin, Transform):
|
|
|
359
359
|
super().__init__(**kwargs)
|
|
360
360
|
self.shared = self._prepare_shared(shared)
|
|
361
361
|
|
|
362
|
-
def make_final(self, x, max_depth=float('inf')
|
|
362
|
+
def make_final(self, x, max_depth=float('inf')):
|
|
363
363
|
if self.is_final or max_depth == 0:
|
|
364
364
|
return self
|
|
365
365
|
return NotImplemented
|
|
@@ -1068,33 +1068,67 @@ class RandomizedTransform(NonFinalTransform):
|
|
|
1068
1068
|
"""
|
|
1069
1069
|
Transform generated by randomizing some parameters of another transform.
|
|
1070
1070
|
|
|
1071
|
+
!!! note "`ctx.randomize` is an alias for `RandomizedTransform`"
|
|
1072
|
+
|
|
1071
1073
|
!!! example "Gaussian noise with randomized variance"
|
|
1072
1074
|
Object call
|
|
1073
1075
|
```python
|
|
1074
1076
|
import cornucopia as cc
|
|
1075
|
-
hypernoise = RandomizedTransform(cc.GaussianNoise, [cc.Uniform(
|
|
1077
|
+
hypernoise = cc.RandomizedTransform(cc.GaussianNoise, [cc.Uniform()])
|
|
1076
1078
|
img = hypernoise(img)
|
|
1077
1079
|
```
|
|
1078
|
-
|
|
1080
|
+
|
|
1081
|
+
Delayed call
|
|
1079
1082
|
```python
|
|
1080
1083
|
import cornucopia as cc
|
|
1081
|
-
|
|
1084
|
+
MyRandomNoise = cc.randomize(cc.GaussianNoise)
|
|
1085
|
+
hypernoise = MyRandomNoise(cc.Uniform())
|
|
1082
1086
|
img = hypernoise(img)
|
|
1083
1087
|
```
|
|
1084
1088
|
|
|
1085
1089
|
"""
|
|
1086
1090
|
|
|
1087
|
-
|
|
1091
|
+
class Delayed:
|
|
1092
|
+
# Temproary parameter holder for delayed calls
|
|
1093
|
+
def __init__(self, transform, **kwargs):
|
|
1094
|
+
self.transform = transform
|
|
1095
|
+
self.kwargs = kwargs
|
|
1096
|
+
|
|
1097
|
+
def __call__(self, *args, **kwargs):
|
|
1098
|
+
return RandomizedTransform(
|
|
1099
|
+
self.transform, args, kwargs, **self.kwargs)
|
|
1100
|
+
|
|
1101
|
+
def __new__(cls, *args, **kwargs):
|
|
1102
|
+
if cls is RandomizedTransform:
|
|
1103
|
+
return cls._base_new(*args, **kwargs)
|
|
1104
|
+
return super().__new__(cls)
|
|
1105
|
+
|
|
1106
|
+
@classmethod
|
|
1107
|
+
def _base_new(cls, transform, sample=tuple(), ksample=dict(),
|
|
1108
|
+
*, shared=False, **kwargs):
|
|
1109
|
+
assert cls is RandomizedTransform
|
|
1110
|
+
if not sample and not ksample:
|
|
1111
|
+
# If no arguments are passed, it means that the user calls
|
|
1112
|
+
# this in "delayed/functional" mode. In that case, we return
|
|
1113
|
+
# a callable object that returns the constructed instance
|
|
1114
|
+
# using the call-time arguments.
|
|
1115
|
+
return cls.Delayed(transform, shared=shared, **kwargs)
|
|
1116
|
+
# Otherwise, we're in object mode and we instantiate the
|
|
1117
|
+
# randomized object.
|
|
1118
|
+
return super().__new__(cls)
|
|
1119
|
+
|
|
1120
|
+
def __init__(self, transform, sample=tuple(), ksample=dict(),
|
|
1088
1121
|
*, shared=False, **kwargs):
|
|
1089
1122
|
"""
|
|
1090
|
-
|
|
1091
1123
|
Parameters
|
|
1092
1124
|
----------
|
|
1093
1125
|
transform : callable(...) -> Transform
|
|
1094
1126
|
A Transform subclass or a function that constructs a Transform.
|
|
1095
1127
|
sample : [list or dict of] callable
|
|
1096
1128
|
A collection of functions that generate parameter values provided
|
|
1097
|
-
to `transform`.
|
|
1129
|
+
to `transform`. Can be args-like or kwargs-like arguments.
|
|
1130
|
+
ksample : dict[callable]
|
|
1131
|
+
Must be kwargs-like arguments.
|
|
1098
1132
|
|
|
1099
1133
|
Other Parameters
|
|
1100
1134
|
----------------
|
|
@@ -1130,9 +1164,7 @@ class RandomizedTransform(NonFinalTransform):
|
|
|
1130
1164
|
|
|
1131
1165
|
def __repr__(self):
|
|
1132
1166
|
if type(self) is RandomizedTransform:
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
except TypeError:
|
|
1137
|
-
pass
|
|
1167
|
+
xform = self.subtransform
|
|
1168
|
+
if isinstance(xform, type) and issubclass(xform, Transform):
|
|
1169
|
+
return f'Randomized{xform.__name__}()'
|
|
1138
1170
|
return super().__repr__()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -8,14 +8,16 @@ __all__ = [
|
|
|
8
8
|
'CropTransform',
|
|
9
9
|
'PadTransform',
|
|
10
10
|
'PowerTwoTransform',
|
|
11
|
+
'Rot90Transform',
|
|
12
|
+
'Rot180Transform',
|
|
13
|
+
'RandomRot90Transform',
|
|
11
14
|
]
|
|
12
|
-
|
|
13
15
|
import math
|
|
14
16
|
from random import shuffle
|
|
15
|
-
from .base import FinalTransform, NonFinalTransform
|
|
17
|
+
from .base import FinalTransform, NonFinalTransform, PerChannelTransform
|
|
16
18
|
from .utils.py import ensure_list
|
|
17
19
|
from .utils.padding import pad
|
|
18
|
-
from .random import Uniform, RandKFrom, Sampler
|
|
20
|
+
from .random import Uniform, RandKFrom, Sampler, RandInt, make_range
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
class FlipTransform(FinalTransform):
|
|
@@ -23,7 +25,6 @@ class FlipTransform(FinalTransform):
|
|
|
23
25
|
|
|
24
26
|
def __init__(self, axis=None, **kwargs):
|
|
25
27
|
"""
|
|
26
|
-
|
|
27
28
|
Parameters
|
|
28
29
|
----------
|
|
29
30
|
axis : [list of] int
|
|
@@ -46,24 +47,30 @@ class FlipTransform(FinalTransform):
|
|
|
46
47
|
class RandomFlipTransform(NonFinalTransform):
|
|
47
48
|
"""Randomly flip one or more axes"""
|
|
48
49
|
|
|
49
|
-
def __init__(self, axes=None, **kwargs):
|
|
50
|
+
def __init__(self, axes=None, *, shared=True, **kwargs):
|
|
50
51
|
"""
|
|
51
|
-
|
|
52
52
|
Parameters
|
|
53
53
|
----------
|
|
54
54
|
axes : Sampler or [list of] int
|
|
55
55
|
Axes that can be flipped (default: all)
|
|
56
|
+
|
|
57
|
+
Other Parameters
|
|
58
|
+
----------------
|
|
56
59
|
shared : {'channels', 'tensors', 'channels+tensors', ''}
|
|
57
60
|
Apply the same flip to all channels and/or tensors
|
|
58
61
|
"""
|
|
59
62
|
axes = kwargs.pop('axis', axes)
|
|
60
|
-
|
|
61
|
-
super().__init__(**kwargs)
|
|
63
|
+
super().__init__(shared=shared, **kwargs)
|
|
62
64
|
self.axes = axes
|
|
63
65
|
|
|
64
66
|
def make_final(self, x, max_depth=float('inf')):
|
|
65
67
|
if max_depth == 0:
|
|
66
68
|
return self
|
|
69
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
70
|
+
return PerChannelTransform(
|
|
71
|
+
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
|
|
72
|
+
**self.get_prm()
|
|
73
|
+
).make_final(x, max_depth-1)
|
|
67
74
|
axes = self.axes or range(1, x.ndim)
|
|
68
75
|
if not isinstance(axes, Sampler):
|
|
69
76
|
rand_axes = RandKFrom(ensure_list(axes))
|
|
@@ -76,7 +83,6 @@ class PermuteAxesTransform(FinalTransform):
|
|
|
76
83
|
|
|
77
84
|
def __init__(self, permutation=None, **kwargs):
|
|
78
85
|
"""
|
|
79
|
-
|
|
80
86
|
Parameters
|
|
81
87
|
----------
|
|
82
88
|
permutation : [list of] int
|
|
@@ -105,23 +111,29 @@ class PermuteAxesTransform(FinalTransform):
|
|
|
105
111
|
class RandomPermuteAxesTransform(NonFinalTransform):
|
|
106
112
|
"""Randomly permute axes"""
|
|
107
113
|
|
|
108
|
-
def __init__(self, axes=None, **kwargs):
|
|
114
|
+
def __init__(self, axes=None, *, shared=True, **kwargs):
|
|
109
115
|
"""
|
|
110
|
-
|
|
111
116
|
Parameters
|
|
112
117
|
----------
|
|
113
118
|
axes : [list of] int
|
|
114
119
|
Axes that can be permuted (default: all)
|
|
120
|
+
|
|
121
|
+
Other Parameters
|
|
122
|
+
----------------
|
|
115
123
|
shared : {'channels', 'tensors', 'channels+tensors', ''}
|
|
116
124
|
Apply the same permutation to all channels and/or tensors
|
|
117
125
|
"""
|
|
118
|
-
|
|
119
|
-
super().__init__(**kwargs)
|
|
126
|
+
super().__init__(shared=shared, **kwargs)
|
|
120
127
|
self.axes = axes
|
|
121
128
|
|
|
122
129
|
def make_final(self, x, max_depth=float('inf')):
|
|
123
130
|
if max_depth == 0:
|
|
124
131
|
return self
|
|
132
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
133
|
+
return PerChannelTransform(
|
|
134
|
+
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
|
|
135
|
+
**self.get_prm()
|
|
136
|
+
).make_final(x, max_depth-1)
|
|
125
137
|
axes = list(self.axes or range(x.ndim-1))
|
|
126
138
|
shuffle(axes)
|
|
127
139
|
return PermuteAxesTransform(
|
|
@@ -129,6 +141,119 @@ class RandomPermuteAxesTransform(NonFinalTransform):
|
|
|
129
141
|
).make_final(x, max_depth-1)
|
|
130
142
|
|
|
131
143
|
|
|
144
|
+
class Rot90Transform(FinalTransform):
|
|
145
|
+
"""
|
|
146
|
+
Apply a 90 (or 180) rotation along one or several axes
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(self, axis=0, negative=False, double=False, **kwargs):
|
|
150
|
+
"""
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
axis : int or list[int]
|
|
154
|
+
Rotation axis (indexing does not account for the channel axis)
|
|
155
|
+
negative : bool or list[bool]
|
|
156
|
+
Rotate by -90 deg instead of 90 deg
|
|
157
|
+
double : bool or list[bool]
|
|
158
|
+
Rotate be 180 instead of 90 (`negative` is then unused)
|
|
159
|
+
"""
|
|
160
|
+
super().__init__(**kwargs)
|
|
161
|
+
self.axis = ensure_list(axis)
|
|
162
|
+
self.negative = ensure_list(negative, len(self.axis))
|
|
163
|
+
self.double = ensure_list(double, len(self.axis))
|
|
164
|
+
|
|
165
|
+
def xform(self, x):
|
|
166
|
+
# this implementation is suboptimal. We should fuse all transpose
|
|
167
|
+
# and all flips into a single "transpose + flip" operation so that
|
|
168
|
+
# a single allocation happens. This will be fine for now.
|
|
169
|
+
|
|
170
|
+
ndim = x.ndim - 1
|
|
171
|
+
axis = [1 + (ndim + a if a < 0 else a) for a in self.axis]
|
|
172
|
+
for ax, neg, dbl in zip(axis, self.negative, self.double):
|
|
173
|
+
if dbl:
|
|
174
|
+
if ndim == 2:
|
|
175
|
+
dims = [1, 2]
|
|
176
|
+
else:
|
|
177
|
+
assert ndim == 3
|
|
178
|
+
dims = [d for d in (1, 2, 3) if d != ax]
|
|
179
|
+
x = x.flip(dims)
|
|
180
|
+
else:
|
|
181
|
+
if ndim == 2:
|
|
182
|
+
dims = [1, 2]
|
|
183
|
+
else:
|
|
184
|
+
assert ndim == 3
|
|
185
|
+
dims = [d for d in (1, 2, 3) if d != ax]
|
|
186
|
+
x = x.transpose(*dims).flip(dims[1] if neg else dims[0])
|
|
187
|
+
return x
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class Rot180Transform(Rot90Transform):
|
|
191
|
+
"""Apply a 180 deg rotation along one or several axes"""
|
|
192
|
+
|
|
193
|
+
def __init__(self, axis=0, **kwargs):
|
|
194
|
+
"""
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
axis : int or list[int]
|
|
198
|
+
Rotation axis (indexing does not account for the channel axis)
|
|
199
|
+
"""
|
|
200
|
+
super().__init__(axis, double=True, **kwargs)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class RandomRot90Transform(NonFinalTransform):
|
|
204
|
+
"""Random set of 90 transforms"""
|
|
205
|
+
|
|
206
|
+
def __init__(self, axes=None, max_rot=2, negative=True,
|
|
207
|
+
*, shared=True, **kwargs):
|
|
208
|
+
"""
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
axes : int or list[int]
|
|
212
|
+
Axes along which rotations can happen.
|
|
213
|
+
If `None`, all axes.
|
|
214
|
+
max_rot : int or Sampler
|
|
215
|
+
Maximum number of consecutive rotations.
|
|
216
|
+
negative : bool
|
|
217
|
+
Whether to authorize negative rotations.
|
|
218
|
+
|
|
219
|
+
Other Parameters
|
|
220
|
+
----------------
|
|
221
|
+
shared : {'channels', 'tensors', 'channels+tensors', ''}
|
|
222
|
+
Apply the same permutation to all channels and/or tensors
|
|
223
|
+
"""
|
|
224
|
+
super().__init__(shared=shared, **kwargs)
|
|
225
|
+
self.axes = axes
|
|
226
|
+
self.max_rot = RandInt.make(make_range(1, max_rot))
|
|
227
|
+
self.negative = negative
|
|
228
|
+
|
|
229
|
+
def make_final(self, x, max_depth=float('inf')):
|
|
230
|
+
if max_depth == 0:
|
|
231
|
+
return self
|
|
232
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
233
|
+
return PerChannelTransform(
|
|
234
|
+
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
|
|
235
|
+
**self.get_prm()
|
|
236
|
+
).make_final(x, max_depth-1)
|
|
237
|
+
ndim = x.ndim - 1
|
|
238
|
+
max_rot = self.max_rot
|
|
239
|
+
if isinstance(max_rot, Sampler):
|
|
240
|
+
max_rot = max_rot()
|
|
241
|
+
axes = self.axes
|
|
242
|
+
if axes is None:
|
|
243
|
+
axes = list(range(ndim))
|
|
244
|
+
if isinstance(axes, (int, list, tuple)):
|
|
245
|
+
axes = ensure_list(axes, max_rot, crop=False)
|
|
246
|
+
if not isinstance(axes, Sampler):
|
|
247
|
+
axes = RandKFrom(axes, max_rot, replacement=True)
|
|
248
|
+
|
|
249
|
+
axes = ensure_list(axes(), max_rot)
|
|
250
|
+
negative = RandKFrom([False, True], max_rot, replacement=True)() \
|
|
251
|
+
if self.negative else [False] * max_rot
|
|
252
|
+
return Rot90Transform(
|
|
253
|
+
axes, negative, **self.get_prm()
|
|
254
|
+
).make_final(max_depth-1)
|
|
255
|
+
|
|
256
|
+
|
|
132
257
|
class CropPadTransform(FinalTransform):
|
|
133
258
|
"""Crop and/or pad a tensor"""
|
|
134
259
|
|
|
@@ -1119,7 +1119,6 @@ class SlicewiseAffineTransform(NonFinalTransform):
|
|
|
1119
1119
|
F = torch.eye(ndim+1, **backend)
|
|
1120
1120
|
F[:ndim, -1] = -offsets
|
|
1121
1121
|
Z = E.clone()
|
|
1122
|
-
print(zooms.shape, Z.shape)
|
|
1123
1122
|
Z.diagonal(0, -1, -2)[:, :-1].copy_(1 + zooms)
|
|
1124
1123
|
T = E.clone()
|
|
1125
1124
|
T[:, :ndim, -1] = translations
|
|
@@ -1362,7 +1361,7 @@ class RandomSlicewiseAffineTransform(NonFinalTransform):
|
|
|
1362
1361
|
# get slice direction
|
|
1363
1362
|
slice = self.slice
|
|
1364
1363
|
if slice is None:
|
|
1365
|
-
slice = RandInt(0, ndim)
|
|
1364
|
+
slice = RandInt(0, ndim - 1)
|
|
1366
1365
|
if isinstance(slice, Sampler):
|
|
1367
1366
|
slice = slice()
|
|
1368
1367
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -163,6 +163,7 @@ class Uniform(Sampler):
|
|
|
163
163
|
def __init__(self, *args, **kwargs):
|
|
164
164
|
"""
|
|
165
165
|
```python
|
|
166
|
+
Uniform()
|
|
166
167
|
Uniform(max)
|
|
167
168
|
Uniform(min, max)
|
|
168
169
|
```
|
|
@@ -171,10 +172,10 @@ class Uniform(Sampler):
|
|
|
171
172
|
----------
|
|
172
173
|
min : float or sequence[float], default=0
|
|
173
174
|
Lower bound (inclusive)
|
|
174
|
-
max : float or sequence[float]
|
|
175
|
+
max : float or sequence[float], default=1
|
|
175
176
|
Upper bound (inclusive or exclusive, depending on rounding)
|
|
176
177
|
"""
|
|
177
|
-
min, max = 0,
|
|
178
|
+
min, max = 0, 1
|
|
178
179
|
if len(args) == 2:
|
|
179
180
|
min, max = args
|
|
180
181
|
elif len(args) == 1:
|
|
@@ -183,8 +184,6 @@ class Uniform(Sampler):
|
|
|
183
184
|
min = kwargs['min']
|
|
184
185
|
if 'max' in kwargs:
|
|
185
186
|
max = kwargs['max']
|
|
186
|
-
if max is None:
|
|
187
|
-
raise ValueError('Expected at least one argument')
|
|
188
187
|
super().__init__(min=min, max=max)
|
|
189
188
|
|
|
190
189
|
def __call__(self, n=None, **backend):
|
|
@@ -261,7 +260,7 @@ class RandKFrom(Sampler):
|
|
|
261
260
|
self.replacement = replacement
|
|
262
261
|
|
|
263
262
|
def __call__(self, n=None, **backend):
|
|
264
|
-
k = self.k or RandInt(len(self.range))()
|
|
263
|
+
k = self.k or RandInt(1, len(self.range))()
|
|
265
264
|
if isinstance(n, (list, tuple)) or n:
|
|
266
265
|
raise ValueError('RandKFrom cannot sample multiple elements')
|
|
267
266
|
if not self.replacement:
|
|
@@ -269,7 +268,7 @@ class RandKFrom(Sampler):
|
|
|
269
268
|
random.shuffle(range)
|
|
270
269
|
return range[:k]
|
|
271
270
|
else:
|
|
272
|
-
index = RandInt(len(self.range))(k)
|
|
271
|
+
index = RandInt(0, len(self.range)-1)(k)
|
|
273
272
|
return [self.range[i] for i in index]
|
|
274
273
|
|
|
275
274
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -11,6 +11,9 @@ from cornucopia.fov import (
|
|
|
11
11
|
CropTransform,
|
|
12
12
|
PadTransform,
|
|
13
13
|
PowerTwoTransform,
|
|
14
|
+
Rot90Transform,
|
|
15
|
+
Rot180Transform,
|
|
16
|
+
RandomRot90Transform,
|
|
14
17
|
)
|
|
15
18
|
|
|
16
19
|
SEED = 12345678
|
|
@@ -53,6 +56,37 @@ def test_run_fov_permute_random(size):
|
|
|
53
56
|
assert True
|
|
54
57
|
|
|
55
58
|
|
|
59
|
+
@pytest.mark.parametrize("size", sizes)
|
|
60
|
+
@pytest.mark.parametrize("axes", [0, 1, [0, 1], [0, 0]])
|
|
61
|
+
@pytest.mark.parametrize("negative", [False, True])
|
|
62
|
+
@pytest.mark.parametrize("double", [False, True])
|
|
63
|
+
def test_run_rot90_permute(size, axes, negative, double):
|
|
64
|
+
random.seed(SEED)
|
|
65
|
+
torch.random.manual_seed(SEED)
|
|
66
|
+
x = torch.randn(size)
|
|
67
|
+
_ = Rot90Transform(axes, negative, double)(x)
|
|
68
|
+
assert True
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@pytest.mark.parametrize("size", sizes)
|
|
72
|
+
@pytest.mark.parametrize("axes", [0, 1, [0, 1], [0, 0]])
|
|
73
|
+
def test_run_rot180_permute(size, axes):
|
|
74
|
+
random.seed(SEED)
|
|
75
|
+
torch.random.manual_seed(SEED)
|
|
76
|
+
x = torch.randn(size)
|
|
77
|
+
_ = Rot180Transform(axes)(x)
|
|
78
|
+
assert True
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@pytest.mark.parametrize("size", sizes)
|
|
82
|
+
def test_run_rot90_random(size):
|
|
83
|
+
random.seed(SEED)
|
|
84
|
+
torch.random.manual_seed(SEED)
|
|
85
|
+
x = torch.randn(size)
|
|
86
|
+
_ = RandomRot90Transform()(x)
|
|
87
|
+
assert True
|
|
88
|
+
|
|
89
|
+
|
|
56
90
|
@pytest.mark.parametrize("size", sizes)
|
|
57
91
|
def test_run_fov_patch(size):
|
|
58
92
|
random.seed(SEED)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: cornucopia
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: An abundance of augmentation layers
|
|
5
5
|
Home-page: UNKNOWN
|
|
6
6
|
Author: Yael Balbastre
|
|
@@ -27,9 +27,6 @@ Description: <picture align="center">
|
|
|
27
27
|
theoretically be used within any dataloader pipeline,
|
|
28
28
|
independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
|
|
29
29
|
|
|
30
|
-
## Installation
|
|
31
|
-
|
|
32
|
-
|
|
33
30
|
## Installation
|
|
34
31
|
|
|
35
32
|
### Dependencies
|
|
@@ -46,12 +43,18 @@ Description: <picture align="center">
|
|
|
46
43
|
conda install cornucopia -c balbasty -c pytorch -c conda-forge
|
|
47
44
|
```
|
|
48
45
|
|
|
49
|
-
### Pip
|
|
46
|
+
### Pip (release)
|
|
50
47
|
|
|
51
48
|
```sh
|
|
52
49
|
pip install cornucopia
|
|
53
50
|
```
|
|
54
51
|
|
|
52
|
+
### Pip (dev)
|
|
53
|
+
|
|
54
|
+
```sh
|
|
55
|
+
pip install cornucopia@git+https://github.com/balbasty/cornucopia
|
|
56
|
+
```
|
|
57
|
+
|
|
55
58
|
## Documentation
|
|
56
59
|
|
|
57
60
|
Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|