cornucopia 0.2.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cornucopia-0.2.0 → cornucopia-0.4.0}/LICENSE +0 -0
- cornucopia-0.4.0/PKG-INFO +93 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/README.md +8 -5
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/__init__.py +19 -12
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/_version.py +3 -3
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/base.py +66 -16
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/baseutils.py +23 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/contrast.py +1 -1
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/ctx.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/fov.py +138 -13
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/geometric.py +224 -71
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/intensity.py +87 -18
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/io.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/kspace.py +29 -17
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/labels.py +28 -22
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/noise.py +18 -12
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/psf.py +8 -7
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/qmri.py +74 -69
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/random.py +65 -20
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/special.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/synth.py +6 -2
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/__init__.py +0 -0
- cornucopia-0.4.0/cornucopia/tests/test_backward_geometric.py +173 -0
- cornucopia-0.4.0/cornucopia/tests/test_backward_intensity.py +243 -0
- cornucopia-0.4.0/cornucopia/tests/test_backward_kspace.py +115 -0
- cornucopia-0.4.0/cornucopia/tests/test_backward_noise.py +169 -0
- cornucopia-0.4.0/cornucopia/tests/test_backward_psf.py +143 -0
- cornucopia-0.4.0/cornucopia/tests/test_backward_qmri.py +249 -0
- cornucopia-0.4.0/cornucopia/tests/test_backward_random.py +44 -0
- cornucopia-0.4.0/cornucopia/tests/test_backward_synth.py +72 -0
- cornucopia-0.4.0/cornucopia/tests/test_geometric.py +26 -0
- cornucopia-0.4.0/cornucopia/tests/test_intensity.py +9 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_contrast.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_fov.py +34 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_geometric.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_intensity.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_kspace.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_labels.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_noise.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_psf.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_qmri.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_synth.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/__init__.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/b0.py +23 -23
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/bounds.py +0 -0
- cornucopia-0.4.0/cornucopia/utils/compat.py +30 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/conv.py +14 -9
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/gmm.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/indexing.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/io.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/jit.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/kernels.py +124 -115
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/morpho.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/padding.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/patch.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/py.py +19 -2
- cornucopia-0.4.0/cornucopia/utils/smart_inplace.py +163 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/version.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/warps.py +6 -3
- cornucopia-0.4.0/cornucopia.egg-info/PKG-INFO +93 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia.egg-info/SOURCES.txt +12 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/pyproject.toml +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/setup.cfg +1 -1
- {cornucopia-0.2.0 → cornucopia-0.4.0}/setup.py +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/versioneer.py +0 -0
- cornucopia-0.2.0/PKG-INFO +0 -84
- cornucopia-0.2.0/cornucopia.egg-info/PKG-INFO +0 -84
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia.egg-info/dependency_links.txt +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia.egg-info/requires.txt +0 -0
- {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia.egg-info/top_level.txt +0 -0
|
File without changes
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cornucopia
|
|
3
|
+
Version: 0.4.0
|
|
4
|
+
Summary: An abundance of augmentation layers
|
|
5
|
+
Author: Yael Balbastre
|
|
6
|
+
Author-email: yael.balbastre@gmail.com
|
|
7
|
+
License: MIT
|
|
8
|
+
Project-URL: Source Code, https://github.com/balbasty/cornucopia
|
|
9
|
+
Platform: OS Independent
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
16
|
+
Requires-Python: >=3.6
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE
|
|
19
|
+
Requires-Dist: torch>=1.8
|
|
20
|
+
Requires-Dist: numpy
|
|
21
|
+
Requires-Dist: nibabel
|
|
22
|
+
Requires-Dist: torch-interpol>=0.2.4
|
|
23
|
+
Requires-Dist: torch-distmap
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
|
|
26
|
+
<picture align="center">
|
|
27
|
+
<source media="(prefers-color-scheme: dark)" srcset="docs/icons/cornucopia_lightorange.svg">
|
|
28
|
+
<source media="(prefers-color-scheme: light)" srcset="docs/icons/cornucopia_orange.svg">
|
|
29
|
+
<img alt="Cornucopia logo" src="https://github.com/balbasty/cornucopia/raw/main/docs/icons/cornucopia_orange.svg">
|
|
30
|
+
</picture>
|
|
31
|
+
|
|
32
|
+
The `cornucopia` package provides a generic framework for preprocessing,
|
|
33
|
+
augmentation, and domain randomization; along with an abundance of specific layers,
|
|
34
|
+
mostly targeted at (medical) imaging. `cornucopia` is written using a PyTorch
|
|
35
|
+
backend, and therefore runs **on the CPU or GPU**.
|
|
36
|
+
|
|
37
|
+
Cornucopia is *intended* to be used on the GPU for on-line augmentation.
|
|
38
|
+
A quick [benchmark](docs/examples/benchmark.ipynb) of affine and elastic augmentation
|
|
39
|
+
shows that while cornucopia is slower than [TorchIO](https://github.com/fepegar/torchio)
|
|
40
|
+
on the CPU (~ 3s vs 1s), it is greatly accelerated on the GPU (~ 50ms).
|
|
41
|
+
|
|
42
|
+
Since gradients are not expected to backpropagate through its layers, it can
|
|
43
|
+
theoretically be used within any dataloader pipeline,
|
|
44
|
+
independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
|
|
45
|
+
|
|
46
|
+
## Installation
|
|
47
|
+
|
|
48
|
+
### Dependencies
|
|
49
|
+
|
|
50
|
+
- `pytorch >= 1.8`
|
|
51
|
+
- `numpy`
|
|
52
|
+
- `nibabel`
|
|
53
|
+
- `torch-interpol`
|
|
54
|
+
- `torch-distmap`
|
|
55
|
+
|
|
56
|
+
### Conda
|
|
57
|
+
|
|
58
|
+
```sh
|
|
59
|
+
conda install cornucopia -c balbasty -c pytorch -c conda-forge
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
### Pip (release)
|
|
63
|
+
|
|
64
|
+
```sh
|
|
65
|
+
pip install cornucopia
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
### Pip (dev)
|
|
69
|
+
|
|
70
|
+
```sh
|
|
71
|
+
pip install cornucopia@git+https://github.com/balbasty/cornucopia
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
## Documentation
|
|
75
|
+
|
|
76
|
+
Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
|
|
77
|
+
- [installation](https://cornucopia.readthedocs.io/en/latest/install/)
|
|
78
|
+
- [get started](https://cornucopia.readthedocs.io/en/latest/start/)
|
|
79
|
+
- [examples](https://cornucopia.readthedocs.io/en/latest/examples/overview/)
|
|
80
|
+
- [API](https://cornucopia.readthedocs.io/en/latest/api/overview/)
|
|
81
|
+
|
|
82
|
+
## Other augmentation packages
|
|
83
|
+
|
|
84
|
+
There are other great, and much more mature, augmentation packages
|
|
85
|
+
out-there (although few run on the GPU). Here's a non-exhaustive list:
|
|
86
|
+
- [MONAI](https://github.com/Project-MONAI/MONAI)
|
|
87
|
+
- [TorchIO](https://github.com/fepegar/torchio)
|
|
88
|
+
- [Albumentations](https://github.com/albumentations-team/albumentations) (2D only)
|
|
89
|
+
- [Volumentations](https://github.com/ZFTurbo/volumentations) (3D extension of Albumentations)
|
|
90
|
+
|
|
91
|
+
## Contributions
|
|
92
|
+
|
|
93
|
+
If you find this project useful and wish to contribute, please reach out!
|
|
@@ -10,7 +10,7 @@ mostly targeted at (medical) imaging. `cornucopia` is written using a PyTorch
|
|
|
10
10
|
backend, and therefore runs **on the CPU or GPU**.
|
|
11
11
|
|
|
12
12
|
Cornucopia is *intended* to be used on the GPU for on-line augmentation.
|
|
13
|
-
A quick [benchmark](examples/benchmark.ipynb) of affine and elastic augmentation
|
|
13
|
+
A quick [benchmark](docs/examples/benchmark.ipynb) of affine and elastic augmentation
|
|
14
14
|
shows that while cornucopia is slower than [TorchIO](https://github.com/fepegar/torchio)
|
|
15
15
|
on the CPU (~ 3s vs 1s), it is greatly accelerated on the GPU (~ 50ms).
|
|
16
16
|
|
|
@@ -18,9 +18,6 @@ Since gradients are not expected to backpropagate through its layers, it can
|
|
|
18
18
|
theoretically be used within any dataloader pipeline,
|
|
19
19
|
independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
|
|
20
20
|
|
|
21
|
-
## Installation
|
|
22
|
-
|
|
23
|
-
|
|
24
21
|
## Installation
|
|
25
22
|
|
|
26
23
|
### Dependencies
|
|
@@ -37,12 +34,18 @@ independent of the downstream learning framework (pytorch, tensorflow, jax, ...)
|
|
|
37
34
|
conda install cornucopia -c balbasty -c pytorch -c conda-forge
|
|
38
35
|
```
|
|
39
36
|
|
|
40
|
-
### Pip
|
|
37
|
+
### Pip (release)
|
|
41
38
|
|
|
42
39
|
```sh
|
|
43
40
|
pip install cornucopia
|
|
44
41
|
```
|
|
45
42
|
|
|
43
|
+
### Pip (dev)
|
|
44
|
+
|
|
45
|
+
```sh
|
|
46
|
+
pip install cornucopia@git+https://github.com/balbasty/cornucopia
|
|
47
|
+
```
|
|
48
|
+
|
|
46
49
|
## Documentation
|
|
47
50
|
|
|
48
51
|
Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Flexible transforms for pre-processing and augmentation
|
|
2
3
|
|
|
3
4
|
Example on how to use this machinery to generate within-subject
|
|
4
5
|
image pairs with a random affine deformation between them::
|
|
@@ -30,19 +31,25 @@ image pairs with a random affine deformation between them::
|
|
|
30
31
|
|
|
31
32
|
"""
|
|
32
33
|
|
|
33
|
-
# TODO:
|
|
34
|
-
# [x] Make it a standalone package?
|
|
35
|
-
# [x] Move samplers in their own file
|
|
36
|
-
# [x] Add IO transforms (that transform filenames in tensors)
|
|
37
|
-
# [ ] Better deal with separable/shared transforms
|
|
38
|
-
# [ ] Add a SharedTransform class (like Randomized) that does the heavy
|
|
39
|
-
# lifting
|
|
40
|
-
# [ ] By default (non shared), let Transforms handle multi-channel
|
|
41
|
-
# data (currently we loop across channels in the base class)
|
|
42
|
-
|
|
43
34
|
from . import random # noqa: F401
|
|
44
35
|
from . import ctx # noqa: F401
|
|
45
|
-
from .
|
|
36
|
+
from . import base # noqa: F401
|
|
37
|
+
from . import special # noqa: F401
|
|
38
|
+
from . import contrast # noqa: F401
|
|
39
|
+
from . import geometric # noqa: F401
|
|
40
|
+
from . import intensity # noqa: F401
|
|
41
|
+
from . import io # noqa: F401
|
|
42
|
+
from . import fov # noqa: F401
|
|
43
|
+
from . import kspace # noqa: F401
|
|
44
|
+
from . import labels # noqa: F401
|
|
45
|
+
from . import noise # noqa: F401
|
|
46
|
+
from . import psf # noqa: F401
|
|
47
|
+
from . import qmri # noqa: F401
|
|
48
|
+
from . import synth # noqa: F401
|
|
49
|
+
from . import utils # noqa: F401
|
|
50
|
+
|
|
51
|
+
from .random import * # noqa: F401,F403
|
|
52
|
+
from .ctx import * # noqa: F401,F403
|
|
46
53
|
from .base import * # noqa: F401,F403
|
|
47
54
|
from .special import * # noqa: F401,F403
|
|
48
55
|
from .contrast import * # noqa: F401,F403
|
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "
|
|
11
|
+
"date": "2025-04-16T18:13:05+0100",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.
|
|
14
|
+
"full-revisionid": "6f8ab58dfcfe8978c9aa9e8b05898dcf7d75bb5b",
|
|
15
|
+
"version": "0.4.0"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -359,7 +359,7 @@ class NonFinalTransform(SharedMixin, Transform):
|
|
|
359
359
|
super().__init__(**kwargs)
|
|
360
360
|
self.shared = self._prepare_shared(shared)
|
|
361
361
|
|
|
362
|
-
def make_final(self, x, max_depth=float('inf')
|
|
362
|
+
def make_final(self, x, max_depth=float('inf')):
|
|
363
363
|
if self.is_final or max_depth == 0:
|
|
364
364
|
return self
|
|
365
365
|
return NotImplemented
|
|
@@ -447,6 +447,8 @@ class SequentialTransform(SpecialMixin, SharedMixin, Transform):
|
|
|
447
447
|
def make_final(self, x, max_depth=float('inf')):
|
|
448
448
|
if max_depth == 0:
|
|
449
449
|
return self
|
|
450
|
+
if self.is_final:
|
|
451
|
+
return self
|
|
450
452
|
# x = VirtualTensor.from_any(x, compute_stats=True)
|
|
451
453
|
trf = []
|
|
452
454
|
for t in self:
|
|
@@ -477,12 +479,23 @@ class SequentialTransform(SpecialMixin, SharedMixin, Transform):
|
|
|
477
479
|
x = args[0]
|
|
478
480
|
else:
|
|
479
481
|
return None
|
|
480
|
-
for trf in self
|
|
482
|
+
for trf in self:
|
|
481
483
|
with IncludeKeysTransform(trf, self.include), \
|
|
482
484
|
ExcludeKeysTransform(trf, self.exclude):
|
|
483
485
|
x = trf(x)
|
|
484
486
|
return x
|
|
485
487
|
|
|
488
|
+
def xform(self, x):
|
|
489
|
+
# This should only be called when a Layer's `make_final` returns
|
|
490
|
+
# a `SequentialTransform`` (i.e., it is created implictly under
|
|
491
|
+
# the hood, not explicitly by the user).
|
|
492
|
+
# In such cases, `shared=False` and hopefully we can just fallback
|
|
493
|
+
# to `forward()`.
|
|
494
|
+
#
|
|
495
|
+
# FIXME
|
|
496
|
+
# what happens if there's weird stuff in returns/include/exclude?
|
|
497
|
+
return self(x)
|
|
498
|
+
|
|
486
499
|
def __len__(self):
|
|
487
500
|
return len(self.transforms)
|
|
488
501
|
|
|
@@ -556,9 +569,11 @@ class MaybeTransform(SpecialMixin, SharedMixin, Transform):
|
|
|
556
569
|
img = (0.2 * gauss)(img)
|
|
557
570
|
```
|
|
558
571
|
```
|
|
559
|
-
"""
|
|
560
572
|
|
|
561
|
-
|
|
573
|
+
!!! changedin " \
|
|
574
|
+
Default for `shared` changed from `False` to `True`"
|
|
575
|
+
"""
|
|
576
|
+
def __init__(self, transform, prob=0.5, *, shared=True, **kwargs):
|
|
562
577
|
"""
|
|
563
578
|
|
|
564
579
|
Parameters
|
|
@@ -617,9 +632,12 @@ class SwitchTransform(SpecialMixin, SharedMixin, Transform):
|
|
|
617
632
|
```python
|
|
618
633
|
img = cc.switch({gauss: 0.5, chi: 0.5})(img)
|
|
619
634
|
```
|
|
635
|
+
|
|
636
|
+
!!! changedin " \
|
|
637
|
+
Default for `shared` changed from `False` to `True`"
|
|
620
638
|
"""
|
|
621
639
|
|
|
622
|
-
def __init__(self, transforms, prob=0, *, shared=
|
|
640
|
+
def __init__(self, transforms, prob=0, *, shared=True, **kwargs):
|
|
623
641
|
"""
|
|
624
642
|
|
|
625
643
|
Parameters
|
|
@@ -1068,33 +1086,67 @@ class RandomizedTransform(NonFinalTransform):
|
|
|
1068
1086
|
"""
|
|
1069
1087
|
Transform generated by randomizing some parameters of another transform.
|
|
1070
1088
|
|
|
1089
|
+
!!! note "`ctx.randomize` is an alias for `RandomizedTransform`"
|
|
1090
|
+
|
|
1071
1091
|
!!! example "Gaussian noise with randomized variance"
|
|
1072
1092
|
Object call
|
|
1073
1093
|
```python
|
|
1074
1094
|
import cornucopia as cc
|
|
1075
|
-
hypernoise = RandomizedTransform(cc.GaussianNoise, [cc.Uniform(
|
|
1095
|
+
hypernoise = cc.RandomizedTransform(cc.GaussianNoise, [cc.Uniform()])
|
|
1076
1096
|
img = hypernoise(img)
|
|
1077
1097
|
```
|
|
1078
|
-
|
|
1098
|
+
|
|
1099
|
+
Delayed call
|
|
1079
1100
|
```python
|
|
1080
1101
|
import cornucopia as cc
|
|
1081
|
-
|
|
1102
|
+
MyRandomNoise = cc.randomize(cc.GaussianNoise)
|
|
1103
|
+
hypernoise = MyRandomNoise(cc.Uniform())
|
|
1082
1104
|
img = hypernoise(img)
|
|
1083
1105
|
```
|
|
1084
1106
|
|
|
1085
1107
|
"""
|
|
1086
1108
|
|
|
1087
|
-
|
|
1109
|
+
class Delayed:
|
|
1110
|
+
# Temproary parameter holder for delayed calls
|
|
1111
|
+
def __init__(self, transform, **kwargs):
|
|
1112
|
+
self.transform = transform
|
|
1113
|
+
self.kwargs = kwargs
|
|
1114
|
+
|
|
1115
|
+
def __call__(self, *args, **kwargs):
|
|
1116
|
+
return RandomizedTransform(
|
|
1117
|
+
self.transform, args, kwargs, **self.kwargs)
|
|
1118
|
+
|
|
1119
|
+
def __new__(cls, *args, **kwargs):
|
|
1120
|
+
if cls is RandomizedTransform:
|
|
1121
|
+
return cls._base_new(*args, **kwargs)
|
|
1122
|
+
return super().__new__(cls)
|
|
1123
|
+
|
|
1124
|
+
@classmethod
|
|
1125
|
+
def _base_new(cls, transform, sample=tuple(), ksample=dict(),
|
|
1126
|
+
*, shared=False, **kwargs):
|
|
1127
|
+
assert cls is RandomizedTransform
|
|
1128
|
+
if not sample and not ksample:
|
|
1129
|
+
# If no arguments are passed, it means that the user calls
|
|
1130
|
+
# this in "delayed/functional" mode. In that case, we return
|
|
1131
|
+
# a callable object that returns the constructed instance
|
|
1132
|
+
# using the call-time arguments.
|
|
1133
|
+
return cls.Delayed(transform, shared=shared, **kwargs)
|
|
1134
|
+
# Otherwise, we're in object mode and we instantiate the
|
|
1135
|
+
# randomized object.
|
|
1136
|
+
return super().__new__(cls)
|
|
1137
|
+
|
|
1138
|
+
def __init__(self, transform, sample=tuple(), ksample=dict(),
|
|
1088
1139
|
*, shared=False, **kwargs):
|
|
1089
1140
|
"""
|
|
1090
|
-
|
|
1091
1141
|
Parameters
|
|
1092
1142
|
----------
|
|
1093
1143
|
transform : callable(...) -> Transform
|
|
1094
1144
|
A Transform subclass or a function that constructs a Transform.
|
|
1095
1145
|
sample : [list or dict of] callable
|
|
1096
1146
|
A collection of functions that generate parameter values provided
|
|
1097
|
-
to `transform`.
|
|
1147
|
+
to `transform`. Can be args-like or kwargs-like arguments.
|
|
1148
|
+
ksample : dict[callable]
|
|
1149
|
+
Must be kwargs-like arguments.
|
|
1098
1150
|
|
|
1099
1151
|
Other Parameters
|
|
1100
1152
|
----------------
|
|
@@ -1130,9 +1182,7 @@ class RandomizedTransform(NonFinalTransform):
|
|
|
1130
1182
|
|
|
1131
1183
|
def __repr__(self):
|
|
1132
1184
|
if type(self) is RandomizedTransform:
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
except TypeError:
|
|
1137
|
-
pass
|
|
1185
|
+
xform = self.subtransform
|
|
1186
|
+
if isinstance(xform, type) and issubclass(xform, Transform):
|
|
1187
|
+
return f'Randomized{xform.__name__}()'
|
|
1138
1188
|
return super().__repr__()
|
|
@@ -117,6 +117,29 @@ def returns_find(flag, returned, returns):
|
|
|
117
117
|
return None
|
|
118
118
|
|
|
119
119
|
|
|
120
|
+
def returns_update(value, flag, returned, returns):
|
|
121
|
+
"""Find tensor corresponding to flag in returned structure"""
|
|
122
|
+
if returns is None:
|
|
123
|
+
if flag == 'output':
|
|
124
|
+
return value
|
|
125
|
+
else:
|
|
126
|
+
return None
|
|
127
|
+
if isinstance(returns, dict):
|
|
128
|
+
if flag in returns:
|
|
129
|
+
returned[flag] = value
|
|
130
|
+
return returned
|
|
131
|
+
elif isinstance(returns, (list, tuple)):
|
|
132
|
+
if flag in returns:
|
|
133
|
+
returned[returns.index(flag)] = value
|
|
134
|
+
return returned
|
|
135
|
+
else:
|
|
136
|
+
assert isinstance(returns, str)
|
|
137
|
+
if returns == flag:
|
|
138
|
+
return value
|
|
139
|
+
else:
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
|
|
120
143
|
def flatstruct(x):
|
|
121
144
|
"""Flatten a nested structure of tensors"""
|
|
122
145
|
|
|
@@ -184,7 +184,7 @@ class ContrastLookupTransform(NonFinalTransform):
|
|
|
184
184
|
|
|
185
185
|
vmin, vmax = x.min(), x.max()
|
|
186
186
|
edges = torch.linspace(vmin, vmax, self.nk+1)
|
|
187
|
-
new_mu = torch.rand(self.nk) * (vmax - vmin) + vmin
|
|
187
|
+
new_mu = torch.rand(self.nk).to(x) * (vmax - vmin) + vmin
|
|
188
188
|
return self.LookupFinalTransform(
|
|
189
189
|
edges, new_mu, **self.get_prm()
|
|
190
190
|
).make_final(x, max_depth-1)
|
|
File without changes
|
|
@@ -8,14 +8,16 @@ __all__ = [
|
|
|
8
8
|
'CropTransform',
|
|
9
9
|
'PadTransform',
|
|
10
10
|
'PowerTwoTransform',
|
|
11
|
+
'Rot90Transform',
|
|
12
|
+
'Rot180Transform',
|
|
13
|
+
'RandomRot90Transform',
|
|
11
14
|
]
|
|
12
|
-
|
|
13
15
|
import math
|
|
14
16
|
from random import shuffle
|
|
15
|
-
from .base import FinalTransform, NonFinalTransform
|
|
17
|
+
from .base import FinalTransform, NonFinalTransform, PerChannelTransform
|
|
16
18
|
from .utils.py import ensure_list
|
|
17
19
|
from .utils.padding import pad
|
|
18
|
-
from .random import Uniform, RandKFrom, Sampler
|
|
20
|
+
from .random import Uniform, RandKFrom, Sampler, RandInt, make_range
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
class FlipTransform(FinalTransform):
|
|
@@ -23,7 +25,6 @@ class FlipTransform(FinalTransform):
|
|
|
23
25
|
|
|
24
26
|
def __init__(self, axis=None, **kwargs):
|
|
25
27
|
"""
|
|
26
|
-
|
|
27
28
|
Parameters
|
|
28
29
|
----------
|
|
29
30
|
axis : [list of] int
|
|
@@ -46,24 +47,30 @@ class FlipTransform(FinalTransform):
|
|
|
46
47
|
class RandomFlipTransform(NonFinalTransform):
|
|
47
48
|
"""Randomly flip one or more axes"""
|
|
48
49
|
|
|
49
|
-
def __init__(self, axes=None, **kwargs):
|
|
50
|
+
def __init__(self, axes=None, *, shared=True, **kwargs):
|
|
50
51
|
"""
|
|
51
|
-
|
|
52
52
|
Parameters
|
|
53
53
|
----------
|
|
54
54
|
axes : Sampler or [list of] int
|
|
55
55
|
Axes that can be flipped (default: all)
|
|
56
|
+
|
|
57
|
+
Other Parameters
|
|
58
|
+
----------------
|
|
56
59
|
shared : {'channels', 'tensors', 'channels+tensors', ''}
|
|
57
60
|
Apply the same flip to all channels and/or tensors
|
|
58
61
|
"""
|
|
59
62
|
axes = kwargs.pop('axis', axes)
|
|
60
|
-
|
|
61
|
-
super().__init__(**kwargs)
|
|
63
|
+
super().__init__(shared=shared, **kwargs)
|
|
62
64
|
self.axes = axes
|
|
63
65
|
|
|
64
66
|
def make_final(self, x, max_depth=float('inf')):
|
|
65
67
|
if max_depth == 0:
|
|
66
68
|
return self
|
|
69
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
70
|
+
return PerChannelTransform(
|
|
71
|
+
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
|
|
72
|
+
**self.get_prm()
|
|
73
|
+
).make_final(x, max_depth-1)
|
|
67
74
|
axes = self.axes or range(1, x.ndim)
|
|
68
75
|
if not isinstance(axes, Sampler):
|
|
69
76
|
rand_axes = RandKFrom(ensure_list(axes))
|
|
@@ -76,7 +83,6 @@ class PermuteAxesTransform(FinalTransform):
|
|
|
76
83
|
|
|
77
84
|
def __init__(self, permutation=None, **kwargs):
|
|
78
85
|
"""
|
|
79
|
-
|
|
80
86
|
Parameters
|
|
81
87
|
----------
|
|
82
88
|
permutation : [list of] int
|
|
@@ -105,23 +111,29 @@ class PermuteAxesTransform(FinalTransform):
|
|
|
105
111
|
class RandomPermuteAxesTransform(NonFinalTransform):
|
|
106
112
|
"""Randomly permute axes"""
|
|
107
113
|
|
|
108
|
-
def __init__(self, axes=None, **kwargs):
|
|
114
|
+
def __init__(self, axes=None, *, shared=True, **kwargs):
|
|
109
115
|
"""
|
|
110
|
-
|
|
111
116
|
Parameters
|
|
112
117
|
----------
|
|
113
118
|
axes : [list of] int
|
|
114
119
|
Axes that can be permuted (default: all)
|
|
120
|
+
|
|
121
|
+
Other Parameters
|
|
122
|
+
----------------
|
|
115
123
|
shared : {'channels', 'tensors', 'channels+tensors', ''}
|
|
116
124
|
Apply the same permutation to all channels and/or tensors
|
|
117
125
|
"""
|
|
118
|
-
|
|
119
|
-
super().__init__(**kwargs)
|
|
126
|
+
super().__init__(shared=shared, **kwargs)
|
|
120
127
|
self.axes = axes
|
|
121
128
|
|
|
122
129
|
def make_final(self, x, max_depth=float('inf')):
|
|
123
130
|
if max_depth == 0:
|
|
124
131
|
return self
|
|
132
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
133
|
+
return PerChannelTransform(
|
|
134
|
+
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
|
|
135
|
+
**self.get_prm()
|
|
136
|
+
).make_final(x, max_depth-1)
|
|
125
137
|
axes = list(self.axes or range(x.ndim-1))
|
|
126
138
|
shuffle(axes)
|
|
127
139
|
return PermuteAxesTransform(
|
|
@@ -129,6 +141,119 @@ class RandomPermuteAxesTransform(NonFinalTransform):
|
|
|
129
141
|
).make_final(x, max_depth-1)
|
|
130
142
|
|
|
131
143
|
|
|
144
|
+
class Rot90Transform(FinalTransform):
|
|
145
|
+
"""
|
|
146
|
+
Apply a 90 (or 180) rotation along one or several axes
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(self, axis=0, negative=False, double=False, **kwargs):
|
|
150
|
+
"""
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
axis : int or list[int]
|
|
154
|
+
Rotation axis (indexing does not account for the channel axis)
|
|
155
|
+
negative : bool or list[bool]
|
|
156
|
+
Rotate by -90 deg instead of 90 deg
|
|
157
|
+
double : bool or list[bool]
|
|
158
|
+
Rotate be 180 instead of 90 (`negative` is then unused)
|
|
159
|
+
"""
|
|
160
|
+
super().__init__(**kwargs)
|
|
161
|
+
self.axis = ensure_list(axis)
|
|
162
|
+
self.negative = ensure_list(negative, len(self.axis))
|
|
163
|
+
self.double = ensure_list(double, len(self.axis))
|
|
164
|
+
|
|
165
|
+
def xform(self, x):
|
|
166
|
+
# this implementation is suboptimal. We should fuse all transpose
|
|
167
|
+
# and all flips into a single "transpose + flip" operation so that
|
|
168
|
+
# a single allocation happens. This will be fine for now.
|
|
169
|
+
|
|
170
|
+
ndim = x.ndim - 1
|
|
171
|
+
axis = [1 + (ndim + a if a < 0 else a) for a in self.axis]
|
|
172
|
+
for ax, neg, dbl in zip(axis, self.negative, self.double):
|
|
173
|
+
if dbl:
|
|
174
|
+
if ndim == 2:
|
|
175
|
+
dims = [1, 2]
|
|
176
|
+
else:
|
|
177
|
+
assert ndim == 3
|
|
178
|
+
dims = [d for d in (1, 2, 3) if d != ax]
|
|
179
|
+
x = x.flip(dims)
|
|
180
|
+
else:
|
|
181
|
+
if ndim == 2:
|
|
182
|
+
dims = [1, 2]
|
|
183
|
+
else:
|
|
184
|
+
assert ndim == 3
|
|
185
|
+
dims = [d for d in (1, 2, 3) if d != ax]
|
|
186
|
+
x = x.transpose(*dims).flip(dims[1] if neg else dims[0])
|
|
187
|
+
return x
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class Rot180Transform(Rot90Transform):
|
|
191
|
+
"""Apply a 180 deg rotation along one or several axes"""
|
|
192
|
+
|
|
193
|
+
def __init__(self, axis=0, **kwargs):
|
|
194
|
+
"""
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
axis : int or list[int]
|
|
198
|
+
Rotation axis (indexing does not account for the channel axis)
|
|
199
|
+
"""
|
|
200
|
+
super().__init__(axis, double=True, **kwargs)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class RandomRot90Transform(NonFinalTransform):
|
|
204
|
+
"""Random set of 90 transforms"""
|
|
205
|
+
|
|
206
|
+
def __init__(self, axes=None, max_rot=2, negative=True,
|
|
207
|
+
*, shared=True, **kwargs):
|
|
208
|
+
"""
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
axes : int or list[int]
|
|
212
|
+
Axes along which rotations can happen.
|
|
213
|
+
If `None`, all axes.
|
|
214
|
+
max_rot : int or Sampler
|
|
215
|
+
Maximum number of consecutive rotations.
|
|
216
|
+
negative : bool
|
|
217
|
+
Whether to authorize negative rotations.
|
|
218
|
+
|
|
219
|
+
Other Parameters
|
|
220
|
+
----------------
|
|
221
|
+
shared : {'channels', 'tensors', 'channels+tensors', ''}
|
|
222
|
+
Apply the same permutation to all channels and/or tensors
|
|
223
|
+
"""
|
|
224
|
+
super().__init__(shared=shared, **kwargs)
|
|
225
|
+
self.axes = axes
|
|
226
|
+
self.max_rot = RandInt.make(make_range(1, max_rot))
|
|
227
|
+
self.negative = negative
|
|
228
|
+
|
|
229
|
+
def make_final(self, x, max_depth=float('inf')):
|
|
230
|
+
if max_depth == 0:
|
|
231
|
+
return self
|
|
232
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
233
|
+
return PerChannelTransform(
|
|
234
|
+
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
|
|
235
|
+
**self.get_prm()
|
|
236
|
+
).make_final(x, max_depth-1)
|
|
237
|
+
ndim = x.ndim - 1
|
|
238
|
+
max_rot = self.max_rot
|
|
239
|
+
if isinstance(max_rot, Sampler):
|
|
240
|
+
max_rot = max_rot()
|
|
241
|
+
axes = self.axes
|
|
242
|
+
if axes is None:
|
|
243
|
+
axes = list(range(ndim))
|
|
244
|
+
if isinstance(axes, (int, list, tuple)):
|
|
245
|
+
axes = ensure_list(axes, max_rot, crop=False)
|
|
246
|
+
if not isinstance(axes, Sampler):
|
|
247
|
+
axes = RandKFrom(axes, max_rot, replacement=True)
|
|
248
|
+
|
|
249
|
+
axes = ensure_list(axes(), max_rot)
|
|
250
|
+
negative = RandKFrom([False, True], max_rot, replacement=True)() \
|
|
251
|
+
if self.negative else [False] * max_rot
|
|
252
|
+
return Rot90Transform(
|
|
253
|
+
axes, negative, **self.get_prm()
|
|
254
|
+
).make_final(max_depth-1)
|
|
255
|
+
|
|
256
|
+
|
|
132
257
|
class CropPadTransform(FinalTransform):
|
|
133
258
|
"""Crop and/or pad a tensor"""
|
|
134
259
|
|