cornucopia 0.2.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {cornucopia-0.2.0 → cornucopia-0.4.0}/LICENSE +0 -0
  2. cornucopia-0.4.0/PKG-INFO +93 -0
  3. {cornucopia-0.2.0 → cornucopia-0.4.0}/README.md +8 -5
  4. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/__init__.py +19 -12
  5. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/_version.py +3 -3
  6. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/base.py +66 -16
  7. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/baseutils.py +23 -0
  8. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/contrast.py +1 -1
  9. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/ctx.py +0 -0
  10. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/fov.py +138 -13
  11. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/geometric.py +224 -71
  12. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/intensity.py +87 -18
  13. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/io.py +0 -0
  14. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/kspace.py +29 -17
  15. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/labels.py +28 -22
  16. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/noise.py +18 -12
  17. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/psf.py +8 -7
  18. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/qmri.py +74 -69
  19. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/random.py +65 -20
  20. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/special.py +0 -0
  21. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/synth.py +6 -2
  22. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/__init__.py +0 -0
  23. cornucopia-0.4.0/cornucopia/tests/test_backward_geometric.py +173 -0
  24. cornucopia-0.4.0/cornucopia/tests/test_backward_intensity.py +243 -0
  25. cornucopia-0.4.0/cornucopia/tests/test_backward_kspace.py +115 -0
  26. cornucopia-0.4.0/cornucopia/tests/test_backward_noise.py +169 -0
  27. cornucopia-0.4.0/cornucopia/tests/test_backward_psf.py +143 -0
  28. cornucopia-0.4.0/cornucopia/tests/test_backward_qmri.py +249 -0
  29. cornucopia-0.4.0/cornucopia/tests/test_backward_random.py +44 -0
  30. cornucopia-0.4.0/cornucopia/tests/test_backward_synth.py +72 -0
  31. cornucopia-0.4.0/cornucopia/tests/test_geometric.py +26 -0
  32. cornucopia-0.4.0/cornucopia/tests/test_intensity.py +9 -0
  33. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_contrast.py +0 -0
  34. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_fov.py +34 -0
  35. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_geometric.py +0 -0
  36. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_intensity.py +0 -0
  37. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_kspace.py +0 -0
  38. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_labels.py +0 -0
  39. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_noise.py +0 -0
  40. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_psf.py +0 -0
  41. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_qmri.py +0 -0
  42. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/tests/test_run_synth.py +0 -0
  43. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/__init__.py +0 -0
  44. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/b0.py +23 -23
  45. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/bounds.py +0 -0
  46. cornucopia-0.4.0/cornucopia/utils/compat.py +30 -0
  47. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/conv.py +14 -9
  48. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/gmm.py +0 -0
  49. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/indexing.py +0 -0
  50. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/io.py +0 -0
  51. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/jit.py +0 -0
  52. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/kernels.py +124 -115
  53. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/morpho.py +0 -0
  54. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/padding.py +0 -0
  55. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/patch.py +0 -0
  56. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/py.py +19 -2
  57. cornucopia-0.4.0/cornucopia/utils/smart_inplace.py +163 -0
  58. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/version.py +0 -0
  59. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia/utils/warps.py +6 -3
  60. cornucopia-0.4.0/cornucopia.egg-info/PKG-INFO +93 -0
  61. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia.egg-info/SOURCES.txt +12 -0
  62. {cornucopia-0.2.0 → cornucopia-0.4.0}/pyproject.toml +0 -0
  63. {cornucopia-0.2.0 → cornucopia-0.4.0}/setup.cfg +1 -1
  64. {cornucopia-0.2.0 → cornucopia-0.4.0}/setup.py +0 -0
  65. {cornucopia-0.2.0 → cornucopia-0.4.0}/versioneer.py +0 -0
  66. cornucopia-0.2.0/PKG-INFO +0 -84
  67. cornucopia-0.2.0/cornucopia.egg-info/PKG-INFO +0 -84
  68. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia.egg-info/dependency_links.txt +0 -0
  69. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia.egg-info/requires.txt +0 -0
  70. {cornucopia-0.2.0 → cornucopia-0.4.0}/cornucopia.egg-info/top_level.txt +0 -0
