cornucopia 0.1.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cornucopia-0.1.0 → cornucopia-0.3.0}/LICENSE +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/PKG-INFO +10 -7
- {cornucopia-0.1.0 → cornucopia-0.3.0}/README.md +9 -6
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/__init__.py +19 -12
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/_version.py +3 -3
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/base.py +51 -19
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/baseutils.py +4 -4
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/contrast.py +4 -3
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/ctx.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/fov.py +141 -16
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/geometric.py +6 -7
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/intensity.py +9 -8
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/io.py +2 -2
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/kspace.py +3 -3
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/labels.py +9 -9
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/noise.py +2 -2
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/psf.py +4 -4
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/qmri.py +6 -6
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/random.py +5 -6
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/special.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/synth.py +11 -1
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/__init__.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_contrast.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_fov.py +34 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_geometric.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_intensity.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_kspace.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_labels.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_noise.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_psf.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_qmri.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/tests/test_run_synth.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/__init__.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/b0.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/bounds.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/conv.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/gmm.py +7 -2
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/indexing.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/io.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/jit.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/kernels.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/morpho.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/padding.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/patch.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/py.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/version.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia/utils/warps.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/PKG-INFO +10 -7
- {cornucopia-0.1.0 → cornucopia-0.3.0}/pyproject.toml +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/setup.cfg +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/setup.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/versioneer.py +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/SOURCES.txt +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/dependency_links.txt +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/requires.txt +0 -0
- {cornucopia-0.1.0 → cornucopia-0.3.0}/cornucopia.egg-info/top_level.txt +0 -0
|
File without changes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: cornucopia
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: An abundance of augmentation layers
|
|
5
5
|
Home-page: UNKNOWN
|
|
6
6
|
Author: Yael Balbastre
|
|
@@ -10,7 +10,7 @@ Project-URL: Source Code, https://github.com/balbasty/cornucopia
|
|
|
10
10
|
Description: <picture align="center">
|
|
11
11
|
<source media="(prefers-color-scheme: dark)" srcset="docs/icons/cornucopia_lightorange.svg">
|
|
12
12
|
<source media="(prefers-color-scheme: light)" srcset="docs/icons/cornucopia_orange.svg">
|
|
13
|
-
<img alt="Cornucopia logo" src="docs/icons/cornucopia_orange.svg">
|
|
13
|
+
<img alt="Cornucopia logo" src="https://github.com/balbasty/cornucopia/raw/main/docs/icons/cornucopia_orange.svg">
|
|
14
14
|
</picture>
|
|
15
15
|
|
|
16
16
|
The `cornucopia` package provides a generic framework for preprocessing,
|
|
@@ -27,9 +27,6 @@ Description: <picture align="center">
|
|
|
27
27
|
theoretically be used within any dataloader pipeline,
|
|
28
28
|
independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
|
|
29
29
|
|
|
30
|
-
## Installation
|
|
31
|
-
|
|
32
|
-
|
|
33
30
|
## Installation
|
|
34
31
|
|
|
35
32
|
### Dependencies
|
|
@@ -43,15 +40,21 @@ Description: <picture align="center">
|
|
|
43
40
|
### Conda
|
|
44
41
|
|
|
45
42
|
```sh
|
|
46
|
-
conda install cornucopia -c balbasty -c pytorch
|
|
43
|
+
conda install cornucopia -c balbasty -c pytorch -c conda-forge
|
|
47
44
|
```
|
|
48
45
|
|
|
49
|
-
### Pip
|
|
46
|
+
### Pip (release)
|
|
50
47
|
|
|
51
48
|
```sh
|
|
52
49
|
pip install cornucopia
|
|
53
50
|
```
|
|
54
51
|
|
|
52
|
+
### Pip (dev)
|
|
53
|
+
|
|
54
|
+
```sh
|
|
55
|
+
pip install cornucopia@git+https://github.com/balbasty/cornucopia
|
|
56
|
+
```
|
|
57
|
+
|
|
55
58
|
## Documentation
|
|
56
59
|
|
|
57
60
|
Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
<picture align="center">
|
|
2
2
|
<source media="(prefers-color-scheme: dark)" srcset="docs/icons/cornucopia_lightorange.svg">
|
|
3
3
|
<source media="(prefers-color-scheme: light)" srcset="docs/icons/cornucopia_orange.svg">
|
|
4
|
-
<img alt="Cornucopia logo" src="docs/icons/cornucopia_orange.svg">
|
|
4
|
+
<img alt="Cornucopia logo" src="https://github.com/balbasty/cornucopia/raw/main/docs/icons/cornucopia_orange.svg">
|
|
5
5
|
</picture>
|
|
6
6
|
|
|
7
7
|
The `cornucopia` package provides a generic framework for preprocessing,
|
|
@@ -18,9 +18,6 @@ Since gradients are not expected to backpropagate through its layers, it can
|
|
|
18
18
|
theoretically be used within any dataloader pipeline,
|
|
19
19
|
independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
|
|
20
20
|
|
|
21
|
-
## Installation
|
|
22
|
-
|
|
23
|
-
|
|
24
21
|
## Installation
|
|
25
22
|
|
|
26
23
|
### Dependencies
|
|
@@ -34,15 +31,21 @@ independent of the downstream learning framework (pytorch, tensorflow, jax, ...)
|
|
|
34
31
|
### Conda
|
|
35
32
|
|
|
36
33
|
```sh
|
|
37
|
-
conda install cornucopia -c balbasty -c pytorch
|
|
34
|
+
conda install cornucopia -c balbasty -c pytorch -c conda-forge
|
|
38
35
|
```
|
|
39
36
|
|
|
40
|
-
### Pip
|
|
37
|
+
### Pip (release)
|
|
41
38
|
|
|
42
39
|
```sh
|
|
43
40
|
pip install cornucopia
|
|
44
41
|
```
|
|
45
42
|
|
|
43
|
+
### Pip (dev)
|
|
44
|
+
|
|
45
|
+
```sh
|
|
46
|
+
pip install cornucopia@git+https://github.com/balbasty/cornucopia
|
|
47
|
+
```
|
|
48
|
+
|
|
46
49
|
## Documentation
|
|
47
50
|
|
|
48
51
|
Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Flexible transforms for pre-processing and augmentation
|
|
2
3
|
|
|
3
4
|
Example on how to use this machinery to generate within-subject
|
|
4
5
|
image pairs with a random affine deformation between them::
|
|
@@ -30,19 +31,25 @@ image pairs with a random affine deformation between them::
|
|
|
30
31
|
|
|
31
32
|
"""
|
|
32
33
|
|
|
33
|
-
# TODO:
|
|
34
|
-
# [x] Make it a standalone package?
|
|
35
|
-
# [x] Move samplers in their own file
|
|
36
|
-
# [x] Add IO transforms (that transform filenames in tensors)
|
|
37
|
-
# [ ] Better deal with separable/shared transforms
|
|
38
|
-
# [ ] Add a SharedTransform class (like Randomized) that does the heavy
|
|
39
|
-
# lifting
|
|
40
|
-
# [ ] By default (non shared), let Transforms handle multi-channel
|
|
41
|
-
# data (currently we loop across channels in the base class)
|
|
42
|
-
|
|
43
34
|
from . import random # noqa: F401
|
|
44
35
|
from . import ctx # noqa: F401
|
|
45
|
-
from .
|
|
36
|
+
from . import base # noqa: F401
|
|
37
|
+
from . import special # noqa: F401
|
|
38
|
+
from . import contrast # noqa: F401
|
|
39
|
+
from . import geometric # noqa: F401
|
|
40
|
+
from . import intensity # noqa: F401
|
|
41
|
+
from . import io # noqa: F401
|
|
42
|
+
from . import fov # noqa: F401
|
|
43
|
+
from . import kspace # noqa: F401
|
|
44
|
+
from . import labels # noqa: F401
|
|
45
|
+
from . import noise # noqa: F401
|
|
46
|
+
from . import psf # noqa: F401
|
|
47
|
+
from . import qmri # noqa: F401
|
|
48
|
+
from . import synth # noqa: F401
|
|
49
|
+
from . import utils # noqa: F401
|
|
50
|
+
|
|
51
|
+
from .random import * # noqa: F401,F403
|
|
52
|
+
from .ctx import * # noqa: F401,F403
|
|
46
53
|
from .base import * # noqa: F401,F403
|
|
47
54
|
from .special import * # noqa: F401,F403
|
|
48
55
|
from .contrast import * # noqa: F401,F403
|
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "
|
|
11
|
+
"date": "2024-04-19T14:23:50+0100",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "37de94f181b9a97eebd21460f4df63ae4a0750f8",
|
|
15
|
+
"version": "0.3.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -148,7 +148,7 @@ class Transform(nn.Module):
|
|
|
148
148
|
return x
|
|
149
149
|
|
|
150
150
|
# now we're working with a single tensor (or str)
|
|
151
|
-
y = self.
|
|
151
|
+
y = self.xform(x)
|
|
152
152
|
if not isinstance(y, Returned):
|
|
153
153
|
if not isinstance(y, type(self.returns)):
|
|
154
154
|
y = dict(input=x, output=y)
|
|
@@ -240,7 +240,7 @@ class FinalTransform(Transform):
|
|
|
240
240
|
def is_final(self):
|
|
241
241
|
return True
|
|
242
242
|
|
|
243
|
-
def
|
|
243
|
+
def xform(self, x):
|
|
244
244
|
"""Apply the transform to a tensor
|
|
245
245
|
|
|
246
246
|
Parameters
|
|
@@ -283,7 +283,7 @@ class IdentityTransform(FinalTransform):
|
|
|
283
283
|
def __init__(self, **kwargs):
|
|
284
284
|
super().__init__(**kwargs)
|
|
285
285
|
|
|
286
|
-
def
|
|
286
|
+
def xform(self, x):
|
|
287
287
|
return x
|
|
288
288
|
|
|
289
289
|
def make_inverse(self):
|
|
@@ -305,12 +305,12 @@ class SharedMixin:
|
|
|
305
305
|
shared = ''
|
|
306
306
|
return shared
|
|
307
307
|
|
|
308
|
-
def
|
|
308
|
+
def xform(self, x):
|
|
309
309
|
if 'channels' in self.shared:
|
|
310
310
|
xform = self.make_final(x[:1], max_depth=1)
|
|
311
311
|
else:
|
|
312
312
|
xform = self.make_final(x, max_depth=1)
|
|
313
|
-
return xform.
|
|
313
|
+
return xform.xform(x)
|
|
314
314
|
|
|
315
315
|
def forward(self, *a, **k):
|
|
316
316
|
return self._shared_forward(*a, **k)
|
|
@@ -359,7 +359,7 @@ class NonFinalTransform(SharedMixin, Transform):
|
|
|
359
359
|
super().__init__(**kwargs)
|
|
360
360
|
self.shared = self._prepare_shared(shared)
|
|
361
361
|
|
|
362
|
-
def make_final(self, x, max_depth=float('inf')
|
|
362
|
+
def make_final(self, x, max_depth=float('inf')):
|
|
363
363
|
if self.is_final or max_depth == 0:
|
|
364
364
|
return self
|
|
365
365
|
return NotImplemented
|
|
@@ -447,7 +447,7 @@ class SequentialTransform(SpecialMixin, SharedMixin, Transform):
|
|
|
447
447
|
def make_final(self, x, max_depth=float('inf')):
|
|
448
448
|
if max_depth == 0:
|
|
449
449
|
return self
|
|
450
|
-
x = VirtualTensor.from_any(x, compute_stats=True)
|
|
450
|
+
# x = VirtualTensor.from_any(x, compute_stats=True)
|
|
451
451
|
trf = []
|
|
452
452
|
for t in self:
|
|
453
453
|
t = t.make_final(x, max_depth=max_depth-1)
|
|
@@ -525,7 +525,7 @@ class PerChannelTransform(SpecialMixin, Transform):
|
|
|
525
525
|
trf = PerChannelTransform(trf, **prm)
|
|
526
526
|
return trf
|
|
527
527
|
|
|
528
|
-
def
|
|
528
|
+
def xform(self, x):
|
|
529
529
|
results = []
|
|
530
530
|
for i, t in enumerate(self.transforms):
|
|
531
531
|
with ReturningTransform(t, self.returns), \
|
|
@@ -1068,33 +1068,67 @@ class RandomizedTransform(NonFinalTransform):
|
|
|
1068
1068
|
"""
|
|
1069
1069
|
Transform generated by randomizing some parameters of another transform.
|
|
1070
1070
|
|
|
1071
|
+
!!! note "`ctx.randomize` is an alias for `RandomizedTransform`"
|
|
1072
|
+
|
|
1071
1073
|
!!! example "Gaussian noise with randomized variance"
|
|
1072
1074
|
Object call
|
|
1073
1075
|
```python
|
|
1074
1076
|
import cornucopia as cc
|
|
1075
|
-
hypernoise = RandomizedTransform(cc.GaussianNoise, [cc.Uniform(
|
|
1077
|
+
hypernoise = cc.RandomizedTransform(cc.GaussianNoise, [cc.Uniform()])
|
|
1076
1078
|
img = hypernoise(img)
|
|
1077
1079
|
```
|
|
1078
|
-
|
|
1080
|
+
|
|
1081
|
+
Delayed call
|
|
1079
1082
|
```python
|
|
1080
1083
|
import cornucopia as cc
|
|
1081
|
-
|
|
1084
|
+
MyRandomNoise = cc.randomize(cc.GaussianNoise)
|
|
1085
|
+
hypernoise = MyRandomNoise(cc.Uniform())
|
|
1082
1086
|
img = hypernoise(img)
|
|
1083
1087
|
```
|
|
1084
1088
|
|
|
1085
1089
|
"""
|
|
1086
1090
|
|
|
1087
|
-
|
|
1091
|
+
class Delayed:
|
|
1092
|
+
# Temproary parameter holder for delayed calls
|
|
1093
|
+
def __init__(self, transform, **kwargs):
|
|
1094
|
+
self.transform = transform
|
|
1095
|
+
self.kwargs = kwargs
|
|
1096
|
+
|
|
1097
|
+
def __call__(self, *args, **kwargs):
|
|
1098
|
+
return RandomizedTransform(
|
|
1099
|
+
self.transform, args, kwargs, **self.kwargs)
|
|
1100
|
+
|
|
1101
|
+
def __new__(cls, *args, **kwargs):
|
|
1102
|
+
if cls is RandomizedTransform:
|
|
1103
|
+
return cls._base_new(*args, **kwargs)
|
|
1104
|
+
return super().__new__(cls)
|
|
1105
|
+
|
|
1106
|
+
@classmethod
|
|
1107
|
+
def _base_new(cls, transform, sample=tuple(), ksample=dict(),
|
|
1108
|
+
*, shared=False, **kwargs):
|
|
1109
|
+
assert cls is RandomizedTransform
|
|
1110
|
+
if not sample and not ksample:
|
|
1111
|
+
# If no arguments are passed, it means that the user calls
|
|
1112
|
+
# this in "delayed/functional" mode. In that case, we return
|
|
1113
|
+
# a callable object that returns the constructed instance
|
|
1114
|
+
# using the call-time arguments.
|
|
1115
|
+
return cls.Delayed(transform, shared=shared, **kwargs)
|
|
1116
|
+
# Otherwise, we're in object mode and we instantiate the
|
|
1117
|
+
# randomized object.
|
|
1118
|
+
return super().__new__(cls)
|
|
1119
|
+
|
|
1120
|
+
def __init__(self, transform, sample=tuple(), ksample=dict(),
|
|
1088
1121
|
*, shared=False, **kwargs):
|
|
1089
1122
|
"""
|
|
1090
|
-
|
|
1091
1123
|
Parameters
|
|
1092
1124
|
----------
|
|
1093
1125
|
transform : callable(...) -> Transform
|
|
1094
1126
|
A Transform subclass or a function that constructs a Transform.
|
|
1095
1127
|
sample : [list or dict of] callable
|
|
1096
1128
|
A collection of functions that generate parameter values provided
|
|
1097
|
-
to `transform`.
|
|
1129
|
+
to `transform`. Can be args-like or kwargs-like arguments.
|
|
1130
|
+
ksample : dict[callable]
|
|
1131
|
+
Must be kwargs-like arguments.
|
|
1098
1132
|
|
|
1099
1133
|
Other Parameters
|
|
1100
1134
|
----------------
|
|
@@ -1130,9 +1164,7 @@ class RandomizedTransform(NonFinalTransform):
|
|
|
1130
1164
|
|
|
1131
1165
|
def __repr__(self):
|
|
1132
1166
|
if type(self) is RandomizedTransform:
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
except TypeError:
|
|
1137
|
-
pass
|
|
1167
|
+
xform = self.subtransform
|
|
1168
|
+
if isinstance(xform, type) and issubclass(xform, Transform):
|
|
1169
|
+
return f'Randomized{xform.__name__}()'
|
|
1138
1170
|
return super().__repr__()
|
|
@@ -24,7 +24,7 @@ def get_first_element(x, include=None, exclude=None, types=None):
|
|
|
24
24
|
if ok:
|
|
25
25
|
return v, True
|
|
26
26
|
return None, False
|
|
27
|
-
if torch.is_tensor(x) or isinstance(x, types):
|
|
27
|
+
if torch.is_tensor(x) or (types and isinstance(x, types)):
|
|
28
28
|
return x, True
|
|
29
29
|
return x, False
|
|
30
30
|
|
|
@@ -213,9 +213,9 @@ class VirtualTensor:
|
|
|
213
213
|
@classmethod
|
|
214
214
|
def from_tensor(cls, x, compute_stats=False):
|
|
215
215
|
if compute_stats:
|
|
216
|
-
vmin = x.
|
|
217
|
-
vmax = x.
|
|
218
|
-
vmean = x.mean(dim=list(range(1, x.ndim)))
|
|
216
|
+
vmin = x.reshape([len(x), -1]).min(dim=-1).values
|
|
217
|
+
vmax = x.reshape([len(x), -1]).max(dim=-1).values
|
|
218
|
+
vmean = x.float().mean(dim=list(range(1, x.ndim)))
|
|
219
219
|
else:
|
|
220
220
|
vmin = vmax = vmean = None
|
|
221
221
|
return VirtualTensor(x.shape, dtype=x.dtype, device=x.device,
|
|
@@ -45,7 +45,7 @@ class ContrastMixtureTransform(NonFinalTransform):
|
|
|
45
45
|
self.mu = mu
|
|
46
46
|
self.sigma = sigma
|
|
47
47
|
|
|
48
|
-
def
|
|
48
|
+
def xform(self, x):
|
|
49
49
|
z = self.z.to(x)
|
|
50
50
|
mu0 = self.mu0.to(x)
|
|
51
51
|
sigma0 = self.sigma0.to(x)
|
|
@@ -114,7 +114,8 @@ class ContrastMixtureTransform(NonFinalTransform):
|
|
|
114
114
|
mu = torch.rand_like(
|
|
115
115
|
old_mu).mul_(old_mu_max - old_mu_min).add_(old_mu_min)
|
|
116
116
|
sigma = torch.rand_like(
|
|
117
|
-
old_sigma_diag
|
|
117
|
+
old_sigma_diag
|
|
118
|
+
).mul_(old_sigma_max - old_sigma_min).add_(old_sigma_min)
|
|
118
119
|
corr = torch.rand([len(old_mu), nc*(nc-1)//2], **backend).mul_(0.5)
|
|
119
120
|
|
|
120
121
|
fullsigma = torch.eye(nc, **backend).expand([nk, nc, nc]).clone()
|
|
@@ -145,7 +146,7 @@ class ContrastLookupTransform(NonFinalTransform):
|
|
|
145
146
|
self.edges = edges
|
|
146
147
|
self.mu = mu
|
|
147
148
|
|
|
148
|
-
def
|
|
149
|
+
def xform(self, x):
|
|
149
150
|
edges, mu = self.edges.to(x), self.mu.to(x)
|
|
150
151
|
mu0 = (edges[:-1] + edges[1:]) / 2
|
|
151
152
|
nk = len(mu)
|
|
File without changes
|
|
@@ -8,14 +8,16 @@ __all__ = [
|
|
|
8
8
|
'CropTransform',
|
|
9
9
|
'PadTransform',
|
|
10
10
|
'PowerTwoTransform',
|
|
11
|
+
'Rot90Transform',
|
|
12
|
+
'Rot180Transform',
|
|
13
|
+
'RandomRot90Transform',
|
|
11
14
|
]
|
|
12
|
-
|
|
13
15
|
import math
|
|
14
16
|
from random import shuffle
|
|
15
|
-
from .base import FinalTransform, NonFinalTransform
|
|
17
|
+
from .base import FinalTransform, NonFinalTransform, PerChannelTransform
|
|
16
18
|
from .utils.py import ensure_list
|
|
17
19
|
from .utils.padding import pad
|
|
18
|
-
from .random import Uniform, RandKFrom, Sampler
|
|
20
|
+
from .random import Uniform, RandKFrom, Sampler, RandInt, make_range
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
class FlipTransform(FinalTransform):
|
|
@@ -23,7 +25,6 @@ class FlipTransform(FinalTransform):
|
|
|
23
25
|
|
|
24
26
|
def __init__(self, axis=None, **kwargs):
|
|
25
27
|
"""
|
|
26
|
-
|
|
27
28
|
Parameters
|
|
28
29
|
----------
|
|
29
30
|
axis : [list of] int
|
|
@@ -32,7 +33,7 @@ class FlipTransform(FinalTransform):
|
|
|
32
33
|
super().__init__(**kwargs)
|
|
33
34
|
self.axis = axis
|
|
34
35
|
|
|
35
|
-
def
|
|
36
|
+
def xform(self, x):
|
|
36
37
|
axis = self.axis
|
|
37
38
|
if axis is None:
|
|
38
39
|
axis = list(range(1, x.ndim))
|
|
@@ -46,24 +47,30 @@ class FlipTransform(FinalTransform):
|
|
|
46
47
|
class RandomFlipTransform(NonFinalTransform):
|
|
47
48
|
"""Randomly flip one or more axes"""
|
|
48
49
|
|
|
49
|
-
def __init__(self, axes=None, **kwargs):
|
|
50
|
+
def __init__(self, axes=None, *, shared=True, **kwargs):
|
|
50
51
|
"""
|
|
51
|
-
|
|
52
52
|
Parameters
|
|
53
53
|
----------
|
|
54
54
|
axes : Sampler or [list of] int
|
|
55
55
|
Axes that can be flipped (default: all)
|
|
56
|
+
|
|
57
|
+
Other Parameters
|
|
58
|
+
----------------
|
|
56
59
|
shared : {'channels', 'tensors', 'channels+tensors', ''}
|
|
57
60
|
Apply the same flip to all channels and/or tensors
|
|
58
61
|
"""
|
|
59
62
|
axes = kwargs.pop('axis', axes)
|
|
60
|
-
|
|
61
|
-
super().__init__(**kwargs)
|
|
63
|
+
super().__init__(shared=shared, **kwargs)
|
|
62
64
|
self.axes = axes
|
|
63
65
|
|
|
64
66
|
def make_final(self, x, max_depth=float('inf')):
|
|
65
67
|
if max_depth == 0:
|
|
66
68
|
return self
|
|
69
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
70
|
+
return PerChannelTransform(
|
|
71
|
+
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
|
|
72
|
+
**self.get_prm()
|
|
73
|
+
).make_final(x, max_depth-1)
|
|
67
74
|
axes = self.axes or range(1, x.ndim)
|
|
68
75
|
if not isinstance(axes, Sampler):
|
|
69
76
|
rand_axes = RandKFrom(ensure_list(axes))
|
|
@@ -76,7 +83,6 @@ class PermuteAxesTransform(FinalTransform):
|
|
|
76
83
|
|
|
77
84
|
def __init__(self, permutation=None, **kwargs):
|
|
78
85
|
"""
|
|
79
|
-
|
|
80
86
|
Parameters
|
|
81
87
|
----------
|
|
82
88
|
permutation : [list of] int
|
|
@@ -86,7 +92,7 @@ class PermuteAxesTransform(FinalTransform):
|
|
|
86
92
|
super().__init__(**kwargs)
|
|
87
93
|
self.permutation = permutation
|
|
88
94
|
|
|
89
|
-
def
|
|
95
|
+
def xform(self, x):
|
|
90
96
|
permutation = self.permutation
|
|
91
97
|
if permutation is None:
|
|
92
98
|
permutation = list(reversed(range(x.dim()-1)))
|
|
@@ -105,23 +111,29 @@ class PermuteAxesTransform(FinalTransform):
|
|
|
105
111
|
class RandomPermuteAxesTransform(NonFinalTransform):
|
|
106
112
|
"""Randomly permute axes"""
|
|
107
113
|
|
|
108
|
-
def __init__(self, axes=None, **kwargs):
|
|
114
|
+
def __init__(self, axes=None, *, shared=True, **kwargs):
|
|
109
115
|
"""
|
|
110
|
-
|
|
111
116
|
Parameters
|
|
112
117
|
----------
|
|
113
118
|
axes : [list of] int
|
|
114
119
|
Axes that can be permuted (default: all)
|
|
120
|
+
|
|
121
|
+
Other Parameters
|
|
122
|
+
----------------
|
|
115
123
|
shared : {'channels', 'tensors', 'channels+tensors', ''}
|
|
116
124
|
Apply the same permutation to all channels and/or tensors
|
|
117
125
|
"""
|
|
118
|
-
|
|
119
|
-
super().__init__(**kwargs)
|
|
126
|
+
super().__init__(shared=shared, **kwargs)
|
|
120
127
|
self.axes = axes
|
|
121
128
|
|
|
122
129
|
def make_final(self, x, max_depth=float('inf')):
|
|
123
130
|
if max_depth == 0:
|
|
124
131
|
return self
|
|
132
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
133
|
+
return PerChannelTransform(
|
|
134
|
+
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
|
|
135
|
+
**self.get_prm()
|
|
136
|
+
).make_final(x, max_depth-1)
|
|
125
137
|
axes = list(self.axes or range(x.ndim-1))
|
|
126
138
|
shuffle(axes)
|
|
127
139
|
return PermuteAxesTransform(
|
|
@@ -129,6 +141,119 @@ class RandomPermuteAxesTransform(NonFinalTransform):
|
|
|
129
141
|
).make_final(x, max_depth-1)
|
|
130
142
|
|
|
131
143
|
|
|
144
|
+
class Rot90Transform(FinalTransform):
|
|
145
|
+
"""
|
|
146
|
+
Apply a 90 (or 180) rotation along one or several axes
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(self, axis=0, negative=False, double=False, **kwargs):
|
|
150
|
+
"""
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
axis : int or list[int]
|
|
154
|
+
Rotation axis (indexing does not account for the channel axis)
|
|
155
|
+
negative : bool or list[bool]
|
|
156
|
+
Rotate by -90 deg instead of 90 deg
|
|
157
|
+
double : bool or list[bool]
|
|
158
|
+
Rotate be 180 instead of 90 (`negative` is then unused)
|
|
159
|
+
"""
|
|
160
|
+
super().__init__(**kwargs)
|
|
161
|
+
self.axis = ensure_list(axis)
|
|
162
|
+
self.negative = ensure_list(negative, len(self.axis))
|
|
163
|
+
self.double = ensure_list(double, len(self.axis))
|
|
164
|
+
|
|
165
|
+
def xform(self, x):
|
|
166
|
+
# this implementation is suboptimal. We should fuse all transpose
|
|
167
|
+
# and all flips into a single "transpose + flip" operation so that
|
|
168
|
+
# a single allocation happens. This will be fine for now.
|
|
169
|
+
|
|
170
|
+
ndim = x.ndim - 1
|
|
171
|
+
axis = [1 + (ndim + a if a < 0 else a) for a in self.axis]
|
|
172
|
+
for ax, neg, dbl in zip(axis, self.negative, self.double):
|
|
173
|
+
if dbl:
|
|
174
|
+
if ndim == 2:
|
|
175
|
+
dims = [1, 2]
|
|
176
|
+
else:
|
|
177
|
+
assert ndim == 3
|
|
178
|
+
dims = [d for d in (1, 2, 3) if d != ax]
|
|
179
|
+
x = x.flip(dims)
|
|
180
|
+
else:
|
|
181
|
+
if ndim == 2:
|
|
182
|
+
dims = [1, 2]
|
|
183
|
+
else:
|
|
184
|
+
assert ndim == 3
|
|
185
|
+
dims = [d for d in (1, 2, 3) if d != ax]
|
|
186
|
+
x = x.transpose(*dims).flip(dims[1] if neg else dims[0])
|
|
187
|
+
return x
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class Rot180Transform(Rot90Transform):
|
|
191
|
+
"""Apply a 180 deg rotation along one or several axes"""
|
|
192
|
+
|
|
193
|
+
def __init__(self, axis=0, **kwargs):
|
|
194
|
+
"""
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
axis : int or list[int]
|
|
198
|
+
Rotation axis (indexing does not account for the channel axis)
|
|
199
|
+
"""
|
|
200
|
+
super().__init__(axis, double=True, **kwargs)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class RandomRot90Transform(NonFinalTransform):
|
|
204
|
+
"""Random set of 90 transforms"""
|
|
205
|
+
|
|
206
|
+
def __init__(self, axes=None, max_rot=2, negative=True,
|
|
207
|
+
*, shared=True, **kwargs):
|
|
208
|
+
"""
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
axes : int or list[int]
|
|
212
|
+
Axes along which rotations can happen.
|
|
213
|
+
If `None`, all axes.
|
|
214
|
+
max_rot : int or Sampler
|
|
215
|
+
Maximum number of consecutive rotations.
|
|
216
|
+
negative : bool
|
|
217
|
+
Whether to authorize negative rotations.
|
|
218
|
+
|
|
219
|
+
Other Parameters
|
|
220
|
+
----------------
|
|
221
|
+
shared : {'channels', 'tensors', 'channels+tensors', ''}
|
|
222
|
+
Apply the same permutation to all channels and/or tensors
|
|
223
|
+
"""
|
|
224
|
+
super().__init__(shared=shared, **kwargs)
|
|
225
|
+
self.axes = axes
|
|
226
|
+
self.max_rot = RandInt.make(make_range(1, max_rot))
|
|
227
|
+
self.negative = negative
|
|
228
|
+
|
|
229
|
+
def make_final(self, x, max_depth=float('inf')):
|
|
230
|
+
if max_depth == 0:
|
|
231
|
+
return self
|
|
232
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
233
|
+
return PerChannelTransform(
|
|
234
|
+
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
|
|
235
|
+
**self.get_prm()
|
|
236
|
+
).make_final(x, max_depth-1)
|
|
237
|
+
ndim = x.ndim - 1
|
|
238
|
+
max_rot = self.max_rot
|
|
239
|
+
if isinstance(max_rot, Sampler):
|
|
240
|
+
max_rot = max_rot()
|
|
241
|
+
axes = self.axes
|
|
242
|
+
if axes is None:
|
|
243
|
+
axes = list(range(ndim))
|
|
244
|
+
if isinstance(axes, (int, list, tuple)):
|
|
245
|
+
axes = ensure_list(axes, max_rot, crop=False)
|
|
246
|
+
if not isinstance(axes, Sampler):
|
|
247
|
+
axes = RandKFrom(axes, max_rot, replacement=True)
|
|
248
|
+
|
|
249
|
+
axes = ensure_list(axes(), max_rot)
|
|
250
|
+
negative = RandKFrom([False, True], max_rot, replacement=True)() \
|
|
251
|
+
if self.negative else [False] * max_rot
|
|
252
|
+
return Rot90Transform(
|
|
253
|
+
axes, negative, **self.get_prm()
|
|
254
|
+
).make_final(max_depth-1)
|
|
255
|
+
|
|
256
|
+
|
|
132
257
|
class CropPadTransform(FinalTransform):
|
|
133
258
|
"""Crop and/or pad a tensor"""
|
|
134
259
|
|
|
@@ -151,7 +276,7 @@ class CropPadTransform(FinalTransform):
|
|
|
151
276
|
self.bound = bound
|
|
152
277
|
self.value = value
|
|
153
278
|
|
|
154
|
-
def
|
|
279
|
+
def xform(self, x):
|
|
155
280
|
crop = tuple([Ellipsis, *self.crop])
|
|
156
281
|
x = x[crop]
|
|
157
282
|
x = pad(x, self.pad, mode=self.bound, value=self.value)
|
|
@@ -218,7 +218,7 @@ class ElasticTransform(NonFinalTransform):
|
|
|
218
218
|
).movedim(-1, 1)
|
|
219
219
|
return flow
|
|
220
220
|
|
|
221
|
-
def
|
|
221
|
+
def xform(self, x):
|
|
222
222
|
"""Deform the input tensor
|
|
223
223
|
|
|
224
224
|
Parameters
|
|
@@ -503,7 +503,7 @@ class AffineTransform(NonFinalTransform):
|
|
|
503
503
|
def make_flow(self, matrix, shape):
|
|
504
504
|
return warps.affine_flow(matrix, shape).movedim(-1, 0)
|
|
505
505
|
|
|
506
|
-
def
|
|
506
|
+
def xform(self, x):
|
|
507
507
|
flow = cast_like(self.flow, x)
|
|
508
508
|
matrix = cast_like(self.matrix, x)
|
|
509
509
|
required = return_requires(self.returns)
|
|
@@ -766,7 +766,7 @@ class AffineElasticTransform(NonFinalTransform):
|
|
|
766
766
|
self.affine = affine
|
|
767
767
|
self.bound = bound
|
|
768
768
|
|
|
769
|
-
def
|
|
769
|
+
def xform(self, x):
|
|
770
770
|
flow = cast_like(self.flow, x)
|
|
771
771
|
controls = cast_like(self.controls, x)
|
|
772
772
|
affine = cast_like(self.affine, x)
|
|
@@ -954,7 +954,7 @@ class MakeAffinePair(NonFinalTransform):
|
|
|
954
954
|
self.left = left
|
|
955
955
|
self.right = right
|
|
956
956
|
|
|
957
|
-
def
|
|
957
|
+
def xform(self, x):
|
|
958
958
|
x1 = self.left(x)
|
|
959
959
|
x2 = self.right(x)
|
|
960
960
|
mat1, mat2 = self.left.matrix, self.right.matrix
|
|
@@ -1119,7 +1119,6 @@ class SlicewiseAffineTransform(NonFinalTransform):
|
|
|
1119
1119
|
F = torch.eye(ndim+1, **backend)
|
|
1120
1120
|
F[:ndim, -1] = -offsets
|
|
1121
1121
|
Z = E.clone()
|
|
1122
|
-
print(zooms.shape, Z.shape)
|
|
1123
1122
|
Z.diagonal(0, -1, -2)[:, :-1].copy_(1 + zooms)
|
|
1124
1123
|
T = E.clone()
|
|
1125
1124
|
T[:, :ndim, -1] = translations
|
|
@@ -1204,7 +1203,7 @@ class SlicewiseAffineTransform(NonFinalTransform):
|
|
|
1204
1203
|
self.subsample = subsample
|
|
1205
1204
|
self.bound = bound
|
|
1206
1205
|
|
|
1207
|
-
def
|
|
1206
|
+
def xform(self, x):
|
|
1208
1207
|
flow = cast_like(self.flow, x)
|
|
1209
1208
|
matrix = cast_like(self.matrix, x)
|
|
1210
1209
|
|
|
@@ -1362,7 +1361,7 @@ class RandomSlicewiseAffineTransform(NonFinalTransform):
|
|
|
1362
1361
|
# get slice direction
|
|
1363
1362
|
slice = self.slice
|
|
1364
1363
|
if slice is None:
|
|
1365
|
-
slice = RandInt(0, ndim)
|
|
1364
|
+
slice = RandInt(0, ndim - 1)
|
|
1366
1365
|
if isinstance(slice, Sampler):
|
|
1367
1366
|
slice = slice()
|
|
1368
1367
|
|