cornucopia 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornucopia/__init__.py +73 -0
- cornucopia/base.py +1915 -0
- cornucopia/baseutils.py +575 -0
- cornucopia/contrast.py +260 -0
- cornucopia/ctx.py +25 -0
- cornucopia/fov.py +707 -0
- cornucopia/geometric.py +2068 -0
- cornucopia/intensity.py +1358 -0
- cornucopia/io.py +161 -0
- cornucopia/kspace.py +505 -0
- cornucopia/labels.py +1872 -0
- cornucopia/noise.py +508 -0
- cornucopia/psf.py +463 -0
- cornucopia/qmri.py +1288 -0
- cornucopia/random.py +1480 -0
- cornucopia/special.py +159 -0
- cornucopia/synth.py +708 -0
- cornucopia/tests/__init__.py +0 -0
- cornucopia/tests/test_backward_geometric.py +173 -0
- cornucopia/tests/test_backward_intensity.py +243 -0
- cornucopia/tests/test_backward_kspace.py +115 -0
- cornucopia/tests/test_backward_noise.py +169 -0
- cornucopia/tests/test_backward_psf.py +142 -0
- cornucopia/tests/test_backward_qmri.py +249 -0
- cornucopia/tests/test_backward_random.py +44 -0
- cornucopia/tests/test_backward_synth.py +72 -0
- cornucopia/tests/test_base.py +401 -0
- cornucopia/tests/test_geometric.py +26 -0
- cornucopia/tests/test_intensity.py +9 -0
- cornucopia/tests/test_random.py +722 -0
- cornucopia/tests/test_run_contrast.py +28 -0
- cornucopia/tests/test_run_fov.py +132 -0
- cornucopia/tests/test_run_geometric.py +157 -0
- cornucopia/tests/test_run_intensity.py +192 -0
- cornucopia/tests/test_run_kspace.py +70 -0
- cornucopia/tests/test_run_labels.py +224 -0
- cornucopia/tests/test_run_noise.py +127 -0
- cornucopia/tests/test_run_psf.py +115 -0
- cornucopia/tests/test_run_qmri.py +114 -0
- cornucopia/tests/test_run_synth.py +67 -0
- cornucopia/typing.py +97 -0
- cornucopia/utils/__init__.py +0 -0
- cornucopia/utils/b0.py +745 -0
- cornucopia/utils/bounds.py +412 -0
- cornucopia/utils/compat.py +47 -0
- cornucopia/utils/conv.py +305 -0
- cornucopia/utils/gmm.py +169 -0
- cornucopia/utils/indexing.py +911 -0
- cornucopia/utils/io.py +258 -0
- cornucopia/utils/jit.py +128 -0
- cornucopia/utils/kernels.py +288 -0
- cornucopia/utils/morpho.py +234 -0
- cornucopia/utils/mrf.py +574 -0
- cornucopia/utils/padding.py +173 -0
- cornucopia/utils/patch.py +302 -0
- cornucopia/utils/pool.py +282 -0
- cornucopia/utils/py.py +348 -0
- cornucopia/utils/smart_inplace.py +163 -0
- cornucopia/utils/version.py +57 -0
- cornucopia/utils/warps.py +606 -0
- cornucopia-0.0.0.dist-info/METADATA +92 -0
- cornucopia-0.0.0.dist-info/RECORD +65 -0
- cornucopia-0.0.0.dist-info/WHEEL +5 -0
- cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
- cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/base.py
ADDED
|
@@ -0,0 +1,1915 @@
|
|
|
1
|
+
__all__ = [
|
|
2
|
+
'Transform',
|
|
3
|
+
'FinalTransform',
|
|
4
|
+
'NonFinalTransform',
|
|
5
|
+
]
|
|
6
|
+
# stdlib
|
|
7
|
+
import inspect
|
|
8
|
+
import random
|
|
9
|
+
import re
|
|
10
|
+
from abc import ABC
|
|
11
|
+
from copy import copy
|
|
12
|
+
from fnmatch import fnmatch
|
|
13
|
+
from math import inf
|
|
14
|
+
|
|
15
|
+
# dependencies
|
|
16
|
+
import torch
|
|
17
|
+
import typing_extensions as tx
|
|
18
|
+
from torch import nn, Tensor
|
|
19
|
+
|
|
20
|
+
# internal
|
|
21
|
+
from .random import Sampler
|
|
22
|
+
from .utils.py import ensure_list, cumsum
|
|
23
|
+
from .baseutils import (
|
|
24
|
+
Arguments, Args, Kwargs, ArgsAndKwargs, NoArguments, Returned,
|
|
25
|
+
get_first_element, prepare_output, UNSET, recursive_cat,
|
|
26
|
+
)
|
|
27
|
+
from . import typing as cct
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Transform(nn.Module, ABC):
|
|
31
|
+
"""Base class for all transforms."""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self, *,
|
|
35
|
+
returns: tx.Union[str, tx.Sequence[str], tx.Mapping[str, str], None] = None,
|
|
36
|
+
append: tx.Union[bool, str] = False,
|
|
37
|
+
prefix: tx.Union[bool, str] = True,
|
|
38
|
+
include: tx.Optional[cct.IncludeT] = None,
|
|
39
|
+
exclude: tx.Optional[cct.ExcludeT] = None,
|
|
40
|
+
consume: tx.Optional[cct.ConsumeT] = None,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
returns : [list or dict of] str, optional
|
|
46
|
+
Which tensors to return. Can be a nested structure.
|
|
47
|
+
Most transforms accept `'input'` and `'output'` as valid
|
|
48
|
+
returns. The default is `'output'`.
|
|
49
|
+
append : bool | str
|
|
50
|
+
Append the (structure of) returned tensors to the parent
|
|
51
|
+
structure.
|
|
52
|
+
|
|
53
|
+
!!! warning
|
|
54
|
+
This option does not keep the input tensors in the returned
|
|
55
|
+
structure! To preserve the input tensors, you should
|
|
56
|
+
use `append` in conjunction with `returns`.
|
|
57
|
+
|
|
58
|
+
!!! example
|
|
59
|
+
```python
|
|
60
|
+
# With lists
|
|
61
|
+
trf = MyTransform(returns=['input', 'output'], append=True)
|
|
62
|
+
x1, y1, x2, y2 = trf([x1, x2])
|
|
63
|
+
|
|
64
|
+
# With dicts
|
|
65
|
+
trf = MyTransform(returns={'x': 'input', 'y': 'output'}, append=True, prefix=True)
|
|
66
|
+
out = trf({'path1': x1, 'path2': x2})
|
|
67
|
+
assert out.keys() == {'path1.x', 'path1.y', 'path2.x', 'path2.y'}
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
!!! changedin " \
|
|
71
|
+
Can be a string since `v0.5`"
|
|
72
|
+
If it is a `str` and parent is a `dict`, its value will be
|
|
73
|
+
used as a separator between the prefix and the key.
|
|
74
|
+
See `prefix`.
|
|
75
|
+
prefix : bool | str
|
|
76
|
+
If `append` and parent is a `dict`, prefix the returned key
|
|
77
|
+
before inserting it in the output dictionary.
|
|
78
|
+
|
|
79
|
+
If `True`, the prefix is the input key.
|
|
80
|
+
|
|
81
|
+
!!! changedin " \
|
|
82
|
+
Can be a string since `v0.5`"
|
|
83
|
+
include : [list of] str | re.Pattern, optional
|
|
84
|
+
List of keys to which the transform should apply.
|
|
85
|
+
Default: all.
|
|
86
|
+
|
|
87
|
+
!!! changedin " \
|
|
88
|
+
Can be a regex or glob pattern since `v0.5`"
|
|
89
|
+
exclude : [list of] str | re.Pattern, optional
|
|
90
|
+
List of keys to which the transform should not apply.
|
|
91
|
+
Default: none.
|
|
92
|
+
|
|
93
|
+
!!! changedin " \
|
|
94
|
+
Can be a regex or glob pattern since `v0.5`"
|
|
95
|
+
consume : [list of] str | re.Pattern, optional
|
|
96
|
+
List of keys to remove from the output after applying the
|
|
97
|
+
transform. Default: none.
|
|
98
|
+
|
|
99
|
+
!!! addedin " \
|
|
100
|
+
Added in `v0.5`. Can be a regex or glob pattern."
|
|
101
|
+
"""
|
|
102
|
+
super().__init__()
|
|
103
|
+
self.returns = returns
|
|
104
|
+
self.append = append
|
|
105
|
+
self.prefix = prefix
|
|
106
|
+
self.include = ensure_list(include) if include is not None else None
|
|
107
|
+
self.exclude = ensure_list(exclude or tuple())
|
|
108
|
+
self.consume = ensure_list(consume or tuple())
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def is_final(self) -> bool:
|
|
112
|
+
"""
|
|
113
|
+
Returns
|
|
114
|
+
-------
|
|
115
|
+
bool
|
|
116
|
+
Whether the transform is final (i.e., deterministic) or not.
|
|
117
|
+
"""
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
def get_prm(self) -> dict:
|
|
121
|
+
"""Get the parameters of the transform, for use in subtransforms.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
dict
|
|
126
|
+
A dictionary containing the attributes
|
|
127
|
+
`returns`, `append`, `prefix`, `include`, `exclude`, and
|
|
128
|
+
`consume`.
|
|
129
|
+
|
|
130
|
+
"""
|
|
131
|
+
return dict(
|
|
132
|
+
returns=self.returns,
|
|
133
|
+
append=self.append,
|
|
134
|
+
prefix=self.prefix,
|
|
135
|
+
include=self.include,
|
|
136
|
+
exclude=self.exclude,
|
|
137
|
+
consume=self.consume
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def _is_included(self, key: str) -> bool:
|
|
141
|
+
if self.include is None:
|
|
142
|
+
return True
|
|
143
|
+
for pattern in self.include:
|
|
144
|
+
if isinstance(pattern, re.Pattern):
|
|
145
|
+
return bool(pattern.match(key))
|
|
146
|
+
else:
|
|
147
|
+
return fnmatch(key, pattern)
|
|
148
|
+
return False
|
|
149
|
+
|
|
150
|
+
def _is_excluded(self, key: str) -> bool:
|
|
151
|
+
if self.exclude is None:
|
|
152
|
+
return False
|
|
153
|
+
for pattern in self.exclude:
|
|
154
|
+
if isinstance(pattern, re.Pattern):
|
|
155
|
+
return bool(pattern.match(key))
|
|
156
|
+
else:
|
|
157
|
+
return fnmatch(key, pattern)
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
def _is_consumed(self, key: str) -> bool:
|
|
161
|
+
if self.consume is None:
|
|
162
|
+
return False
|
|
163
|
+
for pattern in self.consume:
|
|
164
|
+
if isinstance(pattern, re.Pattern):
|
|
165
|
+
return bool(pattern.match(key))
|
|
166
|
+
else:
|
|
167
|
+
return fnmatch(key, pattern)
|
|
168
|
+
return False
|
|
169
|
+
|
|
170
|
+
def __enter__(self) -> "Transform":
|
|
171
|
+
# On most tranfsorms, this does nothing but return the transform
|
|
172
|
+
# itself. However, some subclasses use this to act as context
|
|
173
|
+
# managers that temporarily modify the transform's behavior.
|
|
174
|
+
# See: `IncludeKeysTransform` and `ExcludeKeysTransform`.
|
|
175
|
+
return self
|
|
176
|
+
|
|
177
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
|
|
178
|
+
return False
|
|
179
|
+
|
|
180
|
+
def __add__(self, other: "Transform") -> "SequentialTransform":
|
|
181
|
+
return SequentialTransform([self, other])
|
|
182
|
+
|
|
183
|
+
def __radd__(self, other: "Transform") -> "SequentialTransform":
|
|
184
|
+
return SequentialTransform([other, self])
|
|
185
|
+
|
|
186
|
+
def __iadd__(self, other: "Transform") -> "SequentialTransform":
|
|
187
|
+
return SequentialTransform([self, other])
|
|
188
|
+
|
|
189
|
+
def __mul__(self, prob: float) -> "MaybeTransform":
|
|
190
|
+
return MaybeTransform(self, prob)
|
|
191
|
+
|
|
192
|
+
def __rmul__(self, prob: float) -> "MaybeTransform":
|
|
193
|
+
return MaybeTransform(self, prob)
|
|
194
|
+
|
|
195
|
+
def __imul__(self, prob: float) -> "MaybeTransform":
|
|
196
|
+
return MaybeTransform(self, prob)
|
|
197
|
+
|
|
198
|
+
def __or__(self, other: "Transform") -> "SwitchTransform":
|
|
199
|
+
return SwitchTransform([self, other])
|
|
200
|
+
|
|
201
|
+
def __ior__(self, other: "Transform") -> "SwitchTransform":
|
|
202
|
+
return SwitchTransform([self, other])
|
|
203
|
+
|
|
204
|
+
def __call__(self, *a, **k):
|
|
205
|
+
# Use the torch machinery, although `Returned` objects get unwrapped.
|
|
206
|
+
out = super().__call__(*a, **k)
|
|
207
|
+
if isinstance(out, Returned):
|
|
208
|
+
out = out.obj
|
|
209
|
+
return out
|
|
210
|
+
|
|
211
|
+
def forward(self, *a, **k) -> Returned:
|
|
212
|
+
"""Apply the transform recursively.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
*a, **k : [nested list or dict of] tensor
|
|
217
|
+
Input tensors, with shape `(C, *shape)`
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
[nested list or dict of] tensor
|
|
222
|
+
Output tensors. with shape `(C, *shape)`
|
|
223
|
+
|
|
224
|
+
"""
|
|
225
|
+
# We wrap positional and keywork arguments in special classes
|
|
226
|
+
# to differentiate from inputs that are lists or dicts.
|
|
227
|
+
print(self.__class__, tuple(type(aa) for aa in a), {k: type(v) for k, v in k.items()})
|
|
228
|
+
x = args = Arguments(*a, **k)
|
|
229
|
+
print("->", type(x))
|
|
230
|
+
|
|
231
|
+
if not args:
|
|
232
|
+
# If no input arguments, return None.
|
|
233
|
+
# NOTE: Only `NoArguments()` reduces to `False`.
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
# Arguments are passed to `_Forward.__init__` and `_Forward.__call__`.
|
|
237
|
+
# The former is preserved as is in the `_Forward` object, and passed
|
|
238
|
+
# to each `xform` or `unroll` call, while the latter is
|
|
239
|
+
# recursively unwrapped and processed by `_Forward.__call__`.
|
|
240
|
+
return self._Forward(self, args, **self.get_prm())(x)
|
|
241
|
+
|
|
242
|
+
class _Forward:
|
|
243
|
+
|
|
244
|
+
def __init__(self, transform, args: Arguments, **prm) -> None:
|
|
245
|
+
self.transform = transform
|
|
246
|
+
self.args = args
|
|
247
|
+
self.prm = prm
|
|
248
|
+
|
|
249
|
+
@property
|
|
250
|
+
def include(self) -> tx.Optional[tx.Sequence[str]]:
|
|
251
|
+
return self.prm.get('include')
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def exclude(self) -> tx.Optional[tx.Sequence[str]]:
|
|
255
|
+
return self.prm.get('exclude')
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def consume(self) -> tx.Optional[tx.Sequence[str]]:
|
|
259
|
+
return self.prm.get('consume')
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def append(self) -> bool:
|
|
263
|
+
return self.prm.get('append', False)
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def prefix(self) -> tx.Union[bool, str]:
|
|
267
|
+
return self.prm.get('prefix', True)
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def returns(self) -> tx.Optional[tx.Union[
|
|
271
|
+
str, tx.Sequence[str], tx.Mapping[str, str]
|
|
272
|
+
]]:
|
|
273
|
+
return self.prm.get('returns')
|
|
274
|
+
|
|
275
|
+
def __call__(
|
|
276
|
+
self, x: tx.Union[Tensor, tx.Sequence, tx.Mapping, Arguments]
|
|
277
|
+
) -> tx.Union[Tensor, tx.Sequence, tx.Mapping, Arguments]:
|
|
278
|
+
if isinstance(x, NoArguments):
|
|
279
|
+
return None
|
|
280
|
+
|
|
281
|
+
# At this point, there is a single positional argument
|
|
282
|
+
# (which may be an Args, Kwargs, or ArgsAndKwargs object)
|
|
283
|
+
# and no keyword arguments. We save the original input type
|
|
284
|
+
# to be able to return the same kind of `Parameters` object.
|
|
285
|
+
intype = type(x)
|
|
286
|
+
|
|
287
|
+
# If the input is an `Arguments` object, convert it to list/dict
|
|
288
|
+
if isinstance(x, Arguments):
|
|
289
|
+
x = x.unwrap()
|
|
290
|
+
|
|
291
|
+
def outtype(x):
|
|
292
|
+
# Convert back to the original input type if needed
|
|
293
|
+
if intype is ArgsAndKwargs:
|
|
294
|
+
args, kwargs = x
|
|
295
|
+
elif intype is Args:
|
|
296
|
+
args, kwargs = x, {}
|
|
297
|
+
elif intype is Kwargs:
|
|
298
|
+
args, kwargs = (), x
|
|
299
|
+
else:
|
|
300
|
+
return x
|
|
301
|
+
return Arguments(*args, **kwargs)
|
|
302
|
+
|
|
303
|
+
# Not shared across tensors -> unfold
|
|
304
|
+
if isinstance(x, (list, tuple)):
|
|
305
|
+
return outtype(self._forward_list(x))
|
|
306
|
+
|
|
307
|
+
if hasattr(x, 'items'):
|
|
308
|
+
x = self._forward_dict(x)
|
|
309
|
+
if not isinstance(x, dict) and intype in (Kwargs, ArgsAndKwargs):
|
|
310
|
+
return Args(*x)
|
|
311
|
+
else:
|
|
312
|
+
return outtype(x)
|
|
313
|
+
|
|
314
|
+
# ---- Now we're working with a single tensor (or str) ----
|
|
315
|
+
|
|
316
|
+
# Apply the transform to the input tensor
|
|
317
|
+
y = self.transform.xform(x, args=self.args)
|
|
318
|
+
|
|
319
|
+
# Most transforms return a well-formatted `Returned` object,
|
|
320
|
+
# which contain all possible outputs of a transform, mapped
|
|
321
|
+
# into a structure specified by the `returns` argument.
|
|
322
|
+
if not isinstance(y, Returned):
|
|
323
|
+
# When they do not, we have to build the `Returned`
|
|
324
|
+
# object ourselves.
|
|
325
|
+
if not isinstance(y, type(self.returns)):
|
|
326
|
+
# The transform returned a single output (likely a
|
|
327
|
+
# tensor), which we assign to the `output` key,
|
|
328
|
+
# while the input is assigned to the `input` key.
|
|
329
|
+
y = dict(input=x, output=y)
|
|
330
|
+
y = prepare_output(y, self.returns).obj
|
|
331
|
+
else:
|
|
332
|
+
# `returns` and `y` have the same type, but `y` may
|
|
333
|
+
# have been obtained from a subtransform. We cannot
|
|
334
|
+
# guarantee that the keys of `y` are the same as
|
|
335
|
+
# those of `returns`, but we can insert the correct
|
|
336
|
+
# `input` tensor, it it was requested.
|
|
337
|
+
# NOTE: we cannot break early once `"input"` is
|
|
338
|
+
# encountered, because multiple outputs elements can
|
|
339
|
+
# contain the same target
|
|
340
|
+
# (e.g. `returns=['input', 'input']`).
|
|
341
|
+
if isinstance(self.returns, dict):
|
|
342
|
+
for key, target in self.returns.items():
|
|
343
|
+
if target == 'input':
|
|
344
|
+
y[key] = x
|
|
345
|
+
elif isinstance(self.returns, (list, tuple)):
|
|
346
|
+
for i, target in enumerate(self.returns):
|
|
347
|
+
if target == 'input':
|
|
348
|
+
y[i] = x
|
|
349
|
+
# Wrap output in a `Returned` object, so that helpers
|
|
350
|
+
# know how to handle it (e.g. when `append=True`).
|
|
351
|
+
y = Returned(y)
|
|
352
|
+
|
|
353
|
+
# ---- Now we're working with a `Returned` object ----
|
|
354
|
+
return y
|
|
355
|
+
|
|
356
|
+
def _is_included(self, key: str) -> bool:
|
|
357
|
+
if self.include is None:
|
|
358
|
+
return True
|
|
359
|
+
for pattern in self.include:
|
|
360
|
+
if isinstance(pattern, re.Pattern):
|
|
361
|
+
return bool(pattern.match(key))
|
|
362
|
+
else:
|
|
363
|
+
return fnmatch(key, pattern)
|
|
364
|
+
return False
|
|
365
|
+
|
|
366
|
+
def _is_excluded(self, key: str) -> bool:
|
|
367
|
+
if self.exclude is None:
|
|
368
|
+
return False
|
|
369
|
+
for pattern in self.exclude:
|
|
370
|
+
if isinstance(pattern, re.Pattern):
|
|
371
|
+
return bool(pattern.match(key))
|
|
372
|
+
else:
|
|
373
|
+
return fnmatch(key, pattern)
|
|
374
|
+
return False
|
|
375
|
+
|
|
376
|
+
def _is_consumed(self, key: str) -> bool:
|
|
377
|
+
if self.consume is None:
|
|
378
|
+
return False
|
|
379
|
+
for pattern in self.consume:
|
|
380
|
+
if isinstance(pattern, re.Pattern):
|
|
381
|
+
return bool(pattern.match(key))
|
|
382
|
+
else:
|
|
383
|
+
return fnmatch(key, pattern)
|
|
384
|
+
return False
|
|
385
|
+
|
|
386
|
+
def _get_valid_keys(self, x: tx.Mapping[str, str]) -> tx.Sequence[str]:
|
|
387
|
+
valid_keys = x.keys()
|
|
388
|
+
if self.include is not None:
|
|
389
|
+
valid_keys = [
|
|
390
|
+
k for k in valid_keys
|
|
391
|
+
if self._is_included(k)
|
|
392
|
+
]
|
|
393
|
+
if self.exclude:
|
|
394
|
+
valid_keys = [
|
|
395
|
+
k for k in valid_keys
|
|
396
|
+
if not self._is_excluded(k)
|
|
397
|
+
]
|
|
398
|
+
return valid_keys
|
|
399
|
+
|
|
400
|
+
def _forward_list(
|
|
401
|
+
self, x: list, forward: tx.Optional[tx.Callable] = None
|
|
402
|
+
) -> list:
|
|
403
|
+
"""Apply forward pass to elements of a list"""
|
|
404
|
+
forward = forward or self
|
|
405
|
+
append = self.append and not isinstance(self.append, str)
|
|
406
|
+
y = []
|
|
407
|
+
for elem in x:
|
|
408
|
+
elem = forward(elem)
|
|
409
|
+
if isinstance(elem, Returned):
|
|
410
|
+
elem = elem.obj
|
|
411
|
+
if append:
|
|
412
|
+
if isinstance(elem, dict):
|
|
413
|
+
y.extend(elem.values())
|
|
414
|
+
continue
|
|
415
|
+
elif isinstance(elem, (list, tuple)):
|
|
416
|
+
y.extend(elem)
|
|
417
|
+
continue
|
|
418
|
+
y.append(elem)
|
|
419
|
+
return type(x)(y)
|
|
420
|
+
|
|
421
|
+
def _forward_dict(
|
|
422
|
+
self, x: dict, forward: tx.Optional[tx.Callable] = None
|
|
423
|
+
) -> tx.Union[dict, list]:
|
|
424
|
+
"""Apply forward pass to elements of a dict"""
|
|
425
|
+
forward = forward or self
|
|
426
|
+
valid_keys = self._get_valid_keys(x)
|
|
427
|
+
|
|
428
|
+
append = self.append
|
|
429
|
+
if isinstance(append, str):
|
|
430
|
+
sep, append = append, True
|
|
431
|
+
elif append:
|
|
432
|
+
sep = "."
|
|
433
|
+
|
|
434
|
+
# Initialise output dictionary with input keys and values
|
|
435
|
+
# that *will not* be transformed so that they are preserved.
|
|
436
|
+
y = {
|
|
437
|
+
key: value
|
|
438
|
+
for key, value in x.items()
|
|
439
|
+
if key not in valid_keys and not self._is_consumed(key)
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
# For each input item, apply the transform and save its outputs
|
|
443
|
+
for key, value in x.items():
|
|
444
|
+
if key not in valid_keys:
|
|
445
|
+
continue
|
|
446
|
+
|
|
447
|
+
# Compute prefix
|
|
448
|
+
prefix = self.prefix
|
|
449
|
+
if prefix and not isinstance(self.prefix, str):
|
|
450
|
+
prefix = key
|
|
451
|
+
|
|
452
|
+
# Apply transform to input value
|
|
453
|
+
value = forward(value)
|
|
454
|
+
|
|
455
|
+
# Deal with returned values.
|
|
456
|
+
if isinstance(value, Returned):
|
|
457
|
+
value = value.obj
|
|
458
|
+
|
|
459
|
+
if append:
|
|
460
|
+
|
|
461
|
+
if isinstance(y, dict) and isinstance(value, dict):
|
|
462
|
+
# Insert the returned values in the output dictionary,
|
|
463
|
+
# using a new key, so that the input value is preserved.
|
|
464
|
+
if prefix:
|
|
465
|
+
value = {
|
|
466
|
+
prefix + sep + child_key: child_value
|
|
467
|
+
for child_key, child_value in value.items()
|
|
468
|
+
}
|
|
469
|
+
y.update(value)
|
|
470
|
+
continue
|
|
471
|
+
|
|
472
|
+
if isinstance(y, dict):
|
|
473
|
+
# The transform did not return a dictionary, so
|
|
474
|
+
# we cannot insert its values in the output dict.
|
|
475
|
+
# We transform it into a list and append the
|
|
476
|
+
# returned values to it.
|
|
477
|
+
y = list(y.values())
|
|
478
|
+
|
|
479
|
+
if isinstance(value, (list, tuple)):
|
|
480
|
+
# Append the returned values to the output list.
|
|
481
|
+
y.extend(value)
|
|
482
|
+
continue
|
|
483
|
+
|
|
484
|
+
if isinstance(value, dict):
|
|
485
|
+
# The output dictionary was previously transformed
|
|
486
|
+
# into a list, so we cannot insert the returned
|
|
487
|
+
# (key, value) pairs. We append the values instead.
|
|
488
|
+
y.extend(value.values())
|
|
489
|
+
continue
|
|
490
|
+
|
|
491
|
+
# We insert the returned value (whether it is a single tensor,
|
|
492
|
+
# or a nested structure of tensors) it in the output dictionary,
|
|
493
|
+
# in place of the input value.
|
|
494
|
+
if isinstance(y, dict):
|
|
495
|
+
y[key] = value
|
|
496
|
+
else:
|
|
497
|
+
y.append(value)
|
|
498
|
+
|
|
499
|
+
if isinstance(y, dict):
|
|
500
|
+
|
|
501
|
+
# Consume keys
|
|
502
|
+
for key in list(y.keys()):
|
|
503
|
+
if self._is_consumed(key):
|
|
504
|
+
y.pop(key)
|
|
505
|
+
|
|
506
|
+
# Convert to dictionary subtype
|
|
507
|
+
return type(x)(y)
|
|
508
|
+
else:
|
|
509
|
+
return y
|
|
510
|
+
|
|
511
|
+
def xform(
|
|
512
|
+
self, x: Tensor, /,
|
|
513
|
+
args: Arguments = NoArguments(),
|
|
514
|
+
) -> Returned:
|
|
515
|
+
"""Apply the transform to a tensor.
|
|
516
|
+
|
|
517
|
+
Non-final transforms do not implement this method in general.
|
|
518
|
+
|
|
519
|
+
Parameters
|
|
520
|
+
----------
|
|
521
|
+
x : (C_inp, *spatial_inp) tensor
|
|
522
|
+
A single input tensor
|
|
523
|
+
args: Arguments, optional
|
|
524
|
+
The original inputs arguments to the transform, in case
|
|
525
|
+
they are needed.
|
|
526
|
+
|
|
527
|
+
Returns
|
|
528
|
+
-------
|
|
529
|
+
y : Returned | (C_out, *spatial_out) tensor
|
|
530
|
+
A single output tensor, or a `Returned` object containing
|
|
531
|
+
multiple output tensors and their corresponding keys.
|
|
532
|
+
"""
|
|
533
|
+
# Wrapper that calls `_xform`, but only passes `args` if the
|
|
534
|
+
# method accepts it, to avoid errors with legacy implementations.
|
|
535
|
+
if 'args' in inspect.signature(self._xform).parameters:
|
|
536
|
+
return self._xform(x, args=args)
|
|
537
|
+
else:
|
|
538
|
+
return self._xform(x)
|
|
539
|
+
|
|
540
|
+
def _xform(
|
|
541
|
+
self, x: Tensor, /,
|
|
542
|
+
args: Arguments = NoArguments(),
|
|
543
|
+
) -> Returned:
|
|
544
|
+
raise NotImplementedError("This transform does not implement `xform`.")
|
|
545
|
+
|
|
546
|
+
def final(
|
|
547
|
+
self,
|
|
548
|
+
x: Tensor, /,
|
|
549
|
+
args: Arguments = NoArguments(),
|
|
550
|
+
**kwargs
|
|
551
|
+
) -> "FinalTransform":
|
|
552
|
+
"""
|
|
553
|
+
Generate the final version of the transform.
|
|
554
|
+
|
|
555
|
+
Some transforms save the output type of this function in their
|
|
556
|
+
`Final` attribute.
|
|
557
|
+
|
|
558
|
+
!!! addedin " \
|
|
559
|
+
Added `final` method in `v0.5`."
|
|
560
|
+
Before this, one had to use `make_final(x, max_depth=inf)`.
|
|
561
|
+
|
|
562
|
+
Parameters
|
|
563
|
+
----------
|
|
564
|
+
x : tensor
|
|
565
|
+
A single input tensor, with shape `(C, *shape)`.
|
|
566
|
+
args: Arguments, optional
|
|
567
|
+
The original inputs arguments to the transform, in case
|
|
568
|
+
they are needed.
|
|
569
|
+
|
|
570
|
+
Returns
|
|
571
|
+
-------
|
|
572
|
+
FinalTransform
|
|
573
|
+
A final version of the transform.
|
|
574
|
+
"""
|
|
575
|
+
return self.unroll(x, max_depth=inf, args=args, **kwargs)
|
|
576
|
+
|
|
577
|
+
def next(
|
|
578
|
+
self,
|
|
579
|
+
x: Tensor, /,
|
|
580
|
+
args: Arguments = NoArguments(),
|
|
581
|
+
**kwargs
|
|
582
|
+
) -> "FinalTransform":
|
|
583
|
+
"""
|
|
584
|
+
Generate the next version of the transform.
|
|
585
|
+
|
|
586
|
+
Some transforms save the output type of this function in their
|
|
587
|
+
`Next` attribute.
|
|
588
|
+
|
|
589
|
+
!!! addedin " \
|
|
590
|
+
Added `next` method in `v0.5`."
|
|
591
|
+
Before this, one had to use `make_final(x, max_depth=1)`.
|
|
592
|
+
|
|
593
|
+
Parameters
|
|
594
|
+
----------
|
|
595
|
+
x : tensor
|
|
596
|
+
A single input tensor, with shape `(C, *shape)`.
|
|
597
|
+
args: Arguments, optional
|
|
598
|
+
The original inputs arguments to the transform, in case
|
|
599
|
+
they are needed.
|
|
600
|
+
|
|
601
|
+
Returns
|
|
602
|
+
-------
|
|
603
|
+
Transform
|
|
604
|
+
A more specialized version of the transform.
|
|
605
|
+
"""
|
|
606
|
+
return self.unroll(x, max_depth=1, args=args, **kwargs)
|
|
607
|
+
|
|
608
|
+
def unroll(
|
|
609
|
+
self, x: Tensor, /,
|
|
610
|
+
max_depth: int = inf,
|
|
611
|
+
args: Arguments = NoArguments(),
|
|
612
|
+
**kwargs
|
|
613
|
+
) -> "Transform":
|
|
614
|
+
"""
|
|
615
|
+
Generate the next (i.e., more final) version(s) of the transform.
|
|
616
|
+
|
|
617
|
+
* To completely finalize a transform,
|
|
618
|
+
call `unroll(x, max_depth=inf)` or `final()`.
|
|
619
|
+
* To get the the next version of a transform,
|
|
620
|
+
call `unroll(x, max_depth=1)` or `next()`.
|
|
621
|
+
|
|
622
|
+
!!! addedin " \
|
|
623
|
+
Added `unroll` method in `v0.5`."
|
|
624
|
+
Before this, it was named `make_final`.
|
|
625
|
+
|
|
626
|
+
Parameters
|
|
627
|
+
----------
|
|
628
|
+
x : tensor
|
|
629
|
+
A single input tensor, with shape `(C, *shape)`.
|
|
630
|
+
max_depth : int | {inf}
|
|
631
|
+
Maximum depth to apply `unroll` recursively.
|
|
632
|
+
If not `inf`, the resulting transform may not be fully final.
|
|
633
|
+
Default: no limit.
|
|
634
|
+
args: Arguments, optional
|
|
635
|
+
The original inputs arguments to the transform, in case
|
|
636
|
+
they are needed.
|
|
637
|
+
|
|
638
|
+
Returns
|
|
639
|
+
-------
|
|
640
|
+
Transform
|
|
641
|
+
A more specialized version of the transform.
|
|
642
|
+
"""
|
|
643
|
+
if max_depth <= 0:
|
|
644
|
+
# This is always valid, so let's catch it
|
|
645
|
+
return self
|
|
646
|
+
# Wrapper that calls `_unroll`, but only passes `args` if the
|
|
647
|
+
# method accepts it.
|
|
648
|
+
if 'args' in inspect.signature(self._unroll).parameters:
|
|
649
|
+
return self._unroll(x, max_depth, args=args, **kwargs)
|
|
650
|
+
else:
|
|
651
|
+
return self._unroll(x, max_depth, **kwargs)
|
|
652
|
+
|
|
653
|
+
def make_final(
|
|
654
|
+
self, x: Tensor, /,
|
|
655
|
+
max_depth: int = inf,
|
|
656
|
+
args: Arguments = NoArguments(),
|
|
657
|
+
**kwargs
|
|
658
|
+
) -> "Transform":
|
|
659
|
+
# Deprecated, but keep it for backward compatibility
|
|
660
|
+
return self.unroll(x, max_depth=max_depth, args=args, **kwargs)
|
|
661
|
+
|
|
662
|
+
def _unroll(
|
|
663
|
+
self, x: Tensor, /,
|
|
664
|
+
max_depth: int = inf,
|
|
665
|
+
args: Arguments = NoArguments(),
|
|
666
|
+
) -> "Transform":
|
|
667
|
+
if self.is_final or max_depth == 0:
|
|
668
|
+
return self
|
|
669
|
+
raise NotImplementedError("This transform does not implement `unroll`")
|
|
670
|
+
|
|
671
|
+
def inverse(self, *a, **k) -> "FinalTransform":
|
|
672
|
+
"""Apply the inverse transform recursively
|
|
673
|
+
|
|
674
|
+
Parameters
|
|
675
|
+
----------
|
|
676
|
+
*a, **k : [nested list or dict of] tensor
|
|
677
|
+
Input tensors, with shape `(C, *shape)`
|
|
678
|
+
|
|
679
|
+
Returns
|
|
680
|
+
-------
|
|
681
|
+
[nested list or dict of] tensor
|
|
682
|
+
Output tensors. with shape `(C, *shape)`
|
|
683
|
+
|
|
684
|
+
"""
|
|
685
|
+
return self.make_inverse()(*a, **k)
|
|
686
|
+
|
|
687
|
+
def make_inverse(self) -> "FinalTransform":
|
|
688
|
+
"""Generate the inverse transform"""
|
|
689
|
+
# We fallback to the identity, rather than raising an error.
|
|
690
|
+
return IdentityTransform()
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
class FinalTransform(Transform):
|
|
694
|
+
"""
|
|
695
|
+
Base class for determinstic transforms.
|
|
696
|
+
|
|
697
|
+
Final transforms *must* implement the `xform` method.
|
|
698
|
+
"""
|
|
699
|
+
|
|
700
|
+
@property
|
|
701
|
+
def is_final(self) -> bool:
|
|
702
|
+
return True
|
|
703
|
+
|
|
704
|
+
def _unroll(
|
|
705
|
+
self, x: Tensor, /,
|
|
706
|
+
max_depth: int = inf,
|
|
707
|
+
args: Arguments = NoArguments(),
|
|
708
|
+
) -> "FinalTransform":
|
|
709
|
+
return self
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
class _SharedMixin:
|
|
713
|
+
"""
|
|
714
|
+
Mixin for transforms that have parameters (e.g. random ones)
|
|
715
|
+
that may be shared across tensors and/or channels or independent
|
|
716
|
+
across tensors and/or channels.
|
|
717
|
+
"""
|
|
718
|
+
|
|
719
|
+
@classmethod
|
|
720
|
+
def _prepare_shared(cls, shared):
|
|
721
|
+
if shared is True:
|
|
722
|
+
shared = 'channels+tensors'
|
|
723
|
+
if shared is False:
|
|
724
|
+
shared = ''
|
|
725
|
+
return shared
|
|
726
|
+
|
|
727
|
+
def _xform(
|
|
728
|
+
self, x: Tensor, /, args: Arguments = NoArguments(),
|
|
729
|
+
) -> Returned:
|
|
730
|
+
template = x
|
|
731
|
+
if 'channels' in self.shared:
|
|
732
|
+
# Use the first channel only to compute the final transform
|
|
733
|
+
template = template[:1]
|
|
734
|
+
# Compute the next form of this transform
|
|
735
|
+
# NOTE: we do not use `max_depth=inf` because the `shared` option
|
|
736
|
+
# may differ across transformations in the hierarchy. For example,
|
|
737
|
+
# the top-level parameters of a transformation may be shared
|
|
738
|
+
# (e.g., the number of control points in a bias field), but not
|
|
739
|
+
# the lower level ones (e.g., the values of the control points).
|
|
740
|
+
transformation = self.next(template, args=args)
|
|
741
|
+
if transformation is self:
|
|
742
|
+
# Avoid infinite recursion. This should not happen.
|
|
743
|
+
raise ValueError(
|
|
744
|
+
f"The transform is not final, but calling `next` "
|
|
745
|
+
f"returned itself. Transform: {self}"
|
|
746
|
+
)
|
|
747
|
+
# Apply the final transform to all channels
|
|
748
|
+
return transformation.xform(x, args=args)
|
|
749
|
+
|
|
750
|
+
def forward(self, *a, **k) -> Returned:
|
|
751
|
+
return self._shared_forward(*a, **k)
|
|
752
|
+
|
|
753
|
+
def _shared_forward(self, *a, _fallback=None, **k) -> Returned:
|
|
754
|
+
_fallback = _fallback or super().forward
|
|
755
|
+
args = Arguments(*a, **k)
|
|
756
|
+
a, k = args.to_args_kwargs()
|
|
757
|
+
|
|
758
|
+
if 'tensors' in self.shared:
|
|
759
|
+
# Get the first valid tensor across all inputs, and use it
|
|
760
|
+
# as the template to compute the final transform.
|
|
761
|
+
first_tensor = get_first_element(
|
|
762
|
+
[a, k], include=self.include, exclude=self.exclude)
|
|
763
|
+
if 'channels' in self.shared:
|
|
764
|
+
# Get the first channel only, to compute the final transform.
|
|
765
|
+
first_tensor = first_tensor[:1]
|
|
766
|
+
# Compute the next form of this transform...
|
|
767
|
+
transform = self.next(first_tensor, args=args)
|
|
768
|
+
# ...and apply it to all tensors.
|
|
769
|
+
return transform(*a, **k)
|
|
770
|
+
|
|
771
|
+
# Else, we let `xform` deal with shared parameters across channels.
|
|
772
|
+
return _fallback(*a, **k)
|
|
773
|
+
|
|
774
|
+
def make_per_channel(
|
|
775
|
+
self, x: Tensor, /,
|
|
776
|
+
max_depth: int = float('inf'),
|
|
777
|
+
args: Arguments = NoArguments(),
|
|
778
|
+
**kwargs
|
|
779
|
+
) -> "PerChannelTransform":
|
|
780
|
+
prm = dict(self.get_prm())
|
|
781
|
+
prm.pop('shared', None)
|
|
782
|
+
return PerChannelTransform([
|
|
783
|
+
self.unroll(x[i:i+1], max_depth, args=args, **kwargs)
|
|
784
|
+
for i in range(len(x))
|
|
785
|
+
], **prm).unroll(x, max_depth-1)
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
class NonFinalTransform(_SharedMixin, Transform):
|
|
789
|
+
"""
|
|
790
|
+
Transforms whose parameters depend on features of the input
|
|
791
|
+
transform (shape, dtype, etc).
|
|
792
|
+
|
|
793
|
+
Non-final transforms implement `unroll`, and do not implement
|
|
794
|
+
`xform`. Their aim is to generate a more-specialized transform
|
|
795
|
+
at call time.
|
|
796
|
+
"""
|
|
797
|
+
def __init__(self, *, shared: bool = False, **kwargs) -> None:
|
|
798
|
+
"""
|
|
799
|
+
Parameters
|
|
800
|
+
----------
|
|
801
|
+
shared : {'channels', 'tensors', 'channels+tensor', ''} | bool
|
|
802
|
+
|
|
803
|
+
- `'channel'`: the same transform is applied to all channels
|
|
804
|
+
in a tensor, but different transforms are used in different
|
|
805
|
+
tensors.
|
|
806
|
+
- `'tensors'`: the same transform is applied to all tensors,
|
|
807
|
+
but with a different transform for each channel.
|
|
808
|
+
- `'channels+tensors'` or `True`: the same transform is applied
|
|
809
|
+
to all channels of all tensors.
|
|
810
|
+
- `''` or `False`: A different transform is applied to each
|
|
811
|
+
channel and each tensor.
|
|
812
|
+
"""
|
|
813
|
+
super().__init__(**kwargs)
|
|
814
|
+
self.shared = self._prepare_shared(shared)
|
|
815
|
+
|
|
816
|
+
|
|
817
|
+
class SpecialTransform(Transform):
|
|
818
|
+
"""Base class for transforms that act on other transforms.
|
|
819
|
+
|
|
820
|
+
Such transforms cannot be easily classified as "final" or "non-final",
|
|
821
|
+
because this characeteristic depends on the transforms that they embed.
|
|
822
|
+
|
|
823
|
+
They all implement `unroll`, but some may also implement a
|
|
824
|
+
"fast-track" `xform` that is applied in simple cases (e.g., when
|
|
825
|
+
the transform is not shared across tensors) for efficiency.
|
|
826
|
+
|
|
827
|
+
!!! addedin " \
|
|
828
|
+
Added `SpecialTransform` class in `v0.5`."
|
|
829
|
+
Before this, special transforms inherited directly from `Transform`.
|
|
830
|
+
"""
|
|
831
|
+
...
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
class IdentityTransform(FinalTransform):
|
|
835
|
+
"""Identity transform"""
|
|
836
|
+
|
|
837
|
+
def _xform(
|
|
838
|
+
self, x: Tensor, /, args: Arguments = NoArguments()
|
|
839
|
+
) -> Returned:
|
|
840
|
+
return prepare_output(dict(input=x, output=x), self.returns)
|
|
841
|
+
|
|
842
|
+
def make_inverse(self) -> "IdentityTransform":
|
|
843
|
+
return self
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
class SequentialTransform(_SharedMixin, SpecialTransform):
|
|
847
|
+
"""A sequence of transforms
|
|
848
|
+
|
|
849
|
+
!!! example
|
|
850
|
+
Sequences can be built explicitly, or simply by adding transforms
|
|
851
|
+
together:
|
|
852
|
+
```python
|
|
853
|
+
t1 = MultFieldTransform()
|
|
854
|
+
t2 = GaussianNoiseTransform()
|
|
855
|
+
seq = SequentialTransform([t1, t2]) # explicit
|
|
856
|
+
seq = t1 + t2 # implicit
|
|
857
|
+
```
|
|
858
|
+
|
|
859
|
+
Sequences can also be extended by addition:
|
|
860
|
+
```python
|
|
861
|
+
seq += SmoothTransform()
|
|
862
|
+
```
|
|
863
|
+
|
|
864
|
+
"""
|
|
865
|
+
|
|
866
|
+
def __init__(self, transforms: tx.Sequence[Transform], **kwargs) -> None:
|
|
867
|
+
"""
|
|
868
|
+
Parameters
|
|
869
|
+
----------
|
|
870
|
+
transforms : list[Transform]
|
|
871
|
+
A list of transforms to apply sequentially.
|
|
872
|
+
|
|
873
|
+
Other Parameters
|
|
874
|
+
------------------
|
|
875
|
+
shared
|
|
876
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
877
|
+
for details.
|
|
878
|
+
returns, append, prefix, include, exclude, consume
|
|
879
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
880
|
+
"""
|
|
881
|
+
shared = kwargs.pop('shared', False)
|
|
882
|
+
super().__init__(**kwargs)
|
|
883
|
+
self.shared = self._prepare_shared(shared)
|
|
884
|
+
self.transforms = transforms
|
|
885
|
+
|
|
886
|
+
def _unroll(
|
|
887
|
+
self, x: Tensor, /,
|
|
888
|
+
max_depth: int = inf,
|
|
889
|
+
args: Arguments = NoArguments()
|
|
890
|
+
) -> Transform:
|
|
891
|
+
if max_depth == 0:
|
|
892
|
+
return self
|
|
893
|
+
if self.is_final:
|
|
894
|
+
return self
|
|
895
|
+
# x = VirtualTensor.from_any(x, compute_stats=True)
|
|
896
|
+
trf = []
|
|
897
|
+
for t in self:
|
|
898
|
+
t = t.unroll(x, max_depth=max_depth-1, args=args)
|
|
899
|
+
x = t(x)
|
|
900
|
+
args = Arguments(x)
|
|
901
|
+
trf.append(t)
|
|
902
|
+
trf = SequentialTransform(trf, **self.get_prm())
|
|
903
|
+
return trf
|
|
904
|
+
|
|
905
|
+
@property
|
|
906
|
+
def is_final(self) -> bool:
|
|
907
|
+
return all(t.is_final for t in self)
|
|
908
|
+
|
|
909
|
+
def make_inverse(self) -> Transform:
|
|
910
|
+
return SequentialTransform([t.make_inverse() for t in reversed(self)])
|
|
911
|
+
|
|
912
|
+
def forward(self, *a, **k) -> Returned:
|
|
913
|
+
# If the entire sequence is shared across tensors, we use the
|
|
914
|
+
# behavior from `SharedMixin`, which is to call `unroll` on
|
|
915
|
+
# the first valid tensor, and apply the resulting transform to all
|
|
916
|
+
# tensors.
|
|
917
|
+
# Finalizing a sequence of transforms is a bit tricky, but sequences
|
|
918
|
+
# are not shared by default, so this should rarely be used.
|
|
919
|
+
# If the sequence is not shared (or onlt shared across channels),
|
|
920
|
+
# we simply apply the transforms sequentially.
|
|
921
|
+
return self._shared_forward(*a, **k, _fallback=self._forward_impl)
|
|
922
|
+
|
|
923
|
+
def _forward_impl(self, *args, **kwargs):
|
|
924
|
+
x = Arguments(*args, **kwargs)
|
|
925
|
+
for trf in self:
|
|
926
|
+
# NOTE:
|
|
927
|
+
# I do not propagate `returns`, as I don't think it makes
|
|
928
|
+
# sense for sequences.
|
|
929
|
+
with \
|
|
930
|
+
IncludeKeysTransform(trf, self.include), \
|
|
931
|
+
ExcludeKeysTransform(trf, self.exclude), \
|
|
932
|
+
ConsumeKeysTransform(trf, self.consume):
|
|
933
|
+
x = trf(x)
|
|
934
|
+
return x
|
|
935
|
+
|
|
936
|
+
def _xform(
|
|
937
|
+
self, x: Tensor, /, args: Arguments = NoArguments()
|
|
938
|
+
) -> Returned:
|
|
939
|
+
# This should only be called when a Layer's `unroll` returns
|
|
940
|
+
# a `SequentialTransform` (i.e., it is created implictly under
|
|
941
|
+
# the hood, not explicitly by the user).
|
|
942
|
+
# In such cases, `shared=False` and hopefully we can just fallback
|
|
943
|
+
# to `forward()`.
|
|
944
|
+
#
|
|
945
|
+
# FIXME
|
|
946
|
+
# what happens if there's weird stuff in returns/include/exclude?
|
|
947
|
+
return self(x)
|
|
948
|
+
|
|
949
|
+
def __len__(self) -> int:
|
|
950
|
+
return len(self.transforms)
|
|
951
|
+
|
|
952
|
+
def __iter__(self) -> tx.Iterator[Transform]:
|
|
953
|
+
for t in self.transforms:
|
|
954
|
+
yield t
|
|
955
|
+
|
|
956
|
+
def __getitem__(self, item: tx.Union[int, slice]) -> Transform:
|
|
957
|
+
if isinstance(item, slice):
|
|
958
|
+
return SequentialTransform(self.transforms[item])
|
|
959
|
+
else:
|
|
960
|
+
return self.transforms[item]
|
|
961
|
+
|
|
962
|
+
def __repr__(self) -> str:
|
|
963
|
+
return f'{type(self).__name__}({repr(self.transforms)})'
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
class PerChannelTransform(SpecialTransform):
|
|
967
|
+
"""Apply a different transform to each channel"""
|
|
968
|
+
|
|
969
|
+
def __init__(self, transforms: tx.Sequence[Transform], **kwargs) -> None:
|
|
970
|
+
"""
|
|
971
|
+
Parameters
|
|
972
|
+
----------
|
|
973
|
+
transforms : list[Transform]
|
|
974
|
+
A list of transforms to apply to each channel.
|
|
975
|
+
"""
|
|
976
|
+
super().__init__(**kwargs)
|
|
977
|
+
self.transforms = transforms
|
|
978
|
+
|
|
979
|
+
def _unroll(
|
|
980
|
+
self, x: Tensor, /,
|
|
981
|
+
max_depth: int = inf,
|
|
982
|
+
args: Arguments = NoArguments()
|
|
983
|
+
) -> Transform:
|
|
984
|
+
if max_depth == 0:
|
|
985
|
+
return self
|
|
986
|
+
trf = []
|
|
987
|
+
for i, t in enumerate(self.transforms):
|
|
988
|
+
if (
|
|
989
|
+
self.include is not None or
|
|
990
|
+
self.exclude or self.consume or self.returns
|
|
991
|
+
):
|
|
992
|
+
# NOTE
|
|
993
|
+
# We cannot use context managers because they exit on
|
|
994
|
+
# return. Instead, we make a shallow copy of the
|
|
995
|
+
# transform and change its options. It is not an issue
|
|
996
|
+
# in most cases, as `unroll` often creates a new
|
|
997
|
+
# transform, but can be one when `max_depth < 2`.
|
|
998
|
+
t = copy(t)
|
|
999
|
+
t.exclude = IncludeKeysTransform.combine(self.include, t.include)
|
|
1000
|
+
t.include = ExcludeKeysTransform.combine(self.exclude, t.exclude)
|
|
1001
|
+
t.consume = ConsumeKeysTransform.combine(self.consume, t.consume)
|
|
1002
|
+
if self.returns:
|
|
1003
|
+
t.returns = self.returns
|
|
1004
|
+
t = t.unroll(x[i:i+1], max_depth-1, args=args)
|
|
1005
|
+
trf.append(t)
|
|
1006
|
+
prm = dict(self.get_prm())
|
|
1007
|
+
prm.pop('shared', None)
|
|
1008
|
+
trf = PerChannelTransform(trf, **prm)
|
|
1009
|
+
return trf
|
|
1010
|
+
|
|
1011
|
+
def _xform(
|
|
1012
|
+
self, x: Tensor, /, args: Arguments = NoArguments()
|
|
1013
|
+
) -> Returned:
|
|
1014
|
+
results = []
|
|
1015
|
+
for i, t in enumerate(self.transforms):
|
|
1016
|
+
with \
|
|
1017
|
+
ReturningTransform(t, self.returns), \
|
|
1018
|
+
IncludeKeysTransform(t, self.include), \
|
|
1019
|
+
ExcludeKeysTransform(t, self.exclude), \
|
|
1020
|
+
ConsumeKeysTransform(t, self.consume):
|
|
1021
|
+
results.append(t(x[i:i+1]))
|
|
1022
|
+
return Returned(recursive_cat(results))
|
|
1023
|
+
|
|
1024
|
+
@property
|
|
1025
|
+
def is_final(self) -> bool:
|
|
1026
|
+
return all(t.is_final for t in self.transforms)
|
|
1027
|
+
|
|
1028
|
+
|
|
1029
|
+
class MaybeTransform(_SharedMixin, SpecialTransform):
|
|
1030
|
+
"""Randomly apply a transform
|
|
1031
|
+
|
|
1032
|
+
!!! note "[`ctx.maybe`][cornucopia.ctx.maybe] is an alias for [`MaybeTransform`][cornucopia.special.MaybeTransform]"
|
|
1033
|
+
|
|
1034
|
+
!!! example "20% chance of adding noise"
|
|
1035
|
+
```python
|
|
1036
|
+
import cornucopia as cc
|
|
1037
|
+
gauss = cc.GaussianNoiseTransform()
|
|
1038
|
+
```
|
|
1039
|
+
Explicit call to the class:
|
|
1040
|
+
```python
|
|
1041
|
+
img = cc.MaybeTransform(gauss, 0.2)(img)
|
|
1042
|
+
```
|
|
1043
|
+
Implicit call using syntactic sugar:
|
|
1044
|
+
```python
|
|
1045
|
+
img = (0.2 * gauss)(img)
|
|
1046
|
+
```
|
|
1047
|
+
```
|
|
1048
|
+
|
|
1049
|
+
!!! changedin " \
|
|
1050
|
+
Default for `shared` changed from `False` to `True`"
|
|
1051
|
+
"""
|
|
1052
|
+
def __init__(
|
|
1053
|
+
self,
|
|
1054
|
+
transform: Transform,
|
|
1055
|
+
prob: float = 0.5,
|
|
1056
|
+
*,
|
|
1057
|
+
shared: bool = True,
|
|
1058
|
+
**kwargs
|
|
1059
|
+
) -> None:
|
|
1060
|
+
"""
|
|
1061
|
+
|
|
1062
|
+
Parameters
|
|
1063
|
+
----------
|
|
1064
|
+
transform : Transform
|
|
1065
|
+
A transform to randomly apply
|
|
1066
|
+
prob : float
|
|
1067
|
+
Probability to apply the transform
|
|
1068
|
+
shared : {'channels', 'tensors', 'channels+tensor', ''} | bool
|
|
1069
|
+
Roll the dice once for all input tensors
|
|
1070
|
+
"""
|
|
1071
|
+
super().__init__(**kwargs)
|
|
1072
|
+
self.shared = self._prepare_shared(shared)
|
|
1073
|
+
self.subtransform = transform
|
|
1074
|
+
self.prob = prob
|
|
1075
|
+
|
|
1076
|
+
def throw_dice(self) -> bool:
|
|
1077
|
+
return random.random() > 1 - self.prob
|
|
1078
|
+
|
|
1079
|
+
def _unroll(
|
|
1080
|
+
self, x: Tensor, /,
|
|
1081
|
+
max_depth: int = float('inf'),
|
|
1082
|
+
args: Arguments = NoArguments()
|
|
1083
|
+
) -> Transform:
|
|
1084
|
+
if max_depth == 0:
|
|
1085
|
+
return self
|
|
1086
|
+
if self.throw_dice():
|
|
1087
|
+
trf = self.subtransform
|
|
1088
|
+
if self.include is not None or self.exclude or self.consume:
|
|
1089
|
+
# NOTE
|
|
1090
|
+
# * I do not use context managers as they exit on return.
|
|
1091
|
+
# Context managers would work in most cases, as
|
|
1092
|
+
# `unroll` often creates a new transform, but it
|
|
1093
|
+
# can be a problem when `max_depth<2`. Better safe
|
|
1094
|
+
# than sorry,
|
|
1095
|
+
# * I do not propagate `returns`. I think it should be
|
|
1096
|
+
# dealt with by the subtransform.
|
|
1097
|
+
trf = copy(trf)
|
|
1098
|
+
trf.include = IncludeKeysTransform._combine(self.include, trf.include)
|
|
1099
|
+
trf.exclude = ExcludeKeysTransform._combine(self.exclude, trf.exclude)
|
|
1100
|
+
trf.consume = ConsumeKeysTransform._combine(self.consume, trf.consume)
|
|
1101
|
+
return trf.unroll(x, max_depth-1, args=args)
|
|
1102
|
+
else:
|
|
1103
|
+
return IdentityTransform(consume=self.consume)
|
|
1104
|
+
|
|
1105
|
+
def __repr__(self) -> str:
|
|
1106
|
+
s = f'{repr(self.subtransform)}?'
|
|
1107
|
+
if self.prob != 0.5:
|
|
1108
|
+
s += f'[{self.prob}]'
|
|
1109
|
+
return s
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
class SwitchTransform(_SharedMixin, SpecialTransform):
|
|
1113
|
+
"""Randomly choose a transform to apply
|
|
1114
|
+
|
|
1115
|
+
!!! note "[`ctx.switch`][cornucopia.ctx.switch] is an alias for [`SwitchTransform`][cornucopia.special.SwitchTransform]"
|
|
1116
|
+
|
|
1117
|
+
!!! example "Randomly apply either Gaussian or Chi noise"
|
|
1118
|
+
```python
|
|
1119
|
+
import cornucopia as cc
|
|
1120
|
+
gauss = cc.GaussianNoiseTransform()
|
|
1121
|
+
chi = cc.ChiNoiseTransform()
|
|
1122
|
+
```
|
|
1123
|
+
Explicit call to the class:
|
|
1124
|
+
```python
|
|
1125
|
+
img = cc.SwitchTransform([gauss, chi])(img)
|
|
1126
|
+
```
|
|
1127
|
+
Implicit call using syntactic sugar:
|
|
1128
|
+
```python
|
|
1129
|
+
img = (gauss | chi)(img)
|
|
1130
|
+
```
|
|
1131
|
+
Functional call:
|
|
1132
|
+
```python
|
|
1133
|
+
img = cc.switch({gauss: 0.5, chi: 0.5})(img)
|
|
1134
|
+
```
|
|
1135
|
+
|
|
1136
|
+
!!! changedin " \
|
|
1137
|
+
Default for `shared` changed from `False` to `True`"
|
|
1138
|
+
"""
|
|
1139
|
+
|
|
1140
|
+
def __init__(
|
|
1141
|
+
self,
|
|
1142
|
+
transforms: tx.Sequence[Transform],
|
|
1143
|
+
prob: cct.ScalarOrSequence[float] = 0,
|
|
1144
|
+
*,
|
|
1145
|
+
shared: cct.SharedT = True,
|
|
1146
|
+
**kwargs
|
|
1147
|
+
) -> None:
|
|
1148
|
+
"""
|
|
1149
|
+
|
|
1150
|
+
Parameters
|
|
1151
|
+
----------
|
|
1152
|
+
transforms : list[Transform]
|
|
1153
|
+
A list of transforms to sample from
|
|
1154
|
+
prob : list[float]
|
|
1155
|
+
Probability of applying each transform
|
|
1156
|
+
shared : {'channels', 'tensors', 'channels+tensor', ''} | bool
|
|
1157
|
+
Roll the dice once for all input tensors
|
|
1158
|
+
"""
|
|
1159
|
+
super().__init__(**kwargs)
|
|
1160
|
+
if isinstance(transforms, dict):
|
|
1161
|
+
if prob:
|
|
1162
|
+
raise ValueError(
|
|
1163
|
+
"When `transforms` is a dict, `prob` should not be provided."
|
|
1164
|
+
)
|
|
1165
|
+
prob = list(transforms.values())
|
|
1166
|
+
transforms = list(transforms.keys())
|
|
1167
|
+
self.shared = self._prepare_shared(shared)
|
|
1168
|
+
self.transforms = list(transforms)
|
|
1169
|
+
self.prob = prob or []
|
|
1170
|
+
|
|
1171
|
+
def _make_prob(self) -> tx.Sequence[float]:
|
|
1172
|
+
prob = ensure_list(self.prob, len(self.transforms), default=0)
|
|
1173
|
+
sumprob = sum(prob)
|
|
1174
|
+
if not sumprob:
|
|
1175
|
+
prob = [1/len(self.transforms)] * len(self.transforms)
|
|
1176
|
+
else:
|
|
1177
|
+
prob = [x / sumprob for x in prob]
|
|
1178
|
+
return prob
|
|
1179
|
+
|
|
1180
|
+
def throw_dice(self) -> int:
|
|
1181
|
+
prob = cumsum(self._make_prob())
|
|
1182
|
+
dice = random.random()
|
|
1183
|
+
for k in range(len(self.transforms)):
|
|
1184
|
+
if dice > 1 - prob[k]:
|
|
1185
|
+
return k
|
|
1186
|
+
return len(self.transforms) - 1
|
|
1187
|
+
|
|
1188
|
+
def _unroll(
|
|
1189
|
+
self, x: Tensor, /,
|
|
1190
|
+
max_depth: int = float('inf'),
|
|
1191
|
+
args: Arguments = NoArguments()
|
|
1192
|
+
) -> Transform:
|
|
1193
|
+
if max_depth == 0:
|
|
1194
|
+
return self
|
|
1195
|
+
t = self.transforms[self.throw_dice()]
|
|
1196
|
+
t = t.unroll(x, max_depth-1, args=args)
|
|
1197
|
+
if self.include is not None or self.exclude or self.consume:
|
|
1198
|
+
# NOTE
|
|
1199
|
+
# We cannot use the context manager because it exits on
|
|
1200
|
+
# return. Instead, we make a shallow copy of the transform
|
|
1201
|
+
# and change its options.
|
|
1202
|
+
t = copy(t)
|
|
1203
|
+
t.include = IncludeKeysTransform._combine(self.include, t.include)
|
|
1204
|
+
t.exclude = ExcludeKeysTransform._combine(self.exclude, t.exclude)
|
|
1205
|
+
t.consume = ConsumeKeysTransform._combine(self.consume, t.consume)
|
|
1206
|
+
return t
|
|
1207
|
+
|
|
1208
|
+
def __or__(self, other: Transform) -> "SwitchTransform":
|
|
1209
|
+
if (isinstance(other, SwitchTransform)
|
|
1210
|
+
and not self.prob and not other.prob
|
|
1211
|
+
):
|
|
1212
|
+
return SwitchTransform([*self.transforms, *other.transforms])
|
|
1213
|
+
else:
|
|
1214
|
+
return SwitchTransform([self, other])
|
|
1215
|
+
|
|
1216
|
+
def __ior__(self, other: Transform) -> "SwitchTransform":
|
|
1217
|
+
if (isinstance(other, SwitchTransform)
|
|
1218
|
+
and not self.prob and not other.prob
|
|
1219
|
+
):
|
|
1220
|
+
self.transforms.append(other.transforms)
|
|
1221
|
+
return self
|
|
1222
|
+
else:
|
|
1223
|
+
return SwitchTransform([self, other])
|
|
1224
|
+
|
|
1225
|
+
def __repr__(self) -> str:
|
|
1226
|
+
if self.prob:
|
|
1227
|
+
prob = self._make_prob()
|
|
1228
|
+
s = [f'{p} * {str(t)}' for p, t in zip(prob, self.transforms)]
|
|
1229
|
+
else:
|
|
1230
|
+
s = [str(t) for t in self.transforms]
|
|
1231
|
+
s = ' | '.join(s)
|
|
1232
|
+
s = f'({s})'
|
|
1233
|
+
return s
|
|
1234
|
+
|
|
1235
|
+
|
|
1236
|
+
class IncludeKeysTransform(SpecialTransform):
|
|
1237
|
+
"""
|
|
1238
|
+
Context manager for keys to include
|
|
1239
|
+
|
|
1240
|
+
!!! note "[`ctx.include`][cornucopia.ctx.include] is an alias for [`IncludeKeysTransform`][cornucopia.special.IncludeKeysTransform]"
|
|
1241
|
+
|
|
1242
|
+
!!! example "Use as a transform"
|
|
1243
|
+
```python
|
|
1244
|
+
from cornucopia import IncludeKeysTransform
|
|
1245
|
+
newxform = IncludeKeysTransform(xform, "image)
|
|
1246
|
+
image, label = newxform(image=image, label=label)
|
|
1247
|
+
```
|
|
1248
|
+
|
|
1249
|
+
!!! example "Use as a context manager `with as`"
|
|
1250
|
+
```python
|
|
1251
|
+
from cornucopia import IncludeKeysTransform
|
|
1252
|
+
with IncludeKeysTransform(xform, "image") as newxform:
|
|
1253
|
+
image, label = newxform(image=image, label=label)
|
|
1254
|
+
```
|
|
1255
|
+
|
|
1256
|
+
!!! example "Use as a context manager `with`"
|
|
1257
|
+
```python
|
|
1258
|
+
from cornucopia import IncludeKeysTransform
|
|
1259
|
+
with IncludeKeysTransform(xform, "image"):
|
|
1260
|
+
image, label = xform(image=image, label=label)
|
|
1261
|
+
```
|
|
1262
|
+
|
|
1263
|
+
!!! example "Use as a context manager (alias)"
|
|
1264
|
+
```python
|
|
1265
|
+
from cornucopia import ctx
|
|
1266
|
+
with ctx.include(xform, "image") as newxform:
|
|
1267
|
+
image, label = newxform(image=image, label=label)
|
|
1268
|
+
```
|
|
1269
|
+
"""
|
|
1270
|
+
|
|
1271
|
+
def __init__(
|
|
1272
|
+
self,
|
|
1273
|
+
transform: Transform,
|
|
1274
|
+
keys: cct.IncludeT,
|
|
1275
|
+
union: bool = True
|
|
1276
|
+
) -> None:
|
|
1277
|
+
"""
|
|
1278
|
+
Parameters
|
|
1279
|
+
----------
|
|
1280
|
+
transform : Transform
|
|
1281
|
+
Transform to apply
|
|
1282
|
+
keys : [sequence of] str
|
|
1283
|
+
Keys to include
|
|
1284
|
+
union : bool
|
|
1285
|
+
Include the union of what was already included and `keys`
|
|
1286
|
+
"""
|
|
1287
|
+
super().__init__()
|
|
1288
|
+
if keys is not None:
|
|
1289
|
+
keys = ensure_list(keys)
|
|
1290
|
+
self.transform = transform
|
|
1291
|
+
self.keys = keys
|
|
1292
|
+
self.union = union
|
|
1293
|
+
|
|
1294
|
+
def forward(self, *a, **k) -> Returned:
|
|
1295
|
+
with self as transform:
|
|
1296
|
+
return transform.forward(*a, **k)
|
|
1297
|
+
|
|
1298
|
+
def _unroll(
|
|
1299
|
+
self, x: Tensor, /,
|
|
1300
|
+
max_depth: int = float('inf'),
|
|
1301
|
+
args: Arguments = NoArguments()
|
|
1302
|
+
) -> Transform:
|
|
1303
|
+
if max_depth == 0:
|
|
1304
|
+
return self
|
|
1305
|
+
with self as trf:
|
|
1306
|
+
final_trf = trf.unroll(x, max_depth, args=args)
|
|
1307
|
+
with IncludeKeysTransform(final_trf) as final_final_trf:
|
|
1308
|
+
return final_final_trf
|
|
1309
|
+
|
|
1310
|
+
def make_inverse(self) -> Transform:
|
|
1311
|
+
with self as trf:
|
|
1312
|
+
inv_trf = trf.make_inverse()
|
|
1313
|
+
with IncludeKeysTransform(inv_trf) as final_inv_trf:
|
|
1314
|
+
return final_inv_trf
|
|
1315
|
+
|
|
1316
|
+
@classmethod
|
|
1317
|
+
def _combine(self, *includes, union: bool = True) -> tx.Sequence[str]:
|
|
1318
|
+
new_include, *includes = includes
|
|
1319
|
+
if union:
|
|
1320
|
+
for include in includes:
|
|
1321
|
+
if include is not None:
|
|
1322
|
+
if new_include is None:
|
|
1323
|
+
new_include = []
|
|
1324
|
+
new_include.extend(include)
|
|
1325
|
+
if new_include is not None:
|
|
1326
|
+
new_include = list(set(new_include))
|
|
1327
|
+
return new_include
|
|
1328
|
+
|
|
1329
|
+
def __enter__(self) -> Transform:
|
|
1330
|
+
old_include = self.transform.include
|
|
1331
|
+
new_include = self.keys
|
|
1332
|
+
new_include = self._combine(new_include, old_include, union=self.union)
|
|
1333
|
+
self.transform.include, self.include = new_include, old_include
|
|
1334
|
+
return self.transform
|
|
1335
|
+
|
|
1336
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
1337
|
+
self.transform.include = self.include
|
|
1338
|
+
delattr(self, 'include')
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
class ExcludeKeysTransform(SpecialTransform):
|
|
1342
|
+
"""
|
|
1343
|
+
Context manager for keys to exclude.
|
|
1344
|
+
Can also be used as a transform.
|
|
1345
|
+
|
|
1346
|
+
!!! note "[`ctx.exclude`][cornucopia.ctx.exclude] is an alias for [`ExcludeKeysTransform`][cornucopia.special.ExcludeKeysTransform]"
|
|
1347
|
+
|
|
1348
|
+
!!! example "Use as a transform"
|
|
1349
|
+
```python
|
|
1350
|
+
from cornucopia import ExcludeKeysTransform
|
|
1351
|
+
newxform = ExcludeKeysTransform(xform, "image)
|
|
1352
|
+
image, label = newxform(image=image, label=label)
|
|
1353
|
+
```
|
|
1354
|
+
|
|
1355
|
+
!!! example "Use as a context manager `with as`"
|
|
1356
|
+
```python
|
|
1357
|
+
from cornucopia import ExcludeKeysTransform
|
|
1358
|
+
with ExcludeKeysTransform(xform, "image") as newxform:
|
|
1359
|
+
image, label = newxform(image=image, label=label)
|
|
1360
|
+
```
|
|
1361
|
+
|
|
1362
|
+
!!! example "Use as a context manager `with`"
|
|
1363
|
+
```python
|
|
1364
|
+
from cornucopia import ExcludeKeysTransform
|
|
1365
|
+
with ExcludeKeysTransform(xform, "image"):
|
|
1366
|
+
image, label = xform(image=image, label=label)
|
|
1367
|
+
```
|
|
1368
|
+
|
|
1369
|
+
!!! example "Use as a context manager (alias)"
|
|
1370
|
+
```python
|
|
1371
|
+
from cornucopia import ctx
|
|
1372
|
+
with ctx.exclude(xform, "image") as newxform:
|
|
1373
|
+
image, label = newxform(image=image, label=label)
|
|
1374
|
+
```
|
|
1375
|
+
"""
|
|
1376
|
+
|
|
1377
|
+
def __init__(
|
|
1378
|
+
self,
|
|
1379
|
+
transform: Transform,
|
|
1380
|
+
keys: cct.ExcludeT,
|
|
1381
|
+
union: bool = True
|
|
1382
|
+
) -> None:
|
|
1383
|
+
"""
|
|
1384
|
+
Parameters
|
|
1385
|
+
----------
|
|
1386
|
+
transform : Transform
|
|
1387
|
+
Transform to apply
|
|
1388
|
+
keys : [sequence of] str
|
|
1389
|
+
Keys to include
|
|
1390
|
+
union : bool
|
|
1391
|
+
Exclude the union of what was already excluded and `keys`
|
|
1392
|
+
"""
|
|
1393
|
+
super().__init__()
|
|
1394
|
+
if keys is not None:
|
|
1395
|
+
keys = ensure_list(keys)
|
|
1396
|
+
self.transform = transform
|
|
1397
|
+
self.keys = keys
|
|
1398
|
+
self.union = union
|
|
1399
|
+
|
|
1400
|
+
def forward(self, *a, **k) -> Returned:
|
|
1401
|
+
with self as transform:
|
|
1402
|
+
return transform.forward(*a, **k)
|
|
1403
|
+
|
|
1404
|
+
def _unroll(
|
|
1405
|
+
self, x: Tensor, /,
|
|
1406
|
+
max_depth: int = float('inf'),
|
|
1407
|
+
args: Arguments = NoArguments()
|
|
1408
|
+
) -> Transform:
|
|
1409
|
+
if max_depth == 0:
|
|
1410
|
+
return self
|
|
1411
|
+
with self as trf:
|
|
1412
|
+
final_trf = trf.unroll(x, max_depth, args=args)
|
|
1413
|
+
with ExcludeKeysTransform(final_trf) as final_final_trf:
|
|
1414
|
+
return final_final_trf
|
|
1415
|
+
|
|
1416
|
+
def make_inverse(self) -> Transform:
|
|
1417
|
+
with self as trf:
|
|
1418
|
+
inv_trf = trf.make_inverse()
|
|
1419
|
+
with ExcludeKeysTransform(inv_trf) as final_inv_trf:
|
|
1420
|
+
return final_inv_trf
|
|
1421
|
+
|
|
1422
|
+
@classmethod
|
|
1423
|
+
def _combine(self, *excludes, union: bool = True) -> tx.Sequence[str]:
|
|
1424
|
+
new_exclude, *excludes = excludes
|
|
1425
|
+
if union:
|
|
1426
|
+
for exclude in excludes:
|
|
1427
|
+
if exclude is not None:
|
|
1428
|
+
if new_exclude is None:
|
|
1429
|
+
new_exclude = []
|
|
1430
|
+
new_exclude.extend(exclude)
|
|
1431
|
+
if new_exclude is not None:
|
|
1432
|
+
new_exclude = list(set(new_exclude))
|
|
1433
|
+
return new_exclude
|
|
1434
|
+
|
|
1435
|
+
def __enter__(self) -> Transform:
|
|
1436
|
+
old_exclude = self.transform.exclude
|
|
1437
|
+
new_exclude = self.keys
|
|
1438
|
+
new_exclude = self._combine(new_exclude, old_exclude, union=self.union)
|
|
1439
|
+
self.transform.exclude, self.exclude = new_exclude, old_exclude
|
|
1440
|
+
return self.transform
|
|
1441
|
+
|
|
1442
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
1443
|
+
self.transform.exclude = self.exclude
|
|
1444
|
+
delattr(self, 'exclude')
|
|
1445
|
+
|
|
1446
|
+
|
|
1447
|
+
class ConsumeKeysTransform(SpecialTransform):
|
|
1448
|
+
"""
|
|
1449
|
+
Context manager for keys to consume.
|
|
1450
|
+
Can also be used as a transform.
|
|
1451
|
+
|
|
1452
|
+
!!! note "[`ctx.consume`][cornucopia.ctx.consume] is an alias for [`ConsumeKeysTransform`][cornucopia.special.ConsumeKeysTransform]"
|
|
1453
|
+
|
|
1454
|
+
!!! example "Use as a transform"
|
|
1455
|
+
```python
|
|
1456
|
+
from cornucopia import ConsumeKeysTransform
|
|
1457
|
+
newxform = ConsumeKeysTransform(xform, "image)
|
|
1458
|
+
label = newxform(image=image, label=label)
|
|
1459
|
+
```
|
|
1460
|
+
|
|
1461
|
+
!!! example "Use as a context manager `with as`"
|
|
1462
|
+
```python
|
|
1463
|
+
from cornucopia import ConsumeKeysTransform
|
|
1464
|
+
with ConsumeKeysTransform(xform, "image") as newxform:
|
|
1465
|
+
label = newxform(image=image, label=label)
|
|
1466
|
+
```
|
|
1467
|
+
|
|
1468
|
+
!!! example "Use as a context manager `with`"
|
|
1469
|
+
```python
|
|
1470
|
+
from cornucopia import ConsumeKeysTransform
|
|
1471
|
+
with ConsumeKeysTransform(xform, "image"):
|
|
1472
|
+
label = xform(image=image, label=label)
|
|
1473
|
+
```
|
|
1474
|
+
|
|
1475
|
+
!!! example "Use as a context manager (alias)"
|
|
1476
|
+
```python
|
|
1477
|
+
from cornucopia import ctx
|
|
1478
|
+
with ctx.consume(xform, "image") as newxform:
|
|
1479
|
+
label = newxform(image=image, label=label)
|
|
1480
|
+
```
|
|
1481
|
+
|
|
1482
|
+
!!! addedin " \
|
|
1483
|
+
Added in `v0.5`"
|
|
1484
|
+
"""
|
|
1485
|
+
|
|
1486
|
+
def __init__(
|
|
1487
|
+
self,
|
|
1488
|
+
transform: Transform,
|
|
1489
|
+
keys: cct.ConsumeT,
|
|
1490
|
+
union: bool = True
|
|
1491
|
+
) -> None:
|
|
1492
|
+
"""
|
|
1493
|
+
Parameters
|
|
1494
|
+
----------
|
|
1495
|
+
transform : Transform
|
|
1496
|
+
Transform to apply
|
|
1497
|
+
keys : [sequence of] str
|
|
1498
|
+
Keys to include
|
|
1499
|
+
union : bool
|
|
1500
|
+
Consume the union of what was already consumed and `keys`
|
|
1501
|
+
"""
|
|
1502
|
+
super().__init__()
|
|
1503
|
+
if keys is not None:
|
|
1504
|
+
keys = ensure_list(keys)
|
|
1505
|
+
self.transform = transform
|
|
1506
|
+
self.keys = keys
|
|
1507
|
+
self.union = union
|
|
1508
|
+
|
|
1509
|
+
def forward(self, *a, **k) -> Returned:
|
|
1510
|
+
with self as transform:
|
|
1511
|
+
return transform.forward(*a, **k)
|
|
1512
|
+
|
|
1513
|
+
def _unroll(
|
|
1514
|
+
self, x: Tensor, /,
|
|
1515
|
+
max_depth: int = float('inf'),
|
|
1516
|
+
args: Arguments = NoArguments()
|
|
1517
|
+
) -> Transform:
|
|
1518
|
+
if max_depth == 0:
|
|
1519
|
+
return self
|
|
1520
|
+
with self as trf:
|
|
1521
|
+
final_trf = trf.unroll(x, max_depth, args=args)
|
|
1522
|
+
with ConsumeKeysTransform(final_trf) as final_final_trf:
|
|
1523
|
+
return final_final_trf
|
|
1524
|
+
|
|
1525
|
+
def make_inverse(self) -> Transform:
|
|
1526
|
+
with self as trf:
|
|
1527
|
+
inv_trf = trf.make_inverse()
|
|
1528
|
+
with ConsumeKeysTransform(inv_trf) as final_inv_trf:
|
|
1529
|
+
return final_inv_trf
|
|
1530
|
+
|
|
1531
|
+
@classmethod
|
|
1532
|
+
def _combine(self, *consumes, union: bool = True) -> tx.Sequence[str]:
|
|
1533
|
+
new_consume, *consumes = consumes
|
|
1534
|
+
if union:
|
|
1535
|
+
for consume in consumes:
|
|
1536
|
+
if consume is not None:
|
|
1537
|
+
if new_consume is None:
|
|
1538
|
+
new_consume = []
|
|
1539
|
+
new_consume.extend(consume)
|
|
1540
|
+
if new_consume is not None:
|
|
1541
|
+
new_consume = list(set(new_consume))
|
|
1542
|
+
return new_consume
|
|
1543
|
+
|
|
1544
|
+
def __enter__(self) -> Transform:
|
|
1545
|
+
old_consume = self.transform.consume
|
|
1546
|
+
new_consume = self.keys
|
|
1547
|
+
new_consume = self._combine(new_consume, old_consume, union=self.union)
|
|
1548
|
+
self.transform.consume, self.consume = new_consume, old_consume
|
|
1549
|
+
return self.transform
|
|
1550
|
+
|
|
1551
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
1552
|
+
self.transform.consume = self.consume
|
|
1553
|
+
delattr(self, 'consume')
|
|
1554
|
+
|
|
1555
|
+
|
|
1556
|
+
class SharedTransform(_SharedMixin, SpecialTransform):
|
|
1557
|
+
"""
|
|
1558
|
+
Context manager for sharing transforms across channels / tensors.
|
|
1559
|
+
Can also be used as a transform.
|
|
1560
|
+
|
|
1561
|
+
!!! note "[`ctx.shared`][cornucopia.ctx.shared] is an alias for [`SharedTransform`][cornucopia.special.SharedTransform]"
|
|
1562
|
+
|
|
1563
|
+
!!! example "Use as a context manager (alias)"
|
|
1564
|
+
```python
|
|
1565
|
+
from cornucopia import ctx
|
|
1566
|
+
with ctx.shared(xform, "channels") as newxform:
|
|
1567
|
+
image = newxform(image)
|
|
1568
|
+
```
|
|
1569
|
+
"""
|
|
1570
|
+
|
|
1571
|
+
def __init__(
|
|
1572
|
+
self, transform: Transform, mode: cct.SharedT = UNSET
|
|
1573
|
+
) -> None:
|
|
1574
|
+
"""
|
|
1575
|
+
Parameters
|
|
1576
|
+
----------
|
|
1577
|
+
transform : Transform
|
|
1578
|
+
Transform to apply
|
|
1579
|
+
mode : {'channels', 'tensors', 'channels+tensor', ''} | bool
|
|
1580
|
+
|
|
1581
|
+
- `'channel'`: the same transform is applied to all channels
|
|
1582
|
+
in a tensor, but different transforms are used in different
|
|
1583
|
+
tensors.
|
|
1584
|
+
- `'tensors'`: the same transform is applied to all tensors,
|
|
1585
|
+
but with a different transform for each channel.
|
|
1586
|
+
- `'channels+tensors'` or `True`: the same transform is applied
|
|
1587
|
+
to all channels of all tensors.
|
|
1588
|
+
- `''` or `False`: A different transform is applied to each
|
|
1589
|
+
channel and each tensor.
|
|
1590
|
+
|
|
1591
|
+
"""
|
|
1592
|
+
super().__init__()
|
|
1593
|
+
self.transform = transform
|
|
1594
|
+
self.mode = mode
|
|
1595
|
+
|
|
1596
|
+
def forward(self, *a, **k) -> Returned:
|
|
1597
|
+
with self as transform:
|
|
1598
|
+
return transform.forward(*a, **k)
|
|
1599
|
+
|
|
1600
|
+
def __enter__(self) -> Transform:
|
|
1601
|
+
self.hasattr = hasattr(self.transform, 'shared')
|
|
1602
|
+
self.saved_mode = getattr(self.transform, 'shared', None)
|
|
1603
|
+
if self.mode is not UNSET:
|
|
1604
|
+
self.transform.shared = self.mode
|
|
1605
|
+
return self.transform
|
|
1606
|
+
|
|
1607
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
1608
|
+
if self.hasattr:
|
|
1609
|
+
self.transform.shared = self.saved_mode
|
|
1610
|
+
elif hasattr(self.transform, 'shared'):
|
|
1611
|
+
delattr(self.transform, 'shared')
|
|
1612
|
+
delattr(self, 'hasattr')
|
|
1613
|
+
delattr(self, 'saved_mode')
|
|
1614
|
+
|
|
1615
|
+
|
|
1616
|
+
class ReturningTransform(SpecialTransform):
|
|
1617
|
+
"""
|
|
1618
|
+
Context manager for sharing transforms across channels / tensors
|
|
1619
|
+
|
|
1620
|
+
!!! note "[`ctx.returns`][cornucopia.ctx.returns] is an alias for [`ReturningTransform`][cornucopia.special.ReturningTransform]"
|
|
1621
|
+
|
|
1622
|
+
!!! example "Use as a context manager (alias)"
|
|
1623
|
+
```python
|
|
1624
|
+
from cornucopia import ctx
|
|
1625
|
+
with ctx.returns(xform, "channels") as newxform:
|
|
1626
|
+
image = newxform(image)
|
|
1627
|
+
```
|
|
1628
|
+
"""
|
|
1629
|
+
|
|
1630
|
+
def __init__(
|
|
1631
|
+
self,
|
|
1632
|
+
transform: Transform,
|
|
1633
|
+
returns: tx.Optional[cct.ReturnsT] = None
|
|
1634
|
+
) -> None:
|
|
1635
|
+
super().__init__()
|
|
1636
|
+
self.transform = transform
|
|
1637
|
+
self.returns = returns
|
|
1638
|
+
|
|
1639
|
+
def __enter__(self) -> Transform:
|
|
1640
|
+
self.saved_returns = self.transform.returns
|
|
1641
|
+
self.transform.returns = self.returns
|
|
1642
|
+
return self.transform
|
|
1643
|
+
|
|
1644
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
1645
|
+
self.transform.returns = self.saved_returns
|
|
1646
|
+
delattr(self, 'saved_returns')
|
|
1647
|
+
|
|
1648
|
+
|
|
1649
|
+
class MappedTransform(SpecialTransform):
|
|
1650
|
+
"""
|
|
1651
|
+
Transforms that are applied to specific positional or arguments
|
|
1652
|
+
|
|
1653
|
+
!!! note "[`ctx.map`][cornucopia.ctx.map] is an alias for [`MappedTransform`][cornucopia.special.MappedTransform]"
|
|
1654
|
+
|
|
1655
|
+
!!! example
|
|
1656
|
+
```python
|
|
1657
|
+
img = torch.randn([1, 32, 32])
|
|
1658
|
+
seg = torch.randn([3, 32, 32]).softmax(0)
|
|
1659
|
+
|
|
1660
|
+
# positional variant
|
|
1661
|
+
trf = MappedTransform(GaussianNoise(), None)
|
|
1662
|
+
img, seg = trf(img, seg)
|
|
1663
|
+
|
|
1664
|
+
# keyword variant
|
|
1665
|
+
trf = MappedTransform(image=GaussianNoise())
|
|
1666
|
+
img, seg = trf(image=img, label=seg)
|
|
1667
|
+
|
|
1668
|
+
# alternative version
|
|
1669
|
+
dat = {'img': torch.randn([1, 32, 32]),
|
|
1670
|
+
'seg': torch.randn([3, 32, 32]).softmax(0)}
|
|
1671
|
+
dat = MappedTransform(img=GaussianNoise(), nested=True)(dat)
|
|
1672
|
+
```
|
|
1673
|
+
"""
|
|
1674
|
+
|
|
1675
|
+
def __init__(
|
|
1676
|
+
self,
|
|
1677
|
+
*mapargs,
|
|
1678
|
+
nested: bool = False,
|
|
1679
|
+
default: tx.Optional[Transform] = None,
|
|
1680
|
+
**mapkwargs
|
|
1681
|
+
) -> None:
|
|
1682
|
+
"""
|
|
1683
|
+
|
|
1684
|
+
Parameters
|
|
1685
|
+
----------
|
|
1686
|
+
mapargs : tuple[Transform]
|
|
1687
|
+
Transform to apply to positional arguments
|
|
1688
|
+
mapkwargs : dict[str, Transform]
|
|
1689
|
+
Transform to apply to keyword arguments
|
|
1690
|
+
nested : bool, default=False
|
|
1691
|
+
Recursively traverse the inputs until we find matching
|
|
1692
|
+
dictionaries. Only `mapkwargs` are accepted if "nested".
|
|
1693
|
+
default : Transform
|
|
1694
|
+
Transform to apply if nothing is specifically mapped
|
|
1695
|
+
"""
|
|
1696
|
+
super().__init__(
|
|
1697
|
+
shared=mapkwargs.pop('shared', False),
|
|
1698
|
+
include=mapkwargs.pop('include', None),
|
|
1699
|
+
exclude=mapkwargs.pop('exclude', None),
|
|
1700
|
+
)
|
|
1701
|
+
self.mapargs = mapargs
|
|
1702
|
+
self.mapkwargs = mapkwargs
|
|
1703
|
+
self.nested = nested
|
|
1704
|
+
self.default = default
|
|
1705
|
+
if nested and mapargs:
|
|
1706
|
+
raise ValueError(
|
|
1707
|
+
'Cannot have both `nested` and positional transforms'
|
|
1708
|
+
)
|
|
1709
|
+
|
|
1710
|
+
def forward(self, *args, **kwargs) -> Returned:
|
|
1711
|
+
|
|
1712
|
+
if self.include is not None or self.exclude or self.consume:
|
|
1713
|
+
def wrap(f: tx.Callable) -> tx.Callable:
|
|
1714
|
+
if not f:
|
|
1715
|
+
return f
|
|
1716
|
+
def ff(*a, **k):
|
|
1717
|
+
# NOTE
|
|
1718
|
+
# I do not propagate `returns`. I think it should
|
|
1719
|
+
# be dealt with by the subtransforms.
|
|
1720
|
+
with \
|
|
1721
|
+
IncludeKeysTransform(f, self.include), \
|
|
1722
|
+
ExcludeKeysTransform(f, self.exclude), \
|
|
1723
|
+
ConsumeKeysTransform(f, self.consume):
|
|
1724
|
+
return f(*a, **k)
|
|
1725
|
+
return ff
|
|
1726
|
+
else:
|
|
1727
|
+
def wrap(f: tx.Callable) -> tx.Callable:
|
|
1728
|
+
return f
|
|
1729
|
+
|
|
1730
|
+
# If the input is a wrapped `Arguments`, unwrap it and recurse.
|
|
1731
|
+
arguments = Arguments(*args, **kwargs)
|
|
1732
|
+
if not arguments:
|
|
1733
|
+
return None
|
|
1734
|
+
args, kwargs = arguments.to_args_kwargs()
|
|
1735
|
+
|
|
1736
|
+
def default(x):
|
|
1737
|
+
# Default transform to apply if nothing is mapped.
|
|
1738
|
+
if torch.is_tensor(x):
|
|
1739
|
+
return self.default(x) if self.default else x
|
|
1740
|
+
else:
|
|
1741
|
+
return self.forward(x) if self.nested else x
|
|
1742
|
+
|
|
1743
|
+
if args:
|
|
1744
|
+
# Apply each transform to its corresponding positional argument
|
|
1745
|
+
mapargs = tuple(wrap(f) for f in self.mapargs)
|
|
1746
|
+
mapargs += (default,) * max(0, len(args) - len(mapargs))
|
|
1747
|
+
args = tuple(f(a) if f else a for f, a in zip(mapargs, args))
|
|
1748
|
+
|
|
1749
|
+
if kwargs:
|
|
1750
|
+
# Apply each transform to its corresponding keyword argument
|
|
1751
|
+
mapkwargs = {k: wrap(f) for k, f in self.mapkwargs.items()}
|
|
1752
|
+
|
|
1753
|
+
def func(key):
|
|
1754
|
+
return mapkwargs.get(key, default) or (lambda x: x)
|
|
1755
|
+
|
|
1756
|
+
kwargs = {key: func(key)(value) for key, value in kwargs.items()}
|
|
1757
|
+
|
|
1758
|
+
# If more than a single input argument, wrap them in `Arguments`
|
|
1759
|
+
if kwargs or len(args) > 1:
|
|
1760
|
+
return Arguments(*args, **kwargs)
|
|
1761
|
+
return args[0]
|
|
1762
|
+
|
|
1763
|
+
def __repr__(self) -> str:
|
|
1764
|
+
s = []
|
|
1765
|
+
for v in self.mapargs:
|
|
1766
|
+
s += [f'{v}']
|
|
1767
|
+
for k, v in self.mapkwargs.items():
|
|
1768
|
+
s += [f'{k}={v}']
|
|
1769
|
+
s = ', '.join(s)
|
|
1770
|
+
return f'{type(self).__name__}({s})'
|
|
1771
|
+
|
|
1772
|
+
|
|
1773
|
+
class RandomizedTransform(SpecialTransform, NonFinalTransform):
|
|
1774
|
+
"""
|
|
1775
|
+
Transform generated by randomizing some parameters of another transform.
|
|
1776
|
+
|
|
1777
|
+
!!! note "[`ctx.randomize`][cornucopia.ctx.randomize] is an alias for [`RandomizedTransform`][cornucopia.special.RandomizedTransform]"
|
|
1778
|
+
|
|
1779
|
+
!!! example "Gaussian noise with randomized variance"
|
|
1780
|
+
Object call
|
|
1781
|
+
```python
|
|
1782
|
+
import cornucopia as cc
|
|
1783
|
+
hypernoise = cc.RandomizedTransform(cc.GaussianNoise, [cc.Uniform()])
|
|
1784
|
+
img = hypernoise(img)
|
|
1785
|
+
```
|
|
1786
|
+
|
|
1787
|
+
Delayed call
|
|
1788
|
+
```python
|
|
1789
|
+
import cornucopia as cc
|
|
1790
|
+
MyRandomNoise = cc.randomize(cc.GaussianNoise)
|
|
1791
|
+
hypernoise = MyRandomNoise(cc.Uniform())
|
|
1792
|
+
img = hypernoise(img)
|
|
1793
|
+
```
|
|
1794
|
+
|
|
1795
|
+
"""
|
|
1796
|
+
|
|
1797
|
+
class Delayed:
|
|
1798
|
+
# Temproary parameter holder for delayed calls
|
|
1799
|
+
def __init__(self, transform: Transform, **kwargs) -> None:
|
|
1800
|
+
self.transform = transform
|
|
1801
|
+
self.kwargs = kwargs
|
|
1802
|
+
|
|
1803
|
+
def __call__(self, *args, **kwargs) -> "RandomizedTransform":
|
|
1804
|
+
return RandomizedTransform(
|
|
1805
|
+
self.transform, args, kwargs, **self.kwargs)
|
|
1806
|
+
|
|
1807
|
+
def __new__(cls, *args, **kwargs) -> "RandomizedTransform":
|
|
1808
|
+
if cls is RandomizedTransform:
|
|
1809
|
+
return cls._base_new(*args, **kwargs)
|
|
1810
|
+
return super().__new__(cls)
|
|
1811
|
+
|
|
1812
|
+
@classmethod
|
|
1813
|
+
def _base_new(
|
|
1814
|
+
cls,
|
|
1815
|
+
transform: Transform,
|
|
1816
|
+
sample: tuple = tuple(),
|
|
1817
|
+
ksample: dict = dict(),
|
|
1818
|
+
*,
|
|
1819
|
+
shared: tx.Union[bool, str] = False,
|
|
1820
|
+
**kwargs
|
|
1821
|
+
) -> "RandomizedTransform":
|
|
1822
|
+
assert cls is RandomizedTransform
|
|
1823
|
+
if not sample and not ksample:
|
|
1824
|
+
# If no arguments are passed, it means that the user calls
|
|
1825
|
+
# this in "delayed/functional" mode. In that case, we return
|
|
1826
|
+
# a callable object that returns the constructed instance
|
|
1827
|
+
# using the call-time arguments.
|
|
1828
|
+
return cls.Delayed(transform, shared=shared, **kwargs)
|
|
1829
|
+
# Otherwise, we're in object mode and we instantiate the
|
|
1830
|
+
# randomized object.
|
|
1831
|
+
return super().__new__(cls)
|
|
1832
|
+
|
|
1833
|
+
def __init__(
|
|
1834
|
+
self,
|
|
1835
|
+
transform: Transform,
|
|
1836
|
+
sample: tuple = tuple(),
|
|
1837
|
+
ksample: dict = dict(),
|
|
1838
|
+
*,
|
|
1839
|
+
shared: tx.Union[bool, str] = False,
|
|
1840
|
+
**kwargs
|
|
1841
|
+
) -> None:
|
|
1842
|
+
"""
|
|
1843
|
+
Parameters
|
|
1844
|
+
----------
|
|
1845
|
+
transform : callable(...) -> Transform
|
|
1846
|
+
A Transform subclass or a function that constructs a Transform.
|
|
1847
|
+
sample : [list or dict of] callable
|
|
1848
|
+
A collection of functions that generate parameter values provided
|
|
1849
|
+
to `transform`. Can be args-like or kwargs-like arguments.
|
|
1850
|
+
ksample : dict[callable]
|
|
1851
|
+
Must be kwargs-like arguments.
|
|
1852
|
+
|
|
1853
|
+
Other Parameters
|
|
1854
|
+
----------------
|
|
1855
|
+
shared : {'channels', 'tensors', 'channels+tensors', ''} | bool
|
|
1856
|
+
Share random parameters across tensors and/or channels
|
|
1857
|
+
"""
|
|
1858
|
+
super().__init__(shared=shared, **kwargs)
|
|
1859
|
+
self.sample = sample
|
|
1860
|
+
self.ksample = ksample
|
|
1861
|
+
self.subtransform = transform
|
|
1862
|
+
|
|
1863
|
+
def _unroll(
|
|
1864
|
+
self, x: Tensor, /,
|
|
1865
|
+
max_depth: int = inf,
|
|
1866
|
+
args: Arguments = NoArguments()
|
|
1867
|
+
) -> Transform:
|
|
1868
|
+
if max_depth == 0:
|
|
1869
|
+
return self
|
|
1870
|
+
if 'channels' not in self.shared and len(x) > 1:
|
|
1871
|
+
return self.make_per_channel(x, max_depth, args=args)
|
|
1872
|
+
|
|
1873
|
+
args = []
|
|
1874
|
+
kwargs = {}
|
|
1875
|
+
|
|
1876
|
+
if isinstance(self.sample, (list, tuple)):
|
|
1877
|
+
args += [
|
|
1878
|
+
f() if isinstance(f, Sampler) else f
|
|
1879
|
+
for f in self.sample
|
|
1880
|
+
]
|
|
1881
|
+
|
|
1882
|
+
elif hasattr(self.sample, 'items'):
|
|
1883
|
+
kwargs.update({
|
|
1884
|
+
k: f() if isinstance(f, Sampler) else f
|
|
1885
|
+
for k, f in self.sample.items()
|
|
1886
|
+
})
|
|
1887
|
+
|
|
1888
|
+
else:
|
|
1889
|
+
args += [
|
|
1890
|
+
f() if isinstance(f, Sampler) else f
|
|
1891
|
+
for f in (self.sample,)
|
|
1892
|
+
]
|
|
1893
|
+
|
|
1894
|
+
if self.ksample:
|
|
1895
|
+
kwargs.update({
|
|
1896
|
+
k: f() if isinstance(f, Sampler) else f
|
|
1897
|
+
for k, f in self.ksample.items()
|
|
1898
|
+
})
|
|
1899
|
+
|
|
1900
|
+
# Propagate general options (include, exclude),
|
|
1901
|
+
# unless they are already set by sample/ksample.
|
|
1902
|
+
for key, value in self.get_prm().items():
|
|
1903
|
+
kwargs.setdefault(key, value)
|
|
1904
|
+
|
|
1905
|
+
# Build transform with fixed parameters, and recurse.
|
|
1906
|
+
xform = self.subtransform(*args, **kwargs)
|
|
1907
|
+
xform = xform.unroll(x, max_depth-1, args=args)
|
|
1908
|
+
return xform
|
|
1909
|
+
|
|
1910
|
+
def __repr__(self) -> str:
|
|
1911
|
+
if type(self) is RandomizedTransform:
|
|
1912
|
+
xform = self.subtransform
|
|
1913
|
+
if isinstance(xform, type) and issubclass(xform, Transform):
|
|
1914
|
+
return f'Randomized{xform.__name__}()'
|
|
1915
|
+
return super().__repr__()
|