File without changes
@@ -0,0 +1,93 @@
1
+ Metadata-Version: 2.4
2
+ Name: cornucopia
3
+ Version: 0.4.0
4
+ Summary: An abundance of augmentation layers
5
+ Author: Yael Balbastre
6
+ Author-email: yael.balbastre@gmail.com
7
+ License: MIT
8
+ Project-URL: Source Code, https://github.com/balbasty/cornucopia
9
+ Platform: OS Independent
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
16
+ Requires-Python: >=3.6
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: torch>=1.8
20
+ Requires-Dist: numpy
21
+ Requires-Dist: nibabel
22
+ Requires-Dist: torch-interpol>=0.2.4
23
+ Requires-Dist: torch-distmap
24
+ Dynamic: license-file
25
+
26
+ <picture align="center">
27
+ <source media="(prefers-color-scheme: dark)" srcset="docs/icons/cornucopia_lightorange.svg">
28
+ <source media="(prefers-color-scheme: light)" srcset="docs/icons/cornucopia_orange.svg">
29
+ <img alt="Cornucopia logo" src="https://github.com/balbasty/cornucopia/raw/main/docs/icons/cornucopia_orange.svg">
30
+ </picture>
31
+
32
+ The `cornucopia` package provides a generic framework for preprocessing,
33
+ augmentation, and domain randomization; along with an abundance of specific layers,
34
+ mostly targeted at (medical) imaging. `cornucopia` is written using a PyTorch
35
+ backend, and therefore runs **on the CPU or GPU**.
36
+
37
+ Cornucopia is *intended* to be used on the GPU for on-line augmentation.
38
+ A quick [benchmark](docs/examples/benchmark.ipynb) of affine and elastic augmentation
39
+ shows that while cornucopia is slower than [TorchIO](https://github.com/fepegar/torchio)
40
+ on the CPU (~ 3s vs 1s), it is greatly accelerated on the GPU (~ 50ms).
41
+
42
+ Since gradients are not expected to backpropagate through its layers, it can
43
+ theoretically be used within any dataloader pipeline,
44
+ independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
45
+
46
+ ## Installation
47
+
48
+ ### Dependencies
49
+
50
+ - `pytorch >= 1.8`
51
+ - `numpy`
52
+ - `nibabel`
53
+ - `torch-interpol`
54
+ - `torch-distmap`
55
+
56
+ ### Conda
57
+
58
+ ```sh
59
+ conda install cornucopia -c balbasty -c pytorch -c conda-forge
60
+ ```
61
+
62
+ ### Pip (release)
63
+
64
+ ```sh
65
+ pip install cornucopia
66
+ ```
67
+
68
+ ### Pip (dev)
69
+
70
+ ```sh
71
+ pip install cornucopia@git+https://github.com/balbasty/cornucopia
72
+ ```
73
+
74
+ ## Documentation
75
+
76
+ Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
77
+ - [installation](https://cornucopia.readthedocs.io/en/latest/install/)
78
+ - [get started](https://cornucopia.readthedocs.io/en/latest/start/)
79
+ - [examples](https://cornucopia.readthedocs.io/en/latest/examples/overview/)
80
+ - [API](https://cornucopia.readthedocs.io/en/latest/api/overview/)
81
+
82
+ ## Other augmentation packages
83
+
84
+ There are other great, and much more mature, augmentation packages
85
+ out-there (although few run on the GPU). Here's a non-exhaustive list:
86
+ - [MONAI](https://github.com/Project-MONAI/MONAI)
87
+ - [TorchIO](https://github.com/fepegar/torchio)
88
+ - [Albumentations](https://github.com/albumentations-team/albumentations) (2D only)
89
+ - [Volumentations](https://github.com/ZFTurbo/volumentations) (3D extension of Albumentations)
90
+
91
+ ## Contributions
92
+
93
+ If you find this project useful and wish to contribute, please reach out!
@@ -10,7 +10,7 @@ mostly targeted at (medical) imaging. `cornucopia` is written using a PyTorch
10
10
  backend, and therefore runs **on the CPU or GPU**.
11
11
 
12
12
  Cornucopia is *intended* to be used on the GPU for on-line augmentation.
13
- A quick [benchmark](examples/benchmark.ipynb) of affine and elastic augmentation
13
+ A quick [benchmark](docs/examples/benchmark.ipynb) of affine and elastic augmentation
14
14
  shows that while cornucopia is slower than [TorchIO](https://github.com/fepegar/torchio)
15
15
  on the CPU (~ 3s vs 1s), it is greatly accelerated on the GPU (~ 50ms).
16
16
 
@@ -18,9 +18,6 @@ Since gradients are not expected to backpropagate through its layers, it can
18
18
  theoretically be used within any dataloader pipeline,
19
19
  independent of the downstream learning framework (pytorch, tensorflow, jax, ...).
20
20
 
21
- ## Installation
22
-
23
-
24
21
  ## Installation
25
22
 
26
23
  ### Dependencies
@@ -37,12 +34,18 @@ independent of the downstream learning framework (pytorch, tensorflow, jax, ...)
37
34
  conda install cornucopia -c balbasty -c pytorch -c conda-forge
38
35
  ```
39
36
 
40
- ### Pip
37
+ ### Pip (release)
41
38
 
42
39
  ```sh
43
40
  pip install cornucopia
44
41
  ```
45
42
 
43
+ ### Pip (dev)
44
+
45
+ ```sh
46
+ pip install cornucopia@git+https://github.com/balbasty/cornucopia
47
+ ```
48
+
46
49
  ## Documentation
47
50
 
48
51
  Read the [documentation](https://cornucopia.readthedocs.io) and in particular:
@@ -1,4 +1,5 @@
1
- """Flexible transforms for pre-processing and augmentation
1
+ """
2
+ Flexible transforms for pre-processing and augmentation
2
3
 
3
4
  Example on how to use this machinery to generate within-subject
4
5
  image pairs with a random affine deformation between them::
@@ -30,19 +31,25 @@ image pairs with a random affine deformation between them::
30
31
 
31
32
  """
32
33
 
33
- # TODO:
34
- # [x] Make it a standalone package?
35
- # [x] Move samplers in their own file
36
- # [x] Add IO transforms (that transform filenames in tensors)
37
- # [ ] Better deal with separable/shared transforms
38
- # [ ] Add a SharedTransform class (like Randomized) that does the heavy
39
- # lifting
40
- # [ ] By default (non shared), let Transforms handle multi-channel
41
- # data (currently we loop across channels in the base class)
42
-
43
34
  from . import random # noqa: F401
44
35
  from . import ctx # noqa: F401
45
- from .ctx import batch # noqa: F401
36
+ from . import base # noqa: F401
37
+ from . import special # noqa: F401
38
+ from . import contrast # noqa: F401
39
+ from . import geometric # noqa: F401
40
+ from . import intensity # noqa: F401
41
+ from . import io # noqa: F401
42
+ from . import fov # noqa: F401
43
+ from . import kspace # noqa: F401
44
+ from . import labels # noqa: F401
45
+ from . import noise # noqa: F401
46
+ from . import psf # noqa: F401
47
+ from . import qmri # noqa: F401
48
+ from . import synth # noqa: F401
49
+ from . import utils # noqa: F401
50
+
51
+ from .random import * # noqa: F401,F403
52
+ from .ctx import * # noqa: F401,F403
46
53
  from .base import * # noqa: F401,F403
47
54
  from .special import * # noqa: F401,F403
48
55
  from .contrast import * # noqa: F401,F403
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2023-11-21T15:24:48-0500",
11
+ "date": "2025-04-16T18:13:05+0100",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "6d09025573589db13d99c6df247f138f22f5ab61",
15
- "version": "0.2.0"
14
+ "full-revisionid": "6f8ab58dfcfe8978c9aa9e8b05898dcf7d75bb5b",
15
+ "version": "0.4.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -359,7 +359,7 @@ class NonFinalTransform(SharedMixin, Transform):
359
359
  super().__init__(**kwargs)
360
360
  self.shared = self._prepare_shared(shared)
361
361
 
362
- def make_final(self, x, max_depth=float('inf'), *args, **kwargs):
362
+ def make_final(self, x, max_depth=float('inf')):
363
363
  if self.is_final or max_depth == 0:
364
364
  return self
365
365
  return NotImplemented
@@ -447,6 +447,8 @@ class SequentialTransform(SpecialMixin, SharedMixin, Transform):
447
447
  def make_final(self, x, max_depth=float('inf')):
448
448
  if max_depth == 0:
449
449
  return self
450
+ if self.is_final:
451
+ return self
450
452
  # x = VirtualTensor.from_any(x, compute_stats=True)
451
453
  trf = []
452
454
  for t in self:
@@ -477,12 +479,23 @@ class SequentialTransform(SpecialMixin, SharedMixin, Transform):
477
479
  x = args[0]
478
480
  else:
479
481
  return None
480
- for trf in self.transforms:
482
+ for trf in self:
481
483
  with IncludeKeysTransform(trf, self.include), \
482
484
  ExcludeKeysTransform(trf, self.exclude):
483
485
  x = trf(x)
484
486
  return x
485
487
 
488
+ def xform(self, x):
489
+ # This should only be called when a Layer's `make_final` returns
490
+ # a `SequentialTransform`` (i.e., it is created implictly under
491
+ # the hood, not explicitly by the user).
492
+ # In such cases, `shared=False` and hopefully we can just fallback
493
+ # to `forward()`.
494
+ #
495
+ # FIXME
496
+ # what happens if there's weird stuff in returns/include/exclude?
497
+ return self(x)
498
+
486
499
  def __len__(self):
487
500
  return len(self.transforms)
488
501
 
@@ -556,9 +569,11 @@ class MaybeTransform(SpecialMixin, SharedMixin, Transform):
556
569
  img = (0.2 * gauss)(img)
557
570
  ```
558
571
  ```
559
- """
560
572
 
561
- def __init__(self, transform, prob=0.5, *, shared=False, **kwargs):
573
+ !!! changedin "![v0.4](https://img.shields.io/badge/v0.4-yellow) \
574
+ Default for `shared` changed from `False` to `True`"
575
+ """
576
+ def __init__(self, transform, prob=0.5, *, shared=True, **kwargs):
562
577
  """
563
578
 
564
579
  Parameters
@@ -617,9 +632,12 @@ class SwitchTransform(SpecialMixin, SharedMixin, Transform):
617
632
  ```python
618
633
  img = cc.switch({gauss: 0.5, chi: 0.5})(img)
619
634
  ```
635
+
636
+ !!! changedin "![v0.4](https://img.shields.io/badge/v0.4-yellow) \
637
+ Default for `shared` changed from `False` to `True`"
620
638
  """
621
639
 
622
- def __init__(self, transforms, prob=0, *, shared=False, **kwargs):
640
+ def __init__(self, transforms, prob=0, *, shared=True, **kwargs):
623
641
  """
624
642
 
625
643
  Parameters
@@ -1068,33 +1086,67 @@ class RandomizedTransform(NonFinalTransform):
1068
1086
  """
1069
1087
  Transform generated by randomizing some parameters of another transform.
1070
1088
 
1089
+ !!! note "`ctx.randomize` is an alias for `RandomizedTransform`"
1090
+
1071
1091
  !!! example "Gaussian noise with randomized variance"
1072
1092
  Object call
1073
1093
  ```python
1074
1094
  import cornucopia as cc
1075
- hypernoise = RandomizedTransform(cc.GaussianNoise, [cc.Uniform(0, 10)])
1095
+ hypernoise = cc.RandomizedTransform(cc.GaussianNoise, [cc.Uniform()])
1076
1096
  img = hypernoise(img)
1077
1097
  ```
1078
- Functional call
1098
+
1099
+ Delayed call
1079
1100
  ```python
1080
1101
  import cornucopia as cc
1081
- hypernoise = cc.randomize(cc.GaussianNoise)(cc.Uniform(0, 10))
1102
+ MyRandomNoise = cc.randomize(cc.GaussianNoise)
1103
+ hypernoise = MyRandomNoise(cc.Uniform())
1082
1104
  img = hypernoise(img)
1083
1105
  ```
1084
1106
 
1085
1107
  """
1086
1108
 
1087
- def __init__(self, transform, sample, ksample=None,
1109
+ class Delayed:
1110
+ # Temproary parameter holder for delayed calls
1111
+ def __init__(self, transform, **kwargs):
1112
+ self.transform = transform
1113
+ self.kwargs = kwargs
1114
+
1115
+ def __call__(self, *args, **kwargs):
1116
+ return RandomizedTransform(
1117
+ self.transform, args, kwargs, **self.kwargs)
1118
+
1119
+ def __new__(cls, *args, **kwargs):
1120
+ if cls is RandomizedTransform:
1121
+ return cls._base_new(*args, **kwargs)
1122
+ return super().__new__(cls)
1123
+
1124
+ @classmethod
1125
+ def _base_new(cls, transform, sample=tuple(), ksample=dict(),
1126
+ *, shared=False, **kwargs):
1127
+ assert cls is RandomizedTransform
1128
+ if not sample and not ksample:
1129
+ # If no arguments are passed, it means that the user calls
1130
+ # this in "delayed/functional" mode. In that case, we return
1131
+ # a callable object that returns the constructed instance
1132
+ # using the call-time arguments.
1133
+ return cls.Delayed(transform, shared=shared, **kwargs)
1134
+ # Otherwise, we're in object mode and we instantiate the
1135
+ # randomized object.
1136
+ return super().__new__(cls)
1137
+
1138
+ def __init__(self, transform, sample=tuple(), ksample=dict(),
1088
1139
  *, shared=False, **kwargs):
1089
1140
  """
1090
-
1091
1141
  Parameters
1092
1142
  ----------
1093
1143
  transform : callable(...) -> Transform
1094
1144
  A Transform subclass or a function that constructs a Transform.
1095
1145
  sample : [list or dict of] callable
1096
1146
  A collection of functions that generate parameter values provided
1097
- to `transform`.
1147
+ to `transform`. Can be args-like or kwargs-like arguments.
1148
+ ksample : dict[callable]
1149
+ Must be kwargs-like arguments.
1098
1150
 
1099
1151
  Other Parameters
1100
1152
  ----------------
@@ -1130,9 +1182,7 @@ class RandomizedTransform(NonFinalTransform):
1130
1182
 
1131
1183
  def __repr__(self):
1132
1184
  if type(self) is RandomizedTransform:
1133
- try:
1134
- if issubclass(self.subtransform, Transform):
1135
- return f'Randomized{self.subtransform.__name__}()'
1136
- except TypeError:
1137
- pass
1185
+ xform = self.subtransform
1186
+ if isinstance(xform, type) and issubclass(xform, Transform):
1187
+ return f'Randomized{xform.__name__}()'
1138
1188
  return super().__repr__()
@@ -117,6 +117,29 @@ def returns_find(flag, returned, returns):
117
117
  return None
118
118
 
119
119
 
120
+ def returns_update(value, flag, returned, returns):
121
+ """Find tensor corresponding to flag in returned structure"""
122
+ if returns is None:
123
+ if flag == 'output':
124
+ return value
125
+ else:
126
+ return None
127
+ if isinstance(returns, dict):
128
+ if flag in returns:
129
+ returned[flag] = value
130
+ return returned
131
+ elif isinstance(returns, (list, tuple)):
132
+ if flag in returns:
133
+ returned[returns.index(flag)] = value
134
+ return returned
135
+ else:
136
+ assert isinstance(returns, str)
137
+ if returns == flag:
138
+ return value
139
+ else:
140
+ return None
141
+
142
+
120
143
  def flatstruct(x):
121
144
  """Flatten a nested structure of tensors"""
122
145
 
@@ -184,7 +184,7 @@ class ContrastLookupTransform(NonFinalTransform):
184
184
 
185
185
  vmin, vmax = x.min(), x.max()
186
186
  edges = torch.linspace(vmin, vmax, self.nk+1)
187
- new_mu = torch.rand(self.nk) * (vmax - vmin) + vmin
187
+ new_mu = torch.rand(self.nk).to(x) * (vmax - vmin) + vmin
188
188
  return self.LookupFinalTransform(
189
189
  edges, new_mu, **self.get_prm()
190
190
  ).make_final(x, max_depth-1)
File without changes
@@ -8,14 +8,16 @@ __all__ = [
8
8
  'CropTransform',
9
9
  'PadTransform',
10
10
  'PowerTwoTransform',
11
+ 'Rot90Transform',
12
+ 'Rot180Transform',
13
+ 'RandomRot90Transform',
11
14
  ]
12
-
13
15
  import math
14
16
  from random import shuffle
15
- from .base import FinalTransform, NonFinalTransform
17
+ from .base import FinalTransform, NonFinalTransform, PerChannelTransform
16
18
  from .utils.py import ensure_list
17
19
  from .utils.padding import pad
18
- from .random import Uniform, RandKFrom, Sampler
20
+ from .random import Uniform, RandKFrom, Sampler, RandInt, make_range
19
21
 
20
22
 
21
23
  class FlipTransform(FinalTransform):
@@ -23,7 +25,6 @@ class FlipTransform(FinalTransform):
23
25
 
24
26
  def __init__(self, axis=None, **kwargs):
25
27
  """
26
-
27
28
  Parameters
28
29
  ----------
29
30
  axis : [list of] int
@@ -46,24 +47,30 @@ class FlipTransform(FinalTransform):
46
47
  class RandomFlipTransform(NonFinalTransform):
47
48
  """Randomly flip one or more axes"""
48
49
 
49
- def __init__(self, axes=None, **kwargs):
50
+ def __init__(self, axes=None, *, shared=True, **kwargs):
50
51
  """
51
-
52
52
  Parameters
53
53
  ----------
54
54
  axes : Sampler or [list of] int
55
55
  Axes that can be flipped (default: all)
56
+
57
+ Other Parameters
58
+ ----------------
56
59
  shared : {'channels', 'tensors', 'channels+tensors', ''}
57
60
  Apply the same flip to all channels and/or tensors
58
61
  """
59
62
  axes = kwargs.pop('axis', axes)
60
- kwargs.setdefault('shared', True)
61
- super().__init__(**kwargs)
63
+ super().__init__(shared=shared, **kwargs)
62
64
  self.axes = axes
63
65
 
64
66
  def make_final(self, x, max_depth=float('inf')):
65
67
  if max_depth == 0:
66
68
  return self
69
+ if 'channels' not in self.shared and len(x) > 1:
70
+ return PerChannelTransform(
71
+ [self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
72
+ **self.get_prm()
73
+ ).make_final(x, max_depth-1)
67
74
  axes = self.axes or range(1, x.ndim)
68
75
  if not isinstance(axes, Sampler):
69
76
  rand_axes = RandKFrom(ensure_list(axes))
@@ -76,7 +83,6 @@ class PermuteAxesTransform(FinalTransform):
76
83
 
77
84
  def __init__(self, permutation=None, **kwargs):
78
85
  """
79
-
80
86
  Parameters
81
87
  ----------
82
88
  permutation : [list of] int
@@ -105,23 +111,29 @@ class PermuteAxesTransform(FinalTransform):
105
111
  class RandomPermuteAxesTransform(NonFinalTransform):
106
112
  """Randomly permute axes"""
107
113
 
108
- def __init__(self, axes=None, **kwargs):
114
+ def __init__(self, axes=None, *, shared=True, **kwargs):
109
115
  """
110
-
111
116
  Parameters
112
117
  ----------
113
118
  axes : [list of] int
114
119
  Axes that can be permuted (default: all)
120
+
121
+ Other Parameters
122
+ ----------------
115
123
  shared : {'channels', 'tensors', 'channels+tensors', ''}
116
124
  Apply the same permutation to all channels and/or tensors
117
125
  """
118
- kwargs.setdefault('shared', True)
119
- super().__init__(**kwargs)
126
+ super().__init__(shared=shared, **kwargs)
120
127
  self.axes = axes
121
128
 
122
129
  def make_final(self, x, max_depth=float('inf')):
123
130
  if max_depth == 0:
124
131
  return self
132
+ if 'channels' not in self.shared and len(x) > 1:
133
+ return PerChannelTransform(
134
+ [self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
135
+ **self.get_prm()
136
+ ).make_final(x, max_depth-1)
125
137
  axes = list(self.axes or range(x.ndim-1))
126
138
  shuffle(axes)
127
139
  return PermuteAxesTransform(
@@ -129,6 +141,119 @@ class RandomPermuteAxesTransform(NonFinalTransform):
129
141
  ).make_final(x, max_depth-1)
130
142
 
131
143
 
144
+ class Rot90Transform(FinalTransform):
145
+ """
146
+ Apply a 90 (or 180) rotation along one or several axes
147
+ """
148
+
149
+ def __init__(self, axis=0, negative=False, double=False, **kwargs):
150
+ """
151
+ Parameters
152
+ ----------
153
+ axis : int or list[int]
154
+ Rotation axis (indexing does not account for the channel axis)
155
+ negative : bool or list[bool]
156
+ Rotate by -90 deg instead of 90 deg
157
+ double : bool or list[bool]
158
+ Rotate be 180 instead of 90 (`negative` is then unused)
159
+ """
160
+ super().__init__(**kwargs)
161
+ self.axis = ensure_list(axis)
162
+ self.negative = ensure_list(negative, len(self.axis))
163
+ self.double = ensure_list(double, len(self.axis))
164
+
165
+ def xform(self, x):
166
+ # this implementation is suboptimal. We should fuse all transpose
167
+ # and all flips into a single "transpose + flip" operation so that
168
+ # a single allocation happens. This will be fine for now.
169
+
170
+ ndim = x.ndim - 1
171
+ axis = [1 + (ndim + a if a < 0 else a) for a in self.axis]
172
+ for ax, neg, dbl in zip(axis, self.negative, self.double):
173
+ if dbl:
174
+ if ndim == 2:
175
+ dims = [1, 2]
176
+ else:
177
+ assert ndim == 3
178
+ dims = [d for d in (1, 2, 3) if d != ax]
179
+ x = x.flip(dims)
180
+ else:
181
+ if ndim == 2:
182
+ dims = [1, 2]
183
+ else:
184
+ assert ndim == 3
185
+ dims = [d for d in (1, 2, 3) if d != ax]
186
+ x = x.transpose(*dims).flip(dims[1] if neg else dims[0])
187
+ return x
188
+
189
+
190
+ class Rot180Transform(Rot90Transform):
191
+ """Apply a 180 deg rotation along one or several axes"""
192
+
193
+ def __init__(self, axis=0, **kwargs):
194
+ """
195
+ Parameters
196
+ ----------
197
+ axis : int or list[int]
198
+ Rotation axis (indexing does not account for the channel axis)
199
+ """
200
+ super().__init__(axis, double=True, **kwargs)
201
+
202
+
203
+ class RandomRot90Transform(NonFinalTransform):
204
+ """Random set of 90 transforms"""
205
+
206
+ def __init__(self, axes=None, max_rot=2, negative=True,
207
+ *, shared=True, **kwargs):
208
+ """
209
+ Parameters
210
+ ----------
211
+ axes : int or list[int]
212
+ Axes along which rotations can happen.
213
+ If `None`, all axes.
214
+ max_rot : int or Sampler
215
+ Maximum number of consecutive rotations.
216
+ negative : bool
217
+ Whether to authorize negative rotations.
218
+
219
+ Other Parameters
220
+ ----------------
221
+ shared : {'channels', 'tensors', 'channels+tensors', ''}
222
+ Apply the same permutation to all channels and/or tensors
223
+ """
224
+ super().__init__(shared=shared, **kwargs)
225
+ self.axes = axes
226
+ self.max_rot = RandInt.make(make_range(1, max_rot))
227
+ self.negative = negative
228
+
229
+ def make_final(self, x, max_depth=float('inf')):
230
+ if max_depth == 0:
231
+ return self
232
+ if 'channels' not in self.shared and len(x) > 1:
233
+ return PerChannelTransform(
234
+ [self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
235
+ **self.get_prm()
236
+ ).make_final(x, max_depth-1)
237
+ ndim = x.ndim - 1
238
+ max_rot = self.max_rot
239
+ if isinstance(max_rot, Sampler):
240
+ max_rot = max_rot()
241
+ axes = self.axes
242
+ if axes is None:
243
+ axes = list(range(ndim))
244
+ if isinstance(axes, (int, list, tuple)):
245
+ axes = ensure_list(axes, max_rot, crop=False)
246
+ if not isinstance(axes, Sampler):
247
+ axes = RandKFrom(axes, max_rot, replacement=True)
248
+
249
+ axes = ensure_list(axes(), max_rot)
250
+ negative = RandKFrom([False, True], max_rot, replacement=True)() \
251
+ if self.negative else [False] * max_rot
252
+ return Rot90Transform(
253
+ axes, negative, **self.get_prm()
254
+ ).make_final(max_depth-1)
255
+
256
+
132
257
  class CropPadTransform(FinalTransform):
133
258
  """Crop and/or pad a tensor"""
134
259