cornucopia 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cornucopia/__init__.py +73 -0
  2. cornucopia/base.py +1915 -0
  3. cornucopia/baseutils.py +575 -0
  4. cornucopia/contrast.py +260 -0
  5. cornucopia/ctx.py +25 -0
  6. cornucopia/fov.py +707 -0
  7. cornucopia/geometric.py +2068 -0
  8. cornucopia/intensity.py +1358 -0
  9. cornucopia/io.py +161 -0
  10. cornucopia/kspace.py +505 -0
  11. cornucopia/labels.py +1872 -0
  12. cornucopia/noise.py +508 -0
  13. cornucopia/psf.py +463 -0
  14. cornucopia/qmri.py +1288 -0
  15. cornucopia/random.py +1480 -0
  16. cornucopia/special.py +159 -0
  17. cornucopia/synth.py +708 -0
  18. cornucopia/tests/__init__.py +0 -0
  19. cornucopia/tests/test_backward_geometric.py +173 -0
  20. cornucopia/tests/test_backward_intensity.py +243 -0
  21. cornucopia/tests/test_backward_kspace.py +115 -0
  22. cornucopia/tests/test_backward_noise.py +169 -0
  23. cornucopia/tests/test_backward_psf.py +142 -0
  24. cornucopia/tests/test_backward_qmri.py +249 -0
  25. cornucopia/tests/test_backward_random.py +44 -0
  26. cornucopia/tests/test_backward_synth.py +72 -0
  27. cornucopia/tests/test_base.py +401 -0
  28. cornucopia/tests/test_geometric.py +26 -0
  29. cornucopia/tests/test_intensity.py +9 -0
  30. cornucopia/tests/test_random.py +722 -0
  31. cornucopia/tests/test_run_contrast.py +28 -0
  32. cornucopia/tests/test_run_fov.py +132 -0
  33. cornucopia/tests/test_run_geometric.py +157 -0
  34. cornucopia/tests/test_run_intensity.py +192 -0
  35. cornucopia/tests/test_run_kspace.py +70 -0
  36. cornucopia/tests/test_run_labels.py +224 -0
  37. cornucopia/tests/test_run_noise.py +127 -0
  38. cornucopia/tests/test_run_psf.py +115 -0
  39. cornucopia/tests/test_run_qmri.py +114 -0
  40. cornucopia/tests/test_run_synth.py +67 -0
  41. cornucopia/typing.py +97 -0
  42. cornucopia/utils/__init__.py +0 -0
  43. cornucopia/utils/b0.py +745 -0
  44. cornucopia/utils/bounds.py +412 -0
  45. cornucopia/utils/compat.py +47 -0
  46. cornucopia/utils/conv.py +305 -0
  47. cornucopia/utils/gmm.py +169 -0
  48. cornucopia/utils/indexing.py +911 -0
  49. cornucopia/utils/io.py +258 -0
  50. cornucopia/utils/jit.py +128 -0
  51. cornucopia/utils/kernels.py +288 -0
  52. cornucopia/utils/morpho.py +234 -0
  53. cornucopia/utils/mrf.py +574 -0
  54. cornucopia/utils/padding.py +173 -0
  55. cornucopia/utils/patch.py +302 -0
  56. cornucopia/utils/pool.py +282 -0
  57. cornucopia/utils/py.py +348 -0
  58. cornucopia/utils/smart_inplace.py +163 -0
  59. cornucopia/utils/version.py +57 -0
  60. cornucopia/utils/warps.py +606 -0
  61. cornucopia-0.0.0.dist-info/METADATA +92 -0
  62. cornucopia-0.0.0.dist-info/RECORD +65 -0
  63. cornucopia-0.0.0.dist-info/WHEEL +5 -0
  64. cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
  65. cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/base.py ADDED
@@ -0,0 +1,1915 @@
1
+ __all__ = [
2
+ 'Transform',
3
+ 'FinalTransform',
4
+ 'NonFinalTransform',
5
+ ]
6
+ # stdlib
7
+ import inspect
8
+ import random
9
+ import re
10
+ from abc import ABC
11
+ from copy import copy
12
+ from fnmatch import fnmatch
13
+ from math import inf
14
+
15
+ # dependencies
16
+ import torch
17
+ import typing_extensions as tx
18
+ from torch import nn, Tensor
19
+
20
+ # internal
21
+ from .random import Sampler
22
+ from .utils.py import ensure_list, cumsum
23
+ from .baseutils import (
24
+ Arguments, Args, Kwargs, ArgsAndKwargs, NoArguments, Returned,
25
+ get_first_element, prepare_output, UNSET, recursive_cat,
26
+ )
27
+ from . import typing as cct
28
+
29
+
30
+ class Transform(nn.Module, ABC):
31
+ """Base class for all transforms."""
32
+
33
+ def __init__(
34
+ self, *,
35
+ returns: tx.Union[str, tx.Sequence[str], tx.Mapping[str, str], None] = None,
36
+ append: tx.Union[bool, str] = False,
37
+ prefix: tx.Union[bool, str] = True,
38
+ include: tx.Optional[cct.IncludeT] = None,
39
+ exclude: tx.Optional[cct.ExcludeT] = None,
40
+ consume: tx.Optional[cct.ConsumeT] = None,
41
+ ):
42
+ """
43
+ Parameters
44
+ ----------
45
+ returns : [list or dict of] str, optional
46
+ Which tensors to return. Can be a nested structure.
47
+ Most transforms accept `'input'` and `'output'` as valid
48
+ returns. The default is `'output'`.
49
+ append : bool | str
50
+ Append the (structure of) returned tensors to the parent
51
+ structure.
52
+
53
+ !!! warning
54
+ This option does not keep the input tensors in the returned
55
+ structure! To preserve the input tensors, you should
56
+ use `append` in conjunction with `returns`.
57
+
58
+ !!! example
59
+ ```python
60
+ # With lists
61
+ trf = MyTransform(returns=['input', 'output'], append=True)
62
+ x1, y1, x2, y2 = trf([x1, x2])
63
+
64
+ # With dicts
65
+ trf = MyTransform(returns={'x': 'input', 'y': 'output'}, append=True, prefix=True)
66
+ out = trf({'path1': x1, 'path2': x2})
67
+ assert out.keys() == {'path1.x', 'path1.y', 'path2.x', 'path2.y'}
68
+ ```
69
+
70
+ !!! changedin "![v0.5](https://img.shields.io/badge/v0.5-yellow) \
71
+ Can be a string since `v0.5`"
72
+ If it is a `str` and parent is a `dict`, its value will be
73
+ used as a separator between the prefix and the key.
74
+ See `prefix`.
75
+ prefix : bool | str
76
+ If `append` and parent is a `dict`, prefix the returned key
77
+ before inserting it in the output dictionary.
78
+
79
+ If `True`, the prefix is the input key.
80
+
81
+ !!! changedin "![v0.5](https://img.shields.io/badge/v0.5-yellow) \
82
+ Can be a string since `v0.5`"
83
+ include : [list of] str | re.Pattern, optional
84
+ List of keys to which the transform should apply.
85
+ Default: all.
86
+
87
+ !!! changedin "![v0.5](https://img.shields.io/badge/v0.5-yellow) \
88
+ Can be a regex or glob pattern since `v0.5`"
89
+ exclude : [list of] str | re.Pattern, optional
90
+ List of keys to which the transform should not apply.
91
+ Default: none.
92
+
93
+ !!! changedin "![v0.5](https://img.shields.io/badge/v0.5-yellow) \
94
+ Can be a regex or glob pattern since `v0.5`"
95
+ consume : [list of] str | re.Pattern, optional
96
+ List of keys to remove from the output after applying the
97
+ transform. Default: none.
98
+
99
+ !!! addedin "![v0.5](https://img.shields.io/badge/v0.5-green) \
100
+ Added in `v0.5`. Can be a regex or glob pattern."
101
+ """
102
+ super().__init__()
103
+ self.returns = returns
104
+ self.append = append
105
+ self.prefix = prefix
106
+ self.include = ensure_list(include) if include is not None else None
107
+ self.exclude = ensure_list(exclude or tuple())
108
+ self.consume = ensure_list(consume or tuple())
109
+
110
+ @property
111
+ def is_final(self) -> bool:
112
+ """
113
+ Returns
114
+ -------
115
+ bool
116
+ Whether the transform is final (i.e., deterministic) or not.
117
+ """
118
+ return False
119
+
120
+ def get_prm(self) -> dict:
121
+ """Get the parameters of the transform, for use in subtransforms.
122
+
123
+ Returns
124
+ -------
125
+ dict
126
+ A dictionary containing the attributes
127
+ `returns`, `append`, `prefix`, `include`, `exclude`, and
128
+ `consume`.
129
+
130
+ """
131
+ return dict(
132
+ returns=self.returns,
133
+ append=self.append,
134
+ prefix=self.prefix,
135
+ include=self.include,
136
+ exclude=self.exclude,
137
+ consume=self.consume
138
+ )
139
+
140
+ def _is_included(self, key: str) -> bool:
141
+ if self.include is None:
142
+ return True
143
+ for pattern in self.include:
144
+ if isinstance(pattern, re.Pattern):
145
+ return bool(pattern.match(key))
146
+ else:
147
+ return fnmatch(key, pattern)
148
+ return False
149
+
150
+ def _is_excluded(self, key: str) -> bool:
151
+ if self.exclude is None:
152
+ return False
153
+ for pattern in self.exclude:
154
+ if isinstance(pattern, re.Pattern):
155
+ return bool(pattern.match(key))
156
+ else:
157
+ return fnmatch(key, pattern)
158
+ return False
159
+
160
+ def _is_consumed(self, key: str) -> bool:
161
+ if self.consume is None:
162
+ return False
163
+ for pattern in self.consume:
164
+ if isinstance(pattern, re.Pattern):
165
+ return bool(pattern.match(key))
166
+ else:
167
+ return fnmatch(key, pattern)
168
+ return False
169
+
170
+ def __enter__(self) -> "Transform":
171
+ # On most tranfsorms, this does nothing but return the transform
172
+ # itself. However, some subclasses use this to act as context
173
+ # managers that temporarily modify the transform's behavior.
174
+ # See: `IncludeKeysTransform` and `ExcludeKeysTransform`.
175
+ return self
176
+
177
+ def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
178
+ return False
179
+
180
+ def __add__(self, other: "Transform") -> "SequentialTransform":
181
+ return SequentialTransform([self, other])
182
+
183
+ def __radd__(self, other: "Transform") -> "SequentialTransform":
184
+ return SequentialTransform([other, self])
185
+
186
+ def __iadd__(self, other: "Transform") -> "SequentialTransform":
187
+ return SequentialTransform([self, other])
188
+
189
+ def __mul__(self, prob: float) -> "MaybeTransform":
190
+ return MaybeTransform(self, prob)
191
+
192
+ def __rmul__(self, prob: float) -> "MaybeTransform":
193
+ return MaybeTransform(self, prob)
194
+
195
+ def __imul__(self, prob: float) -> "MaybeTransform":
196
+ return MaybeTransform(self, prob)
197
+
198
+ def __or__(self, other: "Transform") -> "SwitchTransform":
199
+ return SwitchTransform([self, other])
200
+
201
+ def __ior__(self, other: "Transform") -> "SwitchTransform":
202
+ return SwitchTransform([self, other])
203
+
204
+ def __call__(self, *a, **k):
205
+ # Use the torch machinery, although `Returned` objects get unwrapped.
206
+ out = super().__call__(*a, **k)
207
+ if isinstance(out, Returned):
208
+ out = out.obj
209
+ return out
210
+
211
+ def forward(self, *a, **k) -> Returned:
212
+ """Apply the transform recursively.
213
+
214
+ Parameters
215
+ ----------
216
+ *a, **k : [nested list or dict of] tensor
217
+ Input tensors, with shape `(C, *shape)`
218
+
219
+ Returns
220
+ -------
221
+ [nested list or dict of] tensor
222
+ Output tensors. with shape `(C, *shape)`
223
+
224
+ """
225
+ # We wrap positional and keywork arguments in special classes
226
+ # to differentiate from inputs that are lists or dicts.
227
+ print(self.__class__, tuple(type(aa) for aa in a), {k: type(v) for k, v in k.items()})
228
+ x = args = Arguments(*a, **k)
229
+ print("->", type(x))
230
+
231
+ if not args:
232
+ # If no input arguments, return None.
233
+ # NOTE: Only `NoArguments()` reduces to `False`.
234
+ return None
235
+
236
+ # Arguments are passed to `_Forward.__init__` and `_Forward.__call__`.
237
+ # The former is preserved as is in the `_Forward` object, and passed
238
+ # to each `xform` or `unroll` call, while the latter is
239
+ # recursively unwrapped and processed by `_Forward.__call__`.
240
+ return self._Forward(self, args, **self.get_prm())(x)
241
+
242
+ class _Forward:
243
+
244
+ def __init__(self, transform, args: Arguments, **prm) -> None:
245
+ self.transform = transform
246
+ self.args = args
247
+ self.prm = prm
248
+
249
+ @property
250
+ def include(self) -> tx.Optional[tx.Sequence[str]]:
251
+ return self.prm.get('include')
252
+
253
+ @property
254
+ def exclude(self) -> tx.Optional[tx.Sequence[str]]:
255
+ return self.prm.get('exclude')
256
+
257
+ @property
258
+ def consume(self) -> tx.Optional[tx.Sequence[str]]:
259
+ return self.prm.get('consume')
260
+
261
+ @property
262
+ def append(self) -> bool:
263
+ return self.prm.get('append', False)
264
+
265
+ @property
266
+ def prefix(self) -> tx.Union[bool, str]:
267
+ return self.prm.get('prefix', True)
268
+
269
+ @property
270
+ def returns(self) -> tx.Optional[tx.Union[
271
+ str, tx.Sequence[str], tx.Mapping[str, str]
272
+ ]]:
273
+ return self.prm.get('returns')
274
+
275
+ def __call__(
276
+ self, x: tx.Union[Tensor, tx.Sequence, tx.Mapping, Arguments]
277
+ ) -> tx.Union[Tensor, tx.Sequence, tx.Mapping, Arguments]:
278
+ if isinstance(x, NoArguments):
279
+ return None
280
+
281
+ # At this point, there is a single positional argument
282
+ # (which may be an Args, Kwargs, or ArgsAndKwargs object)
283
+ # and no keyword arguments. We save the original input type
284
+ # to be able to return the same kind of `Parameters` object.
285
+ intype = type(x)
286
+
287
+ # If the input is an `Arguments` object, convert it to list/dict
288
+ if isinstance(x, Arguments):
289
+ x = x.unwrap()
290
+
291
+ def outtype(x):
292
+ # Convert back to the original input type if needed
293
+ if intype is ArgsAndKwargs:
294
+ args, kwargs = x
295
+ elif intype is Args:
296
+ args, kwargs = x, {}
297
+ elif intype is Kwargs:
298
+ args, kwargs = (), x
299
+ else:
300
+ return x
301
+ return Arguments(*args, **kwargs)
302
+
303
+ # Not shared across tensors -> unfold
304
+ if isinstance(x, (list, tuple)):
305
+ return outtype(self._forward_list(x))
306
+
307
+ if hasattr(x, 'items'):
308
+ x = self._forward_dict(x)
309
+ if not isinstance(x, dict) and intype in (Kwargs, ArgsAndKwargs):
310
+ return Args(*x)
311
+ else:
312
+ return outtype(x)
313
+
314
+ # ---- Now we're working with a single tensor (or str) ----
315
+
316
+ # Apply the transform to the input tensor
317
+ y = self.transform.xform(x, args=self.args)
318
+
319
+ # Most transforms return a well-formatted `Returned` object,
320
+ # which contain all possible outputs of a transform, mapped
321
+ # into a structure specified by the `returns` argument.
322
+ if not isinstance(y, Returned):
323
+ # When they do not, we have to build the `Returned`
324
+ # object ourselves.
325
+ if not isinstance(y, type(self.returns)):
326
+ # The transform returned a single output (likely a
327
+ # tensor), which we assign to the `output` key,
328
+ # while the input is assigned to the `input` key.
329
+ y = dict(input=x, output=y)
330
+ y = prepare_output(y, self.returns).obj
331
+ else:
332
+ # `returns` and `y` have the same type, but `y` may
333
+ # have been obtained from a subtransform. We cannot
334
+ # guarantee that the keys of `y` are the same as
335
+ # those of `returns`, but we can insert the correct
336
+ # `input` tensor, it it was requested.
337
+ # NOTE: we cannot break early once `"input"` is
338
+ # encountered, because multiple outputs elements can
339
+ # contain the same target
340
+ # (e.g. `returns=['input', 'input']`).
341
+ if isinstance(self.returns, dict):
342
+ for key, target in self.returns.items():
343
+ if target == 'input':
344
+ y[key] = x
345
+ elif isinstance(self.returns, (list, tuple)):
346
+ for i, target in enumerate(self.returns):
347
+ if target == 'input':
348
+ y[i] = x
349
+ # Wrap output in a `Returned` object, so that helpers
350
+ # know how to handle it (e.g. when `append=True`).
351
+ y = Returned(y)
352
+
353
+ # ---- Now we're working with a `Returned` object ----
354
+ return y
355
+
356
+ def _is_included(self, key: str) -> bool:
357
+ if self.include is None:
358
+ return True
359
+ for pattern in self.include:
360
+ if isinstance(pattern, re.Pattern):
361
+ return bool(pattern.match(key))
362
+ else:
363
+ return fnmatch(key, pattern)
364
+ return False
365
+
366
+ def _is_excluded(self, key: str) -> bool:
367
+ if self.exclude is None:
368
+ return False
369
+ for pattern in self.exclude:
370
+ if isinstance(pattern, re.Pattern):
371
+ return bool(pattern.match(key))
372
+ else:
373
+ return fnmatch(key, pattern)
374
+ return False
375
+
376
+ def _is_consumed(self, key: str) -> bool:
377
+ if self.consume is None:
378
+ return False
379
+ for pattern in self.consume:
380
+ if isinstance(pattern, re.Pattern):
381
+ return bool(pattern.match(key))
382
+ else:
383
+ return fnmatch(key, pattern)
384
+ return False
385
+
386
+ def _get_valid_keys(self, x: tx.Mapping[str, str]) -> tx.Sequence[str]:
387
+ valid_keys = x.keys()
388
+ if self.include is not None:
389
+ valid_keys = [
390
+ k for k in valid_keys
391
+ if self._is_included(k)
392
+ ]
393
+ if self.exclude:
394
+ valid_keys = [
395
+ k for k in valid_keys
396
+ if not self._is_excluded(k)
397
+ ]
398
+ return valid_keys
399
+
400
+ def _forward_list(
401
+ self, x: list, forward: tx.Optional[tx.Callable] = None
402
+ ) -> list:
403
+ """Apply forward pass to elements of a list"""
404
+ forward = forward or self
405
+ append = self.append and not isinstance(self.append, str)
406
+ y = []
407
+ for elem in x:
408
+ elem = forward(elem)
409
+ if isinstance(elem, Returned):
410
+ elem = elem.obj
411
+ if append:
412
+ if isinstance(elem, dict):
413
+ y.extend(elem.values())
414
+ continue
415
+ elif isinstance(elem, (list, tuple)):
416
+ y.extend(elem)
417
+ continue
418
+ y.append(elem)
419
+ return type(x)(y)
420
+
421
+ def _forward_dict(
422
+ self, x: dict, forward: tx.Optional[tx.Callable] = None
423
+ ) -> tx.Union[dict, list]:
424
+ """Apply forward pass to elements of a dict"""
425
+ forward = forward or self
426
+ valid_keys = self._get_valid_keys(x)
427
+
428
+ append = self.append
429
+ if isinstance(append, str):
430
+ sep, append = append, True
431
+ elif append:
432
+ sep = "."
433
+
434
+ # Initialise output dictionary with input keys and values
435
+ # that *will not* be transformed so that they are preserved.
436
+ y = {
437
+ key: value
438
+ for key, value in x.items()
439
+ if key not in valid_keys and not self._is_consumed(key)
440
+ }
441
+
442
+ # For each input item, apply the transform and save its outputs
443
+ for key, value in x.items():
444
+ if key not in valid_keys:
445
+ continue
446
+
447
+ # Compute prefix
448
+ prefix = self.prefix
449
+ if prefix and not isinstance(self.prefix, str):
450
+ prefix = key
451
+
452
+ # Apply transform to input value
453
+ value = forward(value)
454
+
455
+ # Deal with returned values.
456
+ if isinstance(value, Returned):
457
+ value = value.obj
458
+
459
+ if append:
460
+
461
+ if isinstance(y, dict) and isinstance(value, dict):
462
+ # Insert the returned values in the output dictionary,
463
+ # using a new key, so that the input value is preserved.
464
+ if prefix:
465
+ value = {
466
+ prefix + sep + child_key: child_value
467
+ for child_key, child_value in value.items()
468
+ }
469
+ y.update(value)
470
+ continue
471
+
472
+ if isinstance(y, dict):
473
+ # The transform did not return a dictionary, so
474
+ # we cannot insert its values in the output dict.
475
+ # We transform it into a list and append the
476
+ # returned values to it.
477
+ y = list(y.values())
478
+
479
+ if isinstance(value, (list, tuple)):
480
+ # Append the returned values to the output list.
481
+ y.extend(value)
482
+ continue
483
+
484
+ if isinstance(value, dict):
485
+ # The output dictionary was previously transformed
486
+ # into a list, so we cannot insert the returned
487
+ # (key, value) pairs. We append the values instead.
488
+ y.extend(value.values())
489
+ continue
490
+
491
+ # We insert the returned value (whether it is a single tensor,
492
+ # or a nested structure of tensors) it in the output dictionary,
493
+ # in place of the input value.
494
+ if isinstance(y, dict):
495
+ y[key] = value
496
+ else:
497
+ y.append(value)
498
+
499
+ if isinstance(y, dict):
500
+
501
+ # Consume keys
502
+ for key in list(y.keys()):
503
+ if self._is_consumed(key):
504
+ y.pop(key)
505
+
506
+ # Convert to dictionary subtype
507
+ return type(x)(y)
508
+ else:
509
+ return y
510
+
511
+ def xform(
512
+ self, x: Tensor, /,
513
+ args: Arguments = NoArguments(),
514
+ ) -> Returned:
515
+ """Apply the transform to a tensor.
516
+
517
+ Non-final transforms do not implement this method in general.
518
+
519
+ Parameters
520
+ ----------
521
+ x : (C_inp, *spatial_inp) tensor
522
+ A single input tensor
523
+ args: Arguments, optional
524
+ The original inputs arguments to the transform, in case
525
+ they are needed.
526
+
527
+ Returns
528
+ -------
529
+ y : Returned | (C_out, *spatial_out) tensor
530
+ A single output tensor, or a `Returned` object containing
531
+ multiple output tensors and their corresponding keys.
532
+ """
533
+ # Wrapper that calls `_xform`, but only passes `args` if the
534
+ # method accepts it, to avoid errors with legacy implementations.
535
+ if 'args' in inspect.signature(self._xform).parameters:
536
+ return self._xform(x, args=args)
537
+ else:
538
+ return self._xform(x)
539
+
540
+ def _xform(
541
+ self, x: Tensor, /,
542
+ args: Arguments = NoArguments(),
543
+ ) -> Returned:
544
+ raise NotImplementedError("This transform does not implement `xform`.")
545
+
546
+ def final(
547
+ self,
548
+ x: Tensor, /,
549
+ args: Arguments = NoArguments(),
550
+ **kwargs
551
+ ) -> "FinalTransform":
552
+ """
553
+ Generate the final version of the transform.
554
+
555
+ Some transforms save the output type of this function in their
556
+ `Final` attribute.
557
+
558
+ !!! addedin "![v0.5](https://img.shields.io/badge/v0.5-green) \
559
+ Added `final` method in `v0.5`."
560
+ Before this, one had to use `make_final(x, max_depth=inf)`.
561
+
562
+ Parameters
563
+ ----------
564
+ x : tensor
565
+ A single input tensor, with shape `(C, *shape)`.
566
+ args: Arguments, optional
567
+ The original inputs arguments to the transform, in case
568
+ they are needed.
569
+
570
+ Returns
571
+ -------
572
+ FinalTransform
573
+ A final version of the transform.
574
+ """
575
+ return self.unroll(x, max_depth=inf, args=args, **kwargs)
576
+
577
+ def next(
578
+ self,
579
+ x: Tensor, /,
580
+ args: Arguments = NoArguments(),
581
+ **kwargs
582
+ ) -> "FinalTransform":
583
+ """
584
+ Generate the next version of the transform.
585
+
586
+ Some transforms save the output type of this function in their
587
+ `Next` attribute.
588
+
589
+ !!! addedin "![v0.5](https://img.shields.io/badge/v0.5-green) \
590
+ Added `next` method in `v0.5`."
591
+ Before this, one had to use `make_final(x, max_depth=1)`.
592
+
593
+ Parameters
594
+ ----------
595
+ x : tensor
596
+ A single input tensor, with shape `(C, *shape)`.
597
+ args: Arguments, optional
598
+ The original inputs arguments to the transform, in case
599
+ they are needed.
600
+
601
+ Returns
602
+ -------
603
+ Transform
604
+ A more specialized version of the transform.
605
+ """
606
+ return self.unroll(x, max_depth=1, args=args, **kwargs)
607
+
608
+ def unroll(
609
+ self, x: Tensor, /,
610
+ max_depth: int = inf,
611
+ args: Arguments = NoArguments(),
612
+ **kwargs
613
+ ) -> "Transform":
614
+ """
615
+ Generate the next (i.e., more final) version(s) of the transform.
616
+
617
+ * To completely finalize a transform,
618
+ call `unroll(x, max_depth=inf)` or `final()`.
619
+ * To get the the next version of a transform,
620
+ call `unroll(x, max_depth=1)` or `next()`.
621
+
622
+ !!! addedin "![v0.5](https://img.shields.io/badge/v0.5-green) \
623
+ Added `unroll` method in `v0.5`."
624
+ Before this, it was named `make_final`.
625
+
626
+ Parameters
627
+ ----------
628
+ x : tensor
629
+ A single input tensor, with shape `(C, *shape)`.
630
+ max_depth : int | {inf}
631
+ Maximum depth to apply `unroll` recursively.
632
+ If not `inf`, the resulting transform may not be fully final.
633
+ Default: no limit.
634
+ args: Arguments, optional
635
+ The original inputs arguments to the transform, in case
636
+ they are needed.
637
+
638
+ Returns
639
+ -------
640
+ Transform
641
+ A more specialized version of the transform.
642
+ """
643
+ if max_depth <= 0:
644
+ # This is always valid, so let's catch it
645
+ return self
646
+ # Wrapper that calls `_unroll`, but only passes `args` if the
647
+ # method accepts it.
648
+ if 'args' in inspect.signature(self._unroll).parameters:
649
+ return self._unroll(x, max_depth, args=args, **kwargs)
650
+ else:
651
+ return self._unroll(x, max_depth, **kwargs)
652
+
653
+ def make_final(
654
+ self, x: Tensor, /,
655
+ max_depth: int = inf,
656
+ args: Arguments = NoArguments(),
657
+ **kwargs
658
+ ) -> "Transform":
659
+ # Deprecated, but keep it for backward compatibility
660
+ return self.unroll(x, max_depth=max_depth, args=args, **kwargs)
661
+
662
+ def _unroll(
663
+ self, x: Tensor, /,
664
+ max_depth: int = inf,
665
+ args: Arguments = NoArguments(),
666
+ ) -> "Transform":
667
+ if self.is_final or max_depth == 0:
668
+ return self
669
+ raise NotImplementedError("This transform does not implement `unroll`")
670
+
671
+ def inverse(self, *a, **k) -> "FinalTransform":
672
+ """Apply the inverse transform recursively
673
+
674
+ Parameters
675
+ ----------
676
+ *a, **k : [nested list or dict of] tensor
677
+ Input tensors, with shape `(C, *shape)`
678
+
679
+ Returns
680
+ -------
681
+ [nested list or dict of] tensor
682
+ Output tensors. with shape `(C, *shape)`
683
+
684
+ """
685
+ return self.make_inverse()(*a, **k)
686
+
687
+ def make_inverse(self) -> "FinalTransform":
688
+ """Generate the inverse transform"""
689
+ # We fallback to the identity, rather than raising an error.
690
+ return IdentityTransform()
691
+
692
+
693
+ class FinalTransform(Transform):
694
+ """
695
+ Base class for determinstic transforms.
696
+
697
+ Final transforms *must* implement the `xform` method.
698
+ """
699
+
700
+ @property
701
+ def is_final(self) -> bool:
702
+ return True
703
+
704
+ def _unroll(
705
+ self, x: Tensor, /,
706
+ max_depth: int = inf,
707
+ args: Arguments = NoArguments(),
708
+ ) -> "FinalTransform":
709
+ return self
710
+
711
+
712
+ class _SharedMixin:
713
+ """
714
+ Mixin for transforms that have parameters (e.g. random ones)
715
+ that may be shared across tensors and/or channels or independent
716
+ across tensors and/or channels.
717
+ """
718
+
719
+ @classmethod
720
+ def _prepare_shared(cls, shared):
721
+ if shared is True:
722
+ shared = 'channels+tensors'
723
+ if shared is False:
724
+ shared = ''
725
+ return shared
726
+
727
+ def _xform(
728
+ self, x: Tensor, /, args: Arguments = NoArguments(),
729
+ ) -> Returned:
730
+ template = x
731
+ if 'channels' in self.shared:
732
+ # Use the first channel only to compute the final transform
733
+ template = template[:1]
734
+ # Compute the next form of this transform
735
+ # NOTE: we do not use `max_depth=inf` because the `shared` option
736
+ # may differ across transformations in the hierarchy. For example,
737
+ # the top-level parameters of a transformation may be shared
738
+ # (e.g., the number of control points in a bias field), but not
739
+ # the lower level ones (e.g., the values of the control points).
740
+ transformation = self.next(template, args=args)
741
+ if transformation is self:
742
+ # Avoid infinite recursion. This should not happen.
743
+ raise ValueError(
744
+ f"The transform is not final, but calling `next` "
745
+ f"returned itself. Transform: {self}"
746
+ )
747
+ # Apply the final transform to all channels
748
+ return transformation.xform(x, args=args)
749
+
750
+ def forward(self, *a, **k) -> Returned:
751
+ return self._shared_forward(*a, **k)
752
+
753
+ def _shared_forward(self, *a, _fallback=None, **k) -> Returned:
754
+ _fallback = _fallback or super().forward
755
+ args = Arguments(*a, **k)
756
+ a, k = args.to_args_kwargs()
757
+
758
+ if 'tensors' in self.shared:
759
+ # Get the first valid tensor across all inputs, and use it
760
+ # as the template to compute the final transform.
761
+ first_tensor = get_first_element(
762
+ [a, k], include=self.include, exclude=self.exclude)
763
+ if 'channels' in self.shared:
764
+ # Get the first channel only, to compute the final transform.
765
+ first_tensor = first_tensor[:1]
766
+ # Compute the next form of this transform...
767
+ transform = self.next(first_tensor, args=args)
768
+ # ...and apply it to all tensors.
769
+ return transform(*a, **k)
770
+
771
+ # Else, we let `xform` deal with shared parameters across channels.
772
+ return _fallback(*a, **k)
773
+
774
+ def make_per_channel(
775
+ self, x: Tensor, /,
776
+ max_depth: int = float('inf'),
777
+ args: Arguments = NoArguments(),
778
+ **kwargs
779
+ ) -> "PerChannelTransform":
780
+ prm = dict(self.get_prm())
781
+ prm.pop('shared', None)
782
+ return PerChannelTransform([
783
+ self.unroll(x[i:i+1], max_depth, args=args, **kwargs)
784
+ for i in range(len(x))
785
+ ], **prm).unroll(x, max_depth-1)
786
+
787
+
788
+ class NonFinalTransform(_SharedMixin, Transform):
789
+ """
790
+ Transforms whose parameters depend on features of the input
791
+ transform (shape, dtype, etc).
792
+
793
+ Non-final transforms implement `unroll`, and do not implement
794
+ `xform`. Their aim is to generate a more-specialized transform
795
+ at call time.
796
+ """
797
+ def __init__(self, *, shared: bool = False, **kwargs) -> None:
798
+ """
799
+ Parameters
800
+ ----------
801
+ shared : {'channels', 'tensors', 'channels+tensor', ''} | bool
802
+
803
+ - `'channel'`: the same transform is applied to all channels
804
+ in a tensor, but different transforms are used in different
805
+ tensors.
806
+ - `'tensors'`: the same transform is applied to all tensors,
807
+ but with a different transform for each channel.
808
+ - `'channels+tensors'` or `True`: the same transform is applied
809
+ to all channels of all tensors.
810
+ - `''` or `False`: A different transform is applied to each
811
+ channel and each tensor.
812
+ """
813
+ super().__init__(**kwargs)
814
+ self.shared = self._prepare_shared(shared)
815
+
816
+
817
+ class SpecialTransform(Transform):
818
+ """Base class for transforms that act on other transforms.
819
+
820
+ Such transforms cannot be easily classified as "final" or "non-final",
821
+ because this characeteristic depends on the transforms that they embed.
822
+
823
+ They all implement `unroll`, but some may also implement a
824
+ "fast-track" `xform` that is applied in simple cases (e.g., when
825
+ the transform is not shared across tensors) for efficiency.
826
+
827
+ !!! addedin "![v0.5](https://img.shields.io/badge/v0.5-green) \
828
+ Added `SpecialTransform` class in `v0.5`."
829
+ Before this, special transforms inherited directly from `Transform`.
830
+ """
831
+ ...
832
+
833
+
834
+ class IdentityTransform(FinalTransform):
835
+ """Identity transform"""
836
+
837
+ def _xform(
838
+ self, x: Tensor, /, args: Arguments = NoArguments()
839
+ ) -> Returned:
840
+ return prepare_output(dict(input=x, output=x), self.returns)
841
+
842
+ def make_inverse(self) -> "IdentityTransform":
843
+ return self
844
+
845
+
846
+ class SequentialTransform(_SharedMixin, SpecialTransform):
847
+ """A sequence of transforms
848
+
849
+ !!! example
850
+ Sequences can be built explicitly, or simply by adding transforms
851
+ together:
852
+ ```python
853
+ t1 = MultFieldTransform()
854
+ t2 = GaussianNoiseTransform()
855
+ seq = SequentialTransform([t1, t2]) # explicit
856
+ seq = t1 + t2 # implicit
857
+ ```
858
+
859
+ Sequences can also be extended by addition:
860
+ ```python
861
+ seq += SmoothTransform()
862
+ ```
863
+
864
+ """
865
+
866
+ def __init__(self, transforms: tx.Sequence[Transform], **kwargs) -> None:
867
+ """
868
+ Parameters
869
+ ----------
870
+ transforms : list[Transform]
871
+ A list of transforms to apply sequentially.
872
+
873
+ Other Parameters
874
+ ------------------
875
+ shared
876
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
877
+ for details.
878
+ returns, append, prefix, include, exclude, consume
879
+ See [`Transform`][cornucopia.base.Transform] for details.
880
+ """
881
+ shared = kwargs.pop('shared', False)
882
+ super().__init__(**kwargs)
883
+ self.shared = self._prepare_shared(shared)
884
+ self.transforms = transforms
885
+
886
+ def _unroll(
887
+ self, x: Tensor, /,
888
+ max_depth: int = inf,
889
+ args: Arguments = NoArguments()
890
+ ) -> Transform:
891
+ if max_depth == 0:
892
+ return self
893
+ if self.is_final:
894
+ return self
895
+ # x = VirtualTensor.from_any(x, compute_stats=True)
896
+ trf = []
897
+ for t in self:
898
+ t = t.unroll(x, max_depth=max_depth-1, args=args)
899
+ x = t(x)
900
+ args = Arguments(x)
901
+ trf.append(t)
902
+ trf = SequentialTransform(trf, **self.get_prm())
903
+ return trf
904
+
905
+ @property
906
+ def is_final(self) -> bool:
907
+ return all(t.is_final for t in self)
908
+
909
+ def make_inverse(self) -> Transform:
910
+ return SequentialTransform([t.make_inverse() for t in reversed(self)])
911
+
912
+ def forward(self, *a, **k) -> Returned:
913
+ # If the entire sequence is shared across tensors, we use the
914
+ # behavior from `SharedMixin`, which is to call `unroll` on
915
+ # the first valid tensor, and apply the resulting transform to all
916
+ # tensors.
917
+ # Finalizing a sequence of transforms is a bit tricky, but sequences
918
+ # are not shared by default, so this should rarely be used.
919
+ # If the sequence is not shared (or onlt shared across channels),
920
+ # we simply apply the transforms sequentially.
921
+ return self._shared_forward(*a, **k, _fallback=self._forward_impl)
922
+
923
+ def _forward_impl(self, *args, **kwargs):
924
+ x = Arguments(*args, **kwargs)
925
+ for trf in self:
926
+ # NOTE:
927
+ # I do not propagate `returns`, as I don't think it makes
928
+ # sense for sequences.
929
+ with \
930
+ IncludeKeysTransform(trf, self.include), \
931
+ ExcludeKeysTransform(trf, self.exclude), \
932
+ ConsumeKeysTransform(trf, self.consume):
933
+ x = trf(x)
934
+ return x
935
+
936
+ def _xform(
937
+ self, x: Tensor, /, args: Arguments = NoArguments()
938
+ ) -> Returned:
939
+ # This should only be called when a Layer's `unroll` returns
940
+ # a `SequentialTransform` (i.e., it is created implictly under
941
+ # the hood, not explicitly by the user).
942
+ # In such cases, `shared=False` and hopefully we can just fallback
943
+ # to `forward()`.
944
+ #
945
+ # FIXME
946
+ # what happens if there's weird stuff in returns/include/exclude?
947
+ return self(x)
948
+
949
+ def __len__(self) -> int:
950
+ return len(self.transforms)
951
+
952
+ def __iter__(self) -> tx.Iterator[Transform]:
953
+ for t in self.transforms:
954
+ yield t
955
+
956
+ def __getitem__(self, item: tx.Union[int, slice]) -> Transform:
957
+ if isinstance(item, slice):
958
+ return SequentialTransform(self.transforms[item])
959
+ else:
960
+ return self.transforms[item]
961
+
962
+ def __repr__(self) -> str:
963
+ return f'{type(self).__name__}({repr(self.transforms)})'
964
+
965
+
966
+ class PerChannelTransform(SpecialTransform):
967
+ """Apply a different transform to each channel"""
968
+
969
+ def __init__(self, transforms: tx.Sequence[Transform], **kwargs) -> None:
970
+ """
971
+ Parameters
972
+ ----------
973
+ transforms : list[Transform]
974
+ A list of transforms to apply to each channel.
975
+ """
976
+ super().__init__(**kwargs)
977
+ self.transforms = transforms
978
+
979
+ def _unroll(
980
+ self, x: Tensor, /,
981
+ max_depth: int = inf,
982
+ args: Arguments = NoArguments()
983
+ ) -> Transform:
984
+ if max_depth == 0:
985
+ return self
986
+ trf = []
987
+ for i, t in enumerate(self.transforms):
988
+ if (
989
+ self.include is not None or
990
+ self.exclude or self.consume or self.returns
991
+ ):
992
+ # NOTE
993
+ # We cannot use context managers because they exit on
994
+ # return. Instead, we make a shallow copy of the
995
+ # transform and change its options. It is not an issue
996
+ # in most cases, as `unroll` often creates a new
997
+ # transform, but can be one when `max_depth < 2`.
998
+ t = copy(t)
999
+ t.exclude = IncludeKeysTransform.combine(self.include, t.include)
1000
+ t.include = ExcludeKeysTransform.combine(self.exclude, t.exclude)
1001
+ t.consume = ConsumeKeysTransform.combine(self.consume, t.consume)
1002
+ if self.returns:
1003
+ t.returns = self.returns
1004
+ t = t.unroll(x[i:i+1], max_depth-1, args=args)
1005
+ trf.append(t)
1006
+ prm = dict(self.get_prm())
1007
+ prm.pop('shared', None)
1008
+ trf = PerChannelTransform(trf, **prm)
1009
+ return trf
1010
+
1011
+ def _xform(
1012
+ self, x: Tensor, /, args: Arguments = NoArguments()
1013
+ ) -> Returned:
1014
+ results = []
1015
+ for i, t in enumerate(self.transforms):
1016
+ with \
1017
+ ReturningTransform(t, self.returns), \
1018
+ IncludeKeysTransform(t, self.include), \
1019
+ ExcludeKeysTransform(t, self.exclude), \
1020
+ ConsumeKeysTransform(t, self.consume):
1021
+ results.append(t(x[i:i+1]))
1022
+ return Returned(recursive_cat(results))
1023
+
1024
+ @property
1025
+ def is_final(self) -> bool:
1026
+ return all(t.is_final for t in self.transforms)
1027
+
1028
+
1029
+ class MaybeTransform(_SharedMixin, SpecialTransform):
1030
+ """Randomly apply a transform
1031
+
1032
+ !!! note "[`ctx.maybe`][cornucopia.ctx.maybe] is an alias for [`MaybeTransform`][cornucopia.special.MaybeTransform]"
1033
+
1034
+ !!! example "20% chance of adding noise"
1035
+ ```python
1036
+ import cornucopia as cc
1037
+ gauss = cc.GaussianNoiseTransform()
1038
+ ```
1039
+ Explicit call to the class:
1040
+ ```python
1041
+ img = cc.MaybeTransform(gauss, 0.2)(img)
1042
+ ```
1043
+ Implicit call using syntactic sugar:
1044
+ ```python
1045
+ img = (0.2 * gauss)(img)
1046
+ ```
1047
+ ```
1048
+
1049
+ !!! changedin "![v0.4](https://img.shields.io/badge/v0.4-yellow) \
1050
+ Default for `shared` changed from `False` to `True`"
1051
+ """
1052
+ def __init__(
1053
+ self,
1054
+ transform: Transform,
1055
+ prob: float = 0.5,
1056
+ *,
1057
+ shared: bool = True,
1058
+ **kwargs
1059
+ ) -> None:
1060
+ """
1061
+
1062
+ Parameters
1063
+ ----------
1064
+ transform : Transform
1065
+ A transform to randomly apply
1066
+ prob : float
1067
+ Probability to apply the transform
1068
+ shared : {'channels', 'tensors', 'channels+tensor', ''} | bool
1069
+ Roll the dice once for all input tensors
1070
+ """
1071
+ super().__init__(**kwargs)
1072
+ self.shared = self._prepare_shared(shared)
1073
+ self.subtransform = transform
1074
+ self.prob = prob
1075
+
1076
+ def throw_dice(self) -> bool:
1077
+ return random.random() > 1 - self.prob
1078
+
1079
+ def _unroll(
1080
+ self, x: Tensor, /,
1081
+ max_depth: int = float('inf'),
1082
+ args: Arguments = NoArguments()
1083
+ ) -> Transform:
1084
+ if max_depth == 0:
1085
+ return self
1086
+ if self.throw_dice():
1087
+ trf = self.subtransform
1088
+ if self.include is not None or self.exclude or self.consume:
1089
+ # NOTE
1090
+ # * I do not use context managers as they exit on return.
1091
+ # Context managers would work in most cases, as
1092
+ # `unroll` often creates a new transform, but it
1093
+ # can be a problem when `max_depth<2`. Better safe
1094
+ # than sorry,
1095
+ # * I do not propagate `returns`. I think it should be
1096
+ # dealt with by the subtransform.
1097
+ trf = copy(trf)
1098
+ trf.include = IncludeKeysTransform._combine(self.include, trf.include)
1099
+ trf.exclude = ExcludeKeysTransform._combine(self.exclude, trf.exclude)
1100
+ trf.consume = ConsumeKeysTransform._combine(self.consume, trf.consume)
1101
+ return trf.unroll(x, max_depth-1, args=args)
1102
+ else:
1103
+ return IdentityTransform(consume=self.consume)
1104
+
1105
+ def __repr__(self) -> str:
1106
+ s = f'{repr(self.subtransform)}?'
1107
+ if self.prob != 0.5:
1108
+ s += f'[{self.prob}]'
1109
+ return s
1110
+
1111
+
1112
+ class SwitchTransform(_SharedMixin, SpecialTransform):
1113
+ """Randomly choose a transform to apply
1114
+
1115
+ !!! note "[`ctx.switch`][cornucopia.ctx.switch] is an alias for [`SwitchTransform`][cornucopia.special.SwitchTransform]"
1116
+
1117
+ !!! example "Randomly apply either Gaussian or Chi noise"
1118
+ ```python
1119
+ import cornucopia as cc
1120
+ gauss = cc.GaussianNoiseTransform()
1121
+ chi = cc.ChiNoiseTransform()
1122
+ ```
1123
+ Explicit call to the class:
1124
+ ```python
1125
+ img = cc.SwitchTransform([gauss, chi])(img)
1126
+ ```
1127
+ Implicit call using syntactic sugar:
1128
+ ```python
1129
+ img = (gauss | chi)(img)
1130
+ ```
1131
+ Functional call:
1132
+ ```python
1133
+ img = cc.switch({gauss: 0.5, chi: 0.5})(img)
1134
+ ```
1135
+
1136
+ !!! changedin "![v0.4](https://img.shields.io/badge/v0.4-yellow) \
1137
+ Default for `shared` changed from `False` to `True`"
1138
+ """
1139
+
1140
+ def __init__(
1141
+ self,
1142
+ transforms: tx.Sequence[Transform],
1143
+ prob: cct.ScalarOrSequence[float] = 0,
1144
+ *,
1145
+ shared: cct.SharedT = True,
1146
+ **kwargs
1147
+ ) -> None:
1148
+ """
1149
+
1150
+ Parameters
1151
+ ----------
1152
+ transforms : list[Transform]
1153
+ A list of transforms to sample from
1154
+ prob : list[float]
1155
+ Probability of applying each transform
1156
+ shared : {'channels', 'tensors', 'channels+tensor', ''} | bool
1157
+ Roll the dice once for all input tensors
1158
+ """
1159
+ super().__init__(**kwargs)
1160
+ if isinstance(transforms, dict):
1161
+ if prob:
1162
+ raise ValueError(
1163
+ "When `transforms` is a dict, `prob` should not be provided."
1164
+ )
1165
+ prob = list(transforms.values())
1166
+ transforms = list(transforms.keys())
1167
+ self.shared = self._prepare_shared(shared)
1168
+ self.transforms = list(transforms)
1169
+ self.prob = prob or []
1170
+
1171
+ def _make_prob(self) -> tx.Sequence[float]:
1172
+ prob = ensure_list(self.prob, len(self.transforms), default=0)
1173
+ sumprob = sum(prob)
1174
+ if not sumprob:
1175
+ prob = [1/len(self.transforms)] * len(self.transforms)
1176
+ else:
1177
+ prob = [x / sumprob for x in prob]
1178
+ return prob
1179
+
1180
+ def throw_dice(self) -> int:
1181
+ prob = cumsum(self._make_prob())
1182
+ dice = random.random()
1183
+ for k in range(len(self.transforms)):
1184
+ if dice > 1 - prob[k]:
1185
+ return k
1186
+ return len(self.transforms) - 1
1187
+
1188
+ def _unroll(
1189
+ self, x: Tensor, /,
1190
+ max_depth: int = float('inf'),
1191
+ args: Arguments = NoArguments()
1192
+ ) -> Transform:
1193
+ if max_depth == 0:
1194
+ return self
1195
+ t = self.transforms[self.throw_dice()]
1196
+ t = t.unroll(x, max_depth-1, args=args)
1197
+ if self.include is not None or self.exclude or self.consume:
1198
+ # NOTE
1199
+ # We cannot use the context manager because it exits on
1200
+ # return. Instead, we make a shallow copy of the transform
1201
+ # and change its options.
1202
+ t = copy(t)
1203
+ t.include = IncludeKeysTransform._combine(self.include, t.include)
1204
+ t.exclude = ExcludeKeysTransform._combine(self.exclude, t.exclude)
1205
+ t.consume = ConsumeKeysTransform._combine(self.consume, t.consume)
1206
+ return t
1207
+
1208
+ def __or__(self, other: Transform) -> "SwitchTransform":
1209
+ if (isinstance(other, SwitchTransform)
1210
+ and not self.prob and not other.prob
1211
+ ):
1212
+ return SwitchTransform([*self.transforms, *other.transforms])
1213
+ else:
1214
+ return SwitchTransform([self, other])
1215
+
1216
+ def __ior__(self, other: Transform) -> "SwitchTransform":
1217
+ if (isinstance(other, SwitchTransform)
1218
+ and not self.prob and not other.prob
1219
+ ):
1220
+ self.transforms.append(other.transforms)
1221
+ return self
1222
+ else:
1223
+ return SwitchTransform([self, other])
1224
+
1225
+ def __repr__(self) -> str:
1226
+ if self.prob:
1227
+ prob = self._make_prob()
1228
+ s = [f'{p} * {str(t)}' for p, t in zip(prob, self.transforms)]
1229
+ else:
1230
+ s = [str(t) for t in self.transforms]
1231
+ s = ' | '.join(s)
1232
+ s = f'({s})'
1233
+ return s
1234
+
1235
+
1236
+ class IncludeKeysTransform(SpecialTransform):
1237
+ """
1238
+ Context manager for keys to include
1239
+
1240
+ !!! note "[`ctx.include`][cornucopia.ctx.include] is an alias for [`IncludeKeysTransform`][cornucopia.special.IncludeKeysTransform]"
1241
+
1242
+ !!! example "Use as a transform"
1243
+ ```python
1244
+ from cornucopia import IncludeKeysTransform
1245
+ newxform = IncludeKeysTransform(xform, "image)
1246
+ image, label = newxform(image=image, label=label)
1247
+ ```
1248
+
1249
+ !!! example "Use as a context manager `with as`"
1250
+ ```python
1251
+ from cornucopia import IncludeKeysTransform
1252
+ with IncludeKeysTransform(xform, "image") as newxform:
1253
+ image, label = newxform(image=image, label=label)
1254
+ ```
1255
+
1256
+ !!! example "Use as a context manager `with`"
1257
+ ```python
1258
+ from cornucopia import IncludeKeysTransform
1259
+ with IncludeKeysTransform(xform, "image"):
1260
+ image, label = xform(image=image, label=label)
1261
+ ```
1262
+
1263
+ !!! example "Use as a context manager (alias)"
1264
+ ```python
1265
+ from cornucopia import ctx
1266
+ with ctx.include(xform, "image") as newxform:
1267
+ image, label = newxform(image=image, label=label)
1268
+ ```
1269
+ """
1270
+
1271
+ def __init__(
1272
+ self,
1273
+ transform: Transform,
1274
+ keys: cct.IncludeT,
1275
+ union: bool = True
1276
+ ) -> None:
1277
+ """
1278
+ Parameters
1279
+ ----------
1280
+ transform : Transform
1281
+ Transform to apply
1282
+ keys : [sequence of] str
1283
+ Keys to include
1284
+ union : bool
1285
+ Include the union of what was already included and `keys`
1286
+ """
1287
+ super().__init__()
1288
+ if keys is not None:
1289
+ keys = ensure_list(keys)
1290
+ self.transform = transform
1291
+ self.keys = keys
1292
+ self.union = union
1293
+
1294
+ def forward(self, *a, **k) -> Returned:
1295
+ with self as transform:
1296
+ return transform.forward(*a, **k)
1297
+
1298
+ def _unroll(
1299
+ self, x: Tensor, /,
1300
+ max_depth: int = float('inf'),
1301
+ args: Arguments = NoArguments()
1302
+ ) -> Transform:
1303
+ if max_depth == 0:
1304
+ return self
1305
+ with self as trf:
1306
+ final_trf = trf.unroll(x, max_depth, args=args)
1307
+ with IncludeKeysTransform(final_trf) as final_final_trf:
1308
+ return final_final_trf
1309
+
1310
+ def make_inverse(self) -> Transform:
1311
+ with self as trf:
1312
+ inv_trf = trf.make_inverse()
1313
+ with IncludeKeysTransform(inv_trf) as final_inv_trf:
1314
+ return final_inv_trf
1315
+
1316
+ @classmethod
1317
+ def _combine(self, *includes, union: bool = True) -> tx.Sequence[str]:
1318
+ new_include, *includes = includes
1319
+ if union:
1320
+ for include in includes:
1321
+ if include is not None:
1322
+ if new_include is None:
1323
+ new_include = []
1324
+ new_include.extend(include)
1325
+ if new_include is not None:
1326
+ new_include = list(set(new_include))
1327
+ return new_include
1328
+
1329
+ def __enter__(self) -> Transform:
1330
+ old_include = self.transform.include
1331
+ new_include = self.keys
1332
+ new_include = self._combine(new_include, old_include, union=self.union)
1333
+ self.transform.include, self.include = new_include, old_include
1334
+ return self.transform
1335
+
1336
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
1337
+ self.transform.include = self.include
1338
+ delattr(self, 'include')
1339
+
1340
+
1341
+ class ExcludeKeysTransform(SpecialTransform):
1342
+ """
1343
+ Context manager for keys to exclude.
1344
+ Can also be used as a transform.
1345
+
1346
+ !!! note "[`ctx.exclude`][cornucopia.ctx.exclude] is an alias for [`ExcludeKeysTransform`][cornucopia.special.ExcludeKeysTransform]"
1347
+
1348
+ !!! example "Use as a transform"
1349
+ ```python
1350
+ from cornucopia import ExcludeKeysTransform
1351
+ newxform = ExcludeKeysTransform(xform, "image)
1352
+ image, label = newxform(image=image, label=label)
1353
+ ```
1354
+
1355
+ !!! example "Use as a context manager `with as`"
1356
+ ```python
1357
+ from cornucopia import ExcludeKeysTransform
1358
+ with ExcludeKeysTransform(xform, "image") as newxform:
1359
+ image, label = newxform(image=image, label=label)
1360
+ ```
1361
+
1362
+ !!! example "Use as a context manager `with`"
1363
+ ```python
1364
+ from cornucopia import ExcludeKeysTransform
1365
+ with ExcludeKeysTransform(xform, "image"):
1366
+ image, label = xform(image=image, label=label)
1367
+ ```
1368
+
1369
+ !!! example "Use as a context manager (alias)"
1370
+ ```python
1371
+ from cornucopia import ctx
1372
+ with ctx.exclude(xform, "image") as newxform:
1373
+ image, label = newxform(image=image, label=label)
1374
+ ```
1375
+ """
1376
+
1377
+ def __init__(
1378
+ self,
1379
+ transform: Transform,
1380
+ keys: cct.ExcludeT,
1381
+ union: bool = True
1382
+ ) -> None:
1383
+ """
1384
+ Parameters
1385
+ ----------
1386
+ transform : Transform
1387
+ Transform to apply
1388
+ keys : [sequence of] str
1389
+ Keys to include
1390
+ union : bool
1391
+ Exclude the union of what was already excluded and `keys`
1392
+ """
1393
+ super().__init__()
1394
+ if keys is not None:
1395
+ keys = ensure_list(keys)
1396
+ self.transform = transform
1397
+ self.keys = keys
1398
+ self.union = union
1399
+
1400
+ def forward(self, *a, **k) -> Returned:
1401
+ with self as transform:
1402
+ return transform.forward(*a, **k)
1403
+
1404
+ def _unroll(
1405
+ self, x: Tensor, /,
1406
+ max_depth: int = float('inf'),
1407
+ args: Arguments = NoArguments()
1408
+ ) -> Transform:
1409
+ if max_depth == 0:
1410
+ return self
1411
+ with self as trf:
1412
+ final_trf = trf.unroll(x, max_depth, args=args)
1413
+ with ExcludeKeysTransform(final_trf) as final_final_trf:
1414
+ return final_final_trf
1415
+
1416
+ def make_inverse(self) -> Transform:
1417
+ with self as trf:
1418
+ inv_trf = trf.make_inverse()
1419
+ with ExcludeKeysTransform(inv_trf) as final_inv_trf:
1420
+ return final_inv_trf
1421
+
1422
+ @classmethod
1423
+ def _combine(self, *excludes, union: bool = True) -> tx.Sequence[str]:
1424
+ new_exclude, *excludes = excludes
1425
+ if union:
1426
+ for exclude in excludes:
1427
+ if exclude is not None:
1428
+ if new_exclude is None:
1429
+ new_exclude = []
1430
+ new_exclude.extend(exclude)
1431
+ if new_exclude is not None:
1432
+ new_exclude = list(set(new_exclude))
1433
+ return new_exclude
1434
+
1435
+ def __enter__(self) -> Transform:
1436
+ old_exclude = self.transform.exclude
1437
+ new_exclude = self.keys
1438
+ new_exclude = self._combine(new_exclude, old_exclude, union=self.union)
1439
+ self.transform.exclude, self.exclude = new_exclude, old_exclude
1440
+ return self.transform
1441
+
1442
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
1443
+ self.transform.exclude = self.exclude
1444
+ delattr(self, 'exclude')
1445
+
1446
+
1447
+ class ConsumeKeysTransform(SpecialTransform):
1448
+ """
1449
+ Context manager for keys to consume.
1450
+ Can also be used as a transform.
1451
+
1452
+ !!! note "[`ctx.consume`][cornucopia.ctx.consume] is an alias for [`ConsumeKeysTransform`][cornucopia.special.ConsumeKeysTransform]"
1453
+
1454
+ !!! example "Use as a transform"
1455
+ ```python
1456
+ from cornucopia import ConsumeKeysTransform
1457
+ newxform = ConsumeKeysTransform(xform, "image)
1458
+ label = newxform(image=image, label=label)
1459
+ ```
1460
+
1461
+ !!! example "Use as a context manager `with as`"
1462
+ ```python
1463
+ from cornucopia import ConsumeKeysTransform
1464
+ with ConsumeKeysTransform(xform, "image") as newxform:
1465
+ label = newxform(image=image, label=label)
1466
+ ```
1467
+
1468
+ !!! example "Use as a context manager `with`"
1469
+ ```python
1470
+ from cornucopia import ConsumeKeysTransform
1471
+ with ConsumeKeysTransform(xform, "image"):
1472
+ label = xform(image=image, label=label)
1473
+ ```
1474
+
1475
+ !!! example "Use as a context manager (alias)"
1476
+ ```python
1477
+ from cornucopia import ctx
1478
+ with ctx.consume(xform, "image") as newxform:
1479
+ label = newxform(image=image, label=label)
1480
+ ```
1481
+
1482
+ !!! addedin "![v0.5](https://img.shields.io/badge/v0.5-green) \
1483
+ Added in `v0.5`"
1484
+ """
1485
+
1486
+ def __init__(
1487
+ self,
1488
+ transform: Transform,
1489
+ keys: cct.ConsumeT,
1490
+ union: bool = True
1491
+ ) -> None:
1492
+ """
1493
+ Parameters
1494
+ ----------
1495
+ transform : Transform
1496
+ Transform to apply
1497
+ keys : [sequence of] str
1498
+ Keys to include
1499
+ union : bool
1500
+ Consume the union of what was already consumed and `keys`
1501
+ """
1502
+ super().__init__()
1503
+ if keys is not None:
1504
+ keys = ensure_list(keys)
1505
+ self.transform = transform
1506
+ self.keys = keys
1507
+ self.union = union
1508
+
1509
+ def forward(self, *a, **k) -> Returned:
1510
+ with self as transform:
1511
+ return transform.forward(*a, **k)
1512
+
1513
+ def _unroll(
1514
+ self, x: Tensor, /,
1515
+ max_depth: int = float('inf'),
1516
+ args: Arguments = NoArguments()
1517
+ ) -> Transform:
1518
+ if max_depth == 0:
1519
+ return self
1520
+ with self as trf:
1521
+ final_trf = trf.unroll(x, max_depth, args=args)
1522
+ with ConsumeKeysTransform(final_trf) as final_final_trf:
1523
+ return final_final_trf
1524
+
1525
+ def make_inverse(self) -> Transform:
1526
+ with self as trf:
1527
+ inv_trf = trf.make_inverse()
1528
+ with ConsumeKeysTransform(inv_trf) as final_inv_trf:
1529
+ return final_inv_trf
1530
+
1531
+ @classmethod
1532
+ def _combine(self, *consumes, union: bool = True) -> tx.Sequence[str]:
1533
+ new_consume, *consumes = consumes
1534
+ if union:
1535
+ for consume in consumes:
1536
+ if consume is not None:
1537
+ if new_consume is None:
1538
+ new_consume = []
1539
+ new_consume.extend(consume)
1540
+ if new_consume is not None:
1541
+ new_consume = list(set(new_consume))
1542
+ return new_consume
1543
+
1544
+ def __enter__(self) -> Transform:
1545
+ old_consume = self.transform.consume
1546
+ new_consume = self.keys
1547
+ new_consume = self._combine(new_consume, old_consume, union=self.union)
1548
+ self.transform.consume, self.consume = new_consume, old_consume
1549
+ return self.transform
1550
+
1551
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
1552
+ self.transform.consume = self.consume
1553
+ delattr(self, 'consume')
1554
+
1555
+
1556
+ class SharedTransform(_SharedMixin, SpecialTransform):
1557
+ """
1558
+ Context manager for sharing transforms across channels / tensors.
1559
+ Can also be used as a transform.
1560
+
1561
+ !!! note "[`ctx.shared`][cornucopia.ctx.shared] is an alias for [`SharedTransform`][cornucopia.special.SharedTransform]"
1562
+
1563
+ !!! example "Use as a context manager (alias)"
1564
+ ```python
1565
+ from cornucopia import ctx
1566
+ with ctx.shared(xform, "channels") as newxform:
1567
+ image = newxform(image)
1568
+ ```
1569
+ """
1570
+
1571
+ def __init__(
1572
+ self, transform: Transform, mode: cct.SharedT = UNSET
1573
+ ) -> None:
1574
+ """
1575
+ Parameters
1576
+ ----------
1577
+ transform : Transform
1578
+ Transform to apply
1579
+ mode : {'channels', 'tensors', 'channels+tensor', ''} | bool
1580
+
1581
+ - `'channel'`: the same transform is applied to all channels
1582
+ in a tensor, but different transforms are used in different
1583
+ tensors.
1584
+ - `'tensors'`: the same transform is applied to all tensors,
1585
+ but with a different transform for each channel.
1586
+ - `'channels+tensors'` or `True`: the same transform is applied
1587
+ to all channels of all tensors.
1588
+ - `''` or `False`: A different transform is applied to each
1589
+ channel and each tensor.
1590
+
1591
+ """
1592
+ super().__init__()
1593
+ self.transform = transform
1594
+ self.mode = mode
1595
+
1596
+ def forward(self, *a, **k) -> Returned:
1597
+ with self as transform:
1598
+ return transform.forward(*a, **k)
1599
+
1600
+ def __enter__(self) -> Transform:
1601
+ self.hasattr = hasattr(self.transform, 'shared')
1602
+ self.saved_mode = getattr(self.transform, 'shared', None)
1603
+ if self.mode is not UNSET:
1604
+ self.transform.shared = self.mode
1605
+ return self.transform
1606
+
1607
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
1608
+ if self.hasattr:
1609
+ self.transform.shared = self.saved_mode
1610
+ elif hasattr(self.transform, 'shared'):
1611
+ delattr(self.transform, 'shared')
1612
+ delattr(self, 'hasattr')
1613
+ delattr(self, 'saved_mode')
1614
+
1615
+
1616
+ class ReturningTransform(SpecialTransform):
1617
+ """
1618
+ Context manager for sharing transforms across channels / tensors
1619
+
1620
+ !!! note "[`ctx.returns`][cornucopia.ctx.returns] is an alias for [`ReturningTransform`][cornucopia.special.ReturningTransform]"
1621
+
1622
+ !!! example "Use as a context manager (alias)"
1623
+ ```python
1624
+ from cornucopia import ctx
1625
+ with ctx.returns(xform, "channels") as newxform:
1626
+ image = newxform(image)
1627
+ ```
1628
+ """
1629
+
1630
+ def __init__(
1631
+ self,
1632
+ transform: Transform,
1633
+ returns: tx.Optional[cct.ReturnsT] = None
1634
+ ) -> None:
1635
+ super().__init__()
1636
+ self.transform = transform
1637
+ self.returns = returns
1638
+
1639
+ def __enter__(self) -> Transform:
1640
+ self.saved_returns = self.transform.returns
1641
+ self.transform.returns = self.returns
1642
+ return self.transform
1643
+
1644
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
1645
+ self.transform.returns = self.saved_returns
1646
+ delattr(self, 'saved_returns')
1647
+
1648
+
1649
+ class MappedTransform(SpecialTransform):
1650
+ """
1651
+ Transforms that are applied to specific positional or arguments
1652
+
1653
+ !!! note "[`ctx.map`][cornucopia.ctx.map] is an alias for [`MappedTransform`][cornucopia.special.MappedTransform]"
1654
+
1655
+ !!! example
1656
+ ```python
1657
+ img = torch.randn([1, 32, 32])
1658
+ seg = torch.randn([3, 32, 32]).softmax(0)
1659
+
1660
+ # positional variant
1661
+ trf = MappedTransform(GaussianNoise(), None)
1662
+ img, seg = trf(img, seg)
1663
+
1664
+ # keyword variant
1665
+ trf = MappedTransform(image=GaussianNoise())
1666
+ img, seg = trf(image=img, label=seg)
1667
+
1668
+ # alternative version
1669
+ dat = {'img': torch.randn([1, 32, 32]),
1670
+ 'seg': torch.randn([3, 32, 32]).softmax(0)}
1671
+ dat = MappedTransform(img=GaussianNoise(), nested=True)(dat)
1672
+ ```
1673
+ """
1674
+
1675
+ def __init__(
1676
+ self,
1677
+ *mapargs,
1678
+ nested: bool = False,
1679
+ default: tx.Optional[Transform] = None,
1680
+ **mapkwargs
1681
+ ) -> None:
1682
+ """
1683
+
1684
+ Parameters
1685
+ ----------
1686
+ mapargs : tuple[Transform]
1687
+ Transform to apply to positional arguments
1688
+ mapkwargs : dict[str, Transform]
1689
+ Transform to apply to keyword arguments
1690
+ nested : bool, default=False
1691
+ Recursively traverse the inputs until we find matching
1692
+ dictionaries. Only `mapkwargs` are accepted if "nested".
1693
+ default : Transform
1694
+ Transform to apply if nothing is specifically mapped
1695
+ """
1696
+ super().__init__(
1697
+ shared=mapkwargs.pop('shared', False),
1698
+ include=mapkwargs.pop('include', None),
1699
+ exclude=mapkwargs.pop('exclude', None),
1700
+ )
1701
+ self.mapargs = mapargs
1702
+ self.mapkwargs = mapkwargs
1703
+ self.nested = nested
1704
+ self.default = default
1705
+ if nested and mapargs:
1706
+ raise ValueError(
1707
+ 'Cannot have both `nested` and positional transforms'
1708
+ )
1709
+
1710
+ def forward(self, *args, **kwargs) -> Returned:
1711
+
1712
+ if self.include is not None or self.exclude or self.consume:
1713
+ def wrap(f: tx.Callable) -> tx.Callable:
1714
+ if not f:
1715
+ return f
1716
+ def ff(*a, **k):
1717
+ # NOTE
1718
+ # I do not propagate `returns`. I think it should
1719
+ # be dealt with by the subtransforms.
1720
+ with \
1721
+ IncludeKeysTransform(f, self.include), \
1722
+ ExcludeKeysTransform(f, self.exclude), \
1723
+ ConsumeKeysTransform(f, self.consume):
1724
+ return f(*a, **k)
1725
+ return ff
1726
+ else:
1727
+ def wrap(f: tx.Callable) -> tx.Callable:
1728
+ return f
1729
+
1730
+ # If the input is a wrapped `Arguments`, unwrap it and recurse.
1731
+ arguments = Arguments(*args, **kwargs)
1732
+ if not arguments:
1733
+ return None
1734
+ args, kwargs = arguments.to_args_kwargs()
1735
+
1736
+ def default(x):
1737
+ # Default transform to apply if nothing is mapped.
1738
+ if torch.is_tensor(x):
1739
+ return self.default(x) if self.default else x
1740
+ else:
1741
+ return self.forward(x) if self.nested else x
1742
+
1743
+ if args:
1744
+ # Apply each transform to its corresponding positional argument
1745
+ mapargs = tuple(wrap(f) for f in self.mapargs)
1746
+ mapargs += (default,) * max(0, len(args) - len(mapargs))
1747
+ args = tuple(f(a) if f else a for f, a in zip(mapargs, args))
1748
+
1749
+ if kwargs:
1750
+ # Apply each transform to its corresponding keyword argument
1751
+ mapkwargs = {k: wrap(f) for k, f in self.mapkwargs.items()}
1752
+
1753
+ def func(key):
1754
+ return mapkwargs.get(key, default) or (lambda x: x)
1755
+
1756
+ kwargs = {key: func(key)(value) for key, value in kwargs.items()}
1757
+
1758
+ # If more than a single input argument, wrap them in `Arguments`
1759
+ if kwargs or len(args) > 1:
1760
+ return Arguments(*args, **kwargs)
1761
+ return args[0]
1762
+
1763
+ def __repr__(self) -> str:
1764
+ s = []
1765
+ for v in self.mapargs:
1766
+ s += [f'{v}']
1767
+ for k, v in self.mapkwargs.items():
1768
+ s += [f'{k}={v}']
1769
+ s = ', '.join(s)
1770
+ return f'{type(self).__name__}({s})'
1771
+
1772
+
1773
+ class RandomizedTransform(SpecialTransform, NonFinalTransform):
1774
+ """
1775
+ Transform generated by randomizing some parameters of another transform.
1776
+
1777
+ !!! note "[`ctx.randomize`][cornucopia.ctx.randomize] is an alias for [`RandomizedTransform`][cornucopia.special.RandomizedTransform]"
1778
+
1779
+ !!! example "Gaussian noise with randomized variance"
1780
+ Object call
1781
+ ```python
1782
+ import cornucopia as cc
1783
+ hypernoise = cc.RandomizedTransform(cc.GaussianNoise, [cc.Uniform()])
1784
+ img = hypernoise(img)
1785
+ ```
1786
+
1787
+ Delayed call
1788
+ ```python
1789
+ import cornucopia as cc
1790
+ MyRandomNoise = cc.randomize(cc.GaussianNoise)
1791
+ hypernoise = MyRandomNoise(cc.Uniform())
1792
+ img = hypernoise(img)
1793
+ ```
1794
+
1795
+ """
1796
+
1797
+ class Delayed:
1798
+ # Temproary parameter holder for delayed calls
1799
+ def __init__(self, transform: Transform, **kwargs) -> None:
1800
+ self.transform = transform
1801
+ self.kwargs = kwargs
1802
+
1803
+ def __call__(self, *args, **kwargs) -> "RandomizedTransform":
1804
+ return RandomizedTransform(
1805
+ self.transform, args, kwargs, **self.kwargs)
1806
+
1807
+ def __new__(cls, *args, **kwargs) -> "RandomizedTransform":
1808
+ if cls is RandomizedTransform:
1809
+ return cls._base_new(*args, **kwargs)
1810
+ return super().__new__(cls)
1811
+
1812
+ @classmethod
1813
+ def _base_new(
1814
+ cls,
1815
+ transform: Transform,
1816
+ sample: tuple = tuple(),
1817
+ ksample: dict = dict(),
1818
+ *,
1819
+ shared: tx.Union[bool, str] = False,
1820
+ **kwargs
1821
+ ) -> "RandomizedTransform":
1822
+ assert cls is RandomizedTransform
1823
+ if not sample and not ksample:
1824
+ # If no arguments are passed, it means that the user calls
1825
+ # this in "delayed/functional" mode. In that case, we return
1826
+ # a callable object that returns the constructed instance
1827
+ # using the call-time arguments.
1828
+ return cls.Delayed(transform, shared=shared, **kwargs)
1829
+ # Otherwise, we're in object mode and we instantiate the
1830
+ # randomized object.
1831
+ return super().__new__(cls)
1832
+
1833
+ def __init__(
1834
+ self,
1835
+ transform: Transform,
1836
+ sample: tuple = tuple(),
1837
+ ksample: dict = dict(),
1838
+ *,
1839
+ shared: tx.Union[bool, str] = False,
1840
+ **kwargs
1841
+ ) -> None:
1842
+ """
1843
+ Parameters
1844
+ ----------
1845
+ transform : callable(...) -> Transform
1846
+ A Transform subclass or a function that constructs a Transform.
1847
+ sample : [list or dict of] callable
1848
+ A collection of functions that generate parameter values provided
1849
+ to `transform`. Can be args-like or kwargs-like arguments.
1850
+ ksample : dict[callable]
1851
+ Must be kwargs-like arguments.
1852
+
1853
+ Other Parameters
1854
+ ----------------
1855
+ shared : {'channels', 'tensors', 'channels+tensors', ''} | bool
1856
+ Share random parameters across tensors and/or channels
1857
+ """
1858
+ super().__init__(shared=shared, **kwargs)
1859
+ self.sample = sample
1860
+ self.ksample = ksample
1861
+ self.subtransform = transform
1862
+
1863
+ def _unroll(
1864
+ self, x: Tensor, /,
1865
+ max_depth: int = inf,
1866
+ args: Arguments = NoArguments()
1867
+ ) -> Transform:
1868
+ if max_depth == 0:
1869
+ return self
1870
+ if 'channels' not in self.shared and len(x) > 1:
1871
+ return self.make_per_channel(x, max_depth, args=args)
1872
+
1873
+ args = []
1874
+ kwargs = {}
1875
+
1876
+ if isinstance(self.sample, (list, tuple)):
1877
+ args += [
1878
+ f() if isinstance(f, Sampler) else f
1879
+ for f in self.sample
1880
+ ]
1881
+
1882
+ elif hasattr(self.sample, 'items'):
1883
+ kwargs.update({
1884
+ k: f() if isinstance(f, Sampler) else f
1885
+ for k, f in self.sample.items()
1886
+ })
1887
+
1888
+ else:
1889
+ args += [
1890
+ f() if isinstance(f, Sampler) else f
1891
+ for f in (self.sample,)
1892
+ ]
1893
+
1894
+ if self.ksample:
1895
+ kwargs.update({
1896
+ k: f() if isinstance(f, Sampler) else f
1897
+ for k, f in self.ksample.items()
1898
+ })
1899
+
1900
+ # Propagate general options (include, exclude),
1901
+ # unless they are already set by sample/ksample.
1902
+ for key, value in self.get_prm().items():
1903
+ kwargs.setdefault(key, value)
1904
+
1905
+ # Build transform with fixed parameters, and recurse.
1906
+ xform = self.subtransform(*args, **kwargs)
1907
+ xform = xform.unroll(x, max_depth-1, args=args)
1908
+ return xform
1909
+
1910
+ def __repr__(self) -> str:
1911
+ if type(self) is RandomizedTransform:
1912
+ xform = self.subtransform
1913
+ if isinstance(xform, type) and issubclass(xform, Transform):
1914
+ return f'Randomized{xform.__name__}()'
1915
+ return super().__repr__()