cornucopia 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cornucopia/__init__.py +73 -0
  2. cornucopia/base.py +1915 -0
  3. cornucopia/baseutils.py +575 -0
  4. cornucopia/contrast.py +260 -0
  5. cornucopia/ctx.py +25 -0
  6. cornucopia/fov.py +707 -0
  7. cornucopia/geometric.py +2068 -0
  8. cornucopia/intensity.py +1358 -0
  9. cornucopia/io.py +161 -0
  10. cornucopia/kspace.py +505 -0
  11. cornucopia/labels.py +1872 -0
  12. cornucopia/noise.py +508 -0
  13. cornucopia/psf.py +463 -0
  14. cornucopia/qmri.py +1288 -0
  15. cornucopia/random.py +1480 -0
  16. cornucopia/special.py +159 -0
  17. cornucopia/synth.py +708 -0
  18. cornucopia/tests/__init__.py +0 -0
  19. cornucopia/tests/test_backward_geometric.py +173 -0
  20. cornucopia/tests/test_backward_intensity.py +243 -0
  21. cornucopia/tests/test_backward_kspace.py +115 -0
  22. cornucopia/tests/test_backward_noise.py +169 -0
  23. cornucopia/tests/test_backward_psf.py +142 -0
  24. cornucopia/tests/test_backward_qmri.py +249 -0
  25. cornucopia/tests/test_backward_random.py +44 -0
  26. cornucopia/tests/test_backward_synth.py +72 -0
  27. cornucopia/tests/test_base.py +401 -0
  28. cornucopia/tests/test_geometric.py +26 -0
  29. cornucopia/tests/test_intensity.py +9 -0
  30. cornucopia/tests/test_random.py +722 -0
  31. cornucopia/tests/test_run_contrast.py +28 -0
  32. cornucopia/tests/test_run_fov.py +132 -0
  33. cornucopia/tests/test_run_geometric.py +157 -0
  34. cornucopia/tests/test_run_intensity.py +192 -0
  35. cornucopia/tests/test_run_kspace.py +70 -0
  36. cornucopia/tests/test_run_labels.py +224 -0
  37. cornucopia/tests/test_run_noise.py +127 -0
  38. cornucopia/tests/test_run_psf.py +115 -0
  39. cornucopia/tests/test_run_qmri.py +114 -0
  40. cornucopia/tests/test_run_synth.py +67 -0
  41. cornucopia/typing.py +97 -0
  42. cornucopia/utils/__init__.py +0 -0
  43. cornucopia/utils/b0.py +745 -0
  44. cornucopia/utils/bounds.py +412 -0
  45. cornucopia/utils/compat.py +47 -0
  46. cornucopia/utils/conv.py +305 -0
  47. cornucopia/utils/gmm.py +169 -0
  48. cornucopia/utils/indexing.py +911 -0
  49. cornucopia/utils/io.py +258 -0
  50. cornucopia/utils/jit.py +128 -0
  51. cornucopia/utils/kernels.py +288 -0
  52. cornucopia/utils/morpho.py +234 -0
  53. cornucopia/utils/mrf.py +574 -0
  54. cornucopia/utils/padding.py +173 -0
  55. cornucopia/utils/patch.py +302 -0
  56. cornucopia/utils/pool.py +282 -0
  57. cornucopia/utils/py.py +348 -0
  58. cornucopia/utils/smart_inplace.py +163 -0
  59. cornucopia/utils/version.py +57 -0
  60. cornucopia/utils/warps.py +606 -0
  61. cornucopia-0.0.0.dist-info/METADATA +92 -0
  62. cornucopia-0.0.0.dist-info/RECORD +65 -0
  63. cornucopia-0.0.0.dist-info/WHEEL +5 -0
  64. cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
  65. cornucopia-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1358 @@
1
+ """This module contains transforms that operate on image intensities."""
2
+ __all__ = [
3
+ 'AddValueTransform',
4
+ 'MulValueTransform',
5
+ 'AddMulTransform',
6
+ 'ReturnValueTransform',
7
+ 'FillValueTransform',
8
+ 'ClipTransform',
9
+ 'BaseFieldTransform',
10
+ 'AddFieldTransform',
11
+ 'MulFieldTransform',
12
+ 'RandomAddFieldTransform',
13
+ 'RandomMulFieldTransform',
14
+ 'RandomSlicewiseMulFieldTransform',
15
+ 'RandomMulTransform',
16
+ 'RandomAddTransform',
17
+ 'RandomAddMulTransform',
18
+ 'GammaFinalTransform',
19
+ 'GammaTransform',
20
+ 'RandomGammaTransform',
21
+ 'ZTransform',
22
+ 'QuantileTransform',
23
+ 'MinMaxTransform',
24
+ ]
25
+ # stdlib
26
+ import math
27
+ from math import inf
28
+ from numbers import Number
29
+
30
+ # dependencies
31
+ import interpol
32
+ import torch
33
+ import typing_extensions as tx
34
+ from torch import Tensor
35
+ from torch.nn.functional import interpolate
36
+
37
+ # internals
38
+ from .baseutils import Returned, prepare_output
39
+ from .base import Transform, FinalTransform, NonFinalTransform
40
+ from .special import RandomizedTransform, SequentialTransform
41
+ from .random import Sampler, Uniform, RandInt, Fixed, make_range
42
+ from .utils.py import ensure_list, positive_index
43
+ from .utils.smart_inplace import add_, mul_, div_, pow_
44
+ from .utils.compat import clamp, clamp_
45
+ from . import typing as cct
46
+
47
+ # typing
48
+ _NumberOrTensor = tx.Union[Number, Tensor]
49
+ _UnaryOperator = tx.Callable[[Tensor], Tensor]
50
+ _BinaryOperator = tx.Callable[[Tensor, _NumberOrTensor], Tensor]
51
+
52
+
53
+ class OpConstTransform(FinalTransform):
54
+ """Base class for arithmetic operations with a constant value"""
55
+
56
+ _op: tx.Optional[_BinaryOperator] = None
57
+ _inv: tx.Dict[_BinaryOperator, _UnaryOperator] = {
58
+ torch.add: lambda x: -x,
59
+ torch.mul: lambda x: 1/x,
60
+ }
61
+
62
+ def __init__(
63
+ self,
64
+ value: _NumberOrTensor,
65
+ op: tx.Optional[_BinaryOperator] = None,
66
+ value_name: str = 'value',
67
+ **kwargs
68
+ ):
69
+ """
70
+ Parameters
71
+ ----------
72
+ value : number or tensor
73
+ right-hand side of the operation
74
+ op : {torch.add, torch.mul}
75
+ Arithmetic operation
76
+ value_name : str
77
+ Name used when returning the rhs value
78
+
79
+ Other Parameters
80
+ ----------------
81
+ returns, append, prefix, include, exclude, consume
82
+ See [`Transform`][cornucopia.base.Transform] for details.
83
+ """
84
+ super().__init__(**kwargs)
85
+ self.value = value
86
+ self.op = op or self._op
87
+ self.value_name = value_name
88
+
89
+ def __getattr__(self, name: str) -> _NumberOrTensor:
90
+ if name == self.__dict__.get("value_name"):
91
+ return self.__dict__.get("value")
92
+ return super().__getattr__(name)
93
+
94
+ def __setattr__(self, name: str, value: _NumberOrTensor) -> None:
95
+ if name == self.__dict__.get("value_name"):
96
+ name = 'value'
97
+ super().__setattr__(name, value)
98
+
99
+ def _xform(self, x: Tensor) -> Returned:
100
+ value = self.value
101
+ if torch.is_tensor(value):
102
+ value = value.to(x)
103
+ y = self.op(x, value)
104
+ return prepare_output(
105
+ {'input': x, 'output': y, self.value_name: value}, self.returns
106
+ )
107
+
108
+ def make_inverse(self) -> Transform:
109
+ inv = self._inv[self.op]
110
+ return type(self)(
111
+ inv(self.value), **self.get_prm(), value_name=self.value_name
112
+ )
113
+
114
+
115
+ class AddValueTransform(OpConstTransform):
116
+ """Add a constant value"""
117
+ _op: _BinaryOperator = torch.add
118
+
119
+
120
+ class MulValueTransform(OpConstTransform):
121
+ """Multiply with a constant value"""
122
+ _op: _BinaryOperator = torch.mul
123
+
124
+
125
+ class FillValueTransform(FinalTransform):
126
+ """Fills the tensor with a value inside a mask"""
127
+
128
+ def __init__(
129
+ self,
130
+ mask: Tensor,
131
+ value: _NumberOrTensor,
132
+ mask_name: str = 'mask',
133
+ value_name: str = 'value',
134
+ **kwargs
135
+ ) -> None:
136
+ """
137
+ Parameters
138
+ ----------
139
+ mask : tensor
140
+ Mask of voxels in which to set the value
141
+ value : number or tensor
142
+ right-hand side of the operation
143
+ mask_name : str
144
+ Name used when returning the mask
145
+ value_name : str
146
+ Name used when returning the rhs value
147
+
148
+ Other Parameters
149
+ ----------------
150
+ returns, append, prefix, include, exclude, consume
151
+ See [`Transform`][cornucopia.base.Transform] for details.
152
+ """
153
+ super().__init__(**kwargs)
154
+ self.mask = mask
155
+ self.value = value
156
+ self.mask_name = mask_name
157
+ self.value_name = value_name
158
+
159
+ def _xform(self, x: Tensor) -> Returned:
160
+ mask, value = self.mask, self.value
161
+ mask = mask.to(x.device)
162
+ if torch.is_tensor(value):
163
+ value = value.to(x)
164
+ y = x.masked_fill(mask, value)
165
+ return prepare_output(
166
+ {'input': x, 'output': y,
167
+ self.mask_name: mask,
168
+ self.value_name: value},
169
+ self.returns
170
+ )
171
+
172
+
173
+ class ReturnValueTransform(FinalTransform):
174
+ """Fills the tensor with a value inside a mask"""
175
+
176
+ def __init__(
177
+ self,
178
+ value: _NumberOrTensor,
179
+ value_name: str = 'output',
180
+ dtype: tx.Optional[torch.dtype] = None,
181
+ **kwargs
182
+ ) -> None:
183
+ """
184
+ Parameters
185
+ ----------
186
+ value : number or tensor
187
+ right-hand side of the operation
188
+ value_name : str
189
+ Name used when returning the rhs value
190
+
191
+ Other Parameters
192
+ ----------------
193
+ returns, append, prefix, include, exclude, consume
194
+ See [`Transform`][cornucopia.base.Transform] for details.
195
+ """
196
+ super().__init__(**kwargs)
197
+ self.value = value
198
+ self.value_name = value_name
199
+ self.dtype = dtype
200
+
201
+ def __getattr__(self, name: str) -> _NumberOrTensor:
202
+ if name == self.__dict__.get("value_name"):
203
+ return self.__dict__.get("value")
204
+ return super().__getattr__(name)
205
+
206
+ def __setattr__(self, name: str, value: _NumberOrTensor) -> None:
207
+ if name == self.__dict__.get("value_name"):
208
+ name = 'value'
209
+ super().__setattr__(name, value)
210
+
211
+ def _xform(self, x: Tensor) -> Returned:
212
+ dtype = self.dtype or x.dtype
213
+ return torch.as_tensor(self.value, dtype=dtype, device=x.device)
214
+
215
+
216
+ class AddMulTransform(FinalTransform):
217
+ """Constant intensity affine transform: `y = x * slope + offset`"""
218
+
219
+ def __init__(
220
+ self,
221
+ slope: _NumberOrTensor = 1,
222
+ offset: _NumberOrTensor = 0,
223
+ **kwargs
224
+ ) -> None:
225
+ """
226
+ Parameters
227
+ ----------
228
+ slope : number or tensor
229
+ Affine slope
230
+ offset : number or tensor
231
+ Affine offset
232
+
233
+ Other Parameters
234
+ ----------------
235
+ returns, append, prefix, include, exclude, consume
236
+ See [`Transform`][cornucopia.base.Transform] for details.
237
+ """
238
+ super().__init__(**kwargs)
239
+ self.slope = slope
240
+ self.offset = offset
241
+
242
+ def _xform(self, x: Tensor) -> Returned:
243
+ slope, offset = self.slope, self.offset
244
+ if torch.is_tensor(slope):
245
+ slope = slope.to(x)
246
+ if torch.is_tensor(offset):
247
+ offset = offset.to(x)
248
+ y = slope * x + offset
249
+ return prepare_output(
250
+ {'input': x, 'output': y, 'slope': slope, 'offset': offset},
251
+ self.returns
252
+ )
253
+
254
+ def make_inverse(self) -> 'AddMulTransform':
255
+ return AddMulTransform(
256
+ 1/self.slope, -self.offset/self.slope, **self.get_prm()
257
+ )
258
+
259
+
260
+ class ClipTransform(FinalTransform):
261
+ """Clip extremum values"""
262
+
263
+ def __init__(
264
+ self,
265
+ vmin: tx.Optional[_NumberOrTensor] = None,
266
+ vmax: tx.Optional[_NumberOrTensor] = None,
267
+ **kwargs
268
+ ) -> None:
269
+ """
270
+ Parameters
271
+ ----------
272
+ vmin : number or tensor, optional
273
+ Min value
274
+ vmax : number or tensor, optional
275
+ Max value
276
+
277
+ Other Parameters
278
+ ----------------
279
+ returns, append, prefix, include, exclude, consume
280
+ See [`Transform`][cornucopia.base.Transform] for details.
281
+ """
282
+ super().__init__(**kwargs)
283
+ self.vmin = vmin
284
+ self.vmax = vmax
285
+
286
+ def _xform(self, x: Tensor) -> Returned:
287
+ vmin, vmax = self.vmin, self.vmax
288
+ if torch.is_tensor(vmin):
289
+ vmin = vmin.to(x)
290
+ if torch.is_tensor(vmax):
291
+ vmax = vmax.to(x)
292
+ y = clamp(x, vmin, vmax)
293
+ return prepare_output(
294
+ {'input': x, 'output': y, 'vmin': vmin, 'vmax': vmax},
295
+ self.returns
296
+ )
297
+
298
+
299
+ class RandomMulTransform(RandomizedTransform):
300
+ """
301
+ Random multiplicative transform.
302
+ """
303
+
304
+ Final = Next = MulValueTransform
305
+ """The transform type returned by `unroll`, `next` and `final`."""
306
+
307
+ def __init__(
308
+ self,
309
+ value: tx.Union[Sampler, float, tx.Tuple[float, float]] = (0.5, 2),
310
+ *,
311
+ shared: cct.SharedT = False,
312
+ **kwargs
313
+ ) -> None:
314
+ """
315
+ Parameters
316
+ ----------
317
+ value : Sampler | [pair of] float
318
+ Bound for multiplicative value
319
+
320
+ Other Parameters
321
+ ----------------
322
+ shared
323
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
324
+ for details.
325
+ returns, append, prefix, include, exclude, consume
326
+ See [`Transform`][cornucopia.base.Transform] for details.
327
+ """
328
+ super().__init__(
329
+ MulValueTransform,
330
+ Uniform.make(make_range(0, value)),
331
+ shared=shared,
332
+ **kwargs
333
+ )
334
+
335
+
336
+ class RandomAddTransform(RandomizedTransform):
337
+ """
338
+ Random additive transform.
339
+ """
340
+
341
+ Final = Next = AddValueTransform
342
+ """The transform type returned by `unroll`, `next` and `final`."""
343
+
344
+ def __init__(
345
+ self,
346
+ value: tx.Union[Sampler, float, tx.Tuple[float, float]] = 1,
347
+ *,
348
+ shared: cct.SharedT = False,
349
+ **kwargs
350
+ ) -> None:
351
+ """
352
+ Parameters
353
+ ----------
354
+ value : Sampler | [pair of] float
355
+ Bound for additive value
356
+
357
+ Other Parameters
358
+ ----------------
359
+ shared
360
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
361
+ for details.
362
+ returns, append, prefix, include, exclude, consume
363
+ See [`Transform`][cornucopia.base.Transform] for details.
364
+ """
365
+ super().__init__(
366
+ AddValueTransform,
367
+ Uniform.make(make_range(value)),
368
+ shared=shared,
369
+ **kwargs
370
+ )
371
+
372
+
373
+ class RandomAddMulTransform(RandomizedTransform):
374
+ """
375
+ Random intensity affine transform.
376
+ """
377
+
378
+ Final = Next = AddMulTransform
379
+ """The transform type returned by `unroll`, `next` and `final`."""
380
+
381
+ def __init__(
382
+ self,
383
+ slope: tx.Union[Sampler, float, tx.Tuple[float, float]] = 1,
384
+ offset: tx.Union[Sampler, float, tx.Tuple[float, float]] = 0.5,
385
+ *,
386
+ shared: cct.SharedT = False,
387
+ **kwargs
388
+ ) -> None:
389
+ """
390
+
391
+ Parameters
392
+ ----------
393
+ slope : Sampler | [pair of] float
394
+ Bound for slope
395
+ offset : Sampler | [pair of] float
396
+ Bound for offset
397
+
398
+ Other Parameters
399
+ ----------------
400
+ shared
401
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
402
+ for details.
403
+ returns, append, prefix, include, exclude, consume
404
+ See [`Transform`][cornucopia.base.Transform] for details.
405
+ """
406
+ super().__init__(
407
+ AddMulTransform,
408
+ (Uniform.make(make_range(slope)),
409
+ Uniform.make(make_range(offset))),
410
+ shared=shared,
411
+ **kwargs
412
+ )
413
+
414
+
415
+ class SplineUpsampleTransform(FinalTransform):
416
+ """Upsample a field using spline interpolation"""
417
+
418
+ def __init__(
419
+ self,
420
+ order: int = 3,
421
+ prefilter: bool = False,
422
+ **kwargs
423
+ ) -> None:
424
+ """
425
+ Parameters
426
+ ----------
427
+ order : int
428
+ Spline interpolation order
429
+ prefilter : bool
430
+ Spline prefiltering
431
+ (True for interpolation, False for spline evaluation)
432
+
433
+ Other Parameters
434
+ ----------------
435
+ returns, append, prefix, include, exclude, consume
436
+ See [`Transform`][cornucopia.base.Transform] for details.
437
+ """
438
+ super().__init__(**kwargs)
439
+ self.order = order
440
+ self.prefilter = prefilter
441
+
442
+ def _xform(self, x: Tensor) -> Tensor:
443
+ fullshape = x.shape[1:]
444
+ if self.order == 1:
445
+ mode = ('trilinear' if len(fullshape) == 3 else
446
+ 'bilinear' if len(fullshape) == 2 else
447
+ 'linear')
448
+ y = interpolate(
449
+ x.unsqueeze(0), fullshape, mode=mode,
450
+ align_corners=True
451
+ ).squeeze(-0)
452
+ else:
453
+ y = interpol.resize(
454
+ x, shape=fullshape, interpolation=self.order,
455
+ prefilter=self.prefilter
456
+ )
457
+ return y
458
+
459
+
460
+ class BaseFieldTransform(NonFinalTransform):
461
+ """Base class for transforms that sample a smooth field"""
462
+
463
+ Final = Next = AddValueTransform
464
+ """The transform type returned by `unroll`, `next` and `final`."""
465
+
466
+ value_name: str = 'field'
467
+
468
+ def __init__(
469
+ self,
470
+ shape: tx.Union[int, tx.Sequence[int]] = 5,
471
+ vmin: float = 0 ,
472
+ vmax: float = 1,
473
+ order: int = 3,
474
+ slice: tx.Optional[int] = None,
475
+ thickness: tx.Optional[int] = None,
476
+ *,
477
+ shared: cct.SharedT = False,
478
+ **kwargs
479
+ ) -> None:
480
+ """
481
+
482
+ Parameters
483
+ ----------
484
+ shape : [list of] int
485
+ Number of spline control points
486
+ vmin : float
487
+ Minimum value
488
+ vmax : float
489
+ Maximum value
490
+ order : int
491
+ Spline order
492
+ slice : int
493
+ Slice direction, if slicewise.
494
+ thickness : int
495
+ Slice thickness, if slicewise.
496
+ Note that `shape` will be scaled along the slice direction
497
+ so that the number of nodes is approximately preserved.
498
+
499
+ Other Parameters
500
+ ----------------
501
+ shared
502
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
503
+ for details.
504
+ returns : [list or dict of] {'input', 'output', 'field'}
505
+ See [`Transform`][cornucopia.base.Transform] for details.
506
+ append, prefix, include, exclude, consume
507
+ See [`Transform`][cornucopia.base.Transform] for details.
508
+
509
+ """
510
+ super().__init__(shared=shared, **kwargs)
511
+ self.shape = shape
512
+ self.vmax = vmax
513
+ self.vmin = vmin
514
+ self.order = order
515
+ self.slice = slice
516
+ self.thickness = thickness
517
+
518
+ def make_field(
519
+ self,
520
+ batch: int,
521
+ smallshape: tx.Sequence[int],
522
+ fullshape: tx.Optional[tx.Sequence[int]] = None,
523
+ **backend
524
+ ) -> None:
525
+ """Generate the random coefficients.
526
+
527
+ Parameters
528
+ ----------
529
+ batch : int
530
+ Number of fields to generate
531
+ smallshape : list of int
532
+ Number of spline control points
533
+ fullshape : list of int, optional
534
+ If given, the coefficients will be upsampled to this shape.
535
+
536
+ Other Parameters
537
+ ----------------
538
+ dtype : torch.dtype
539
+ Data type of the generated field.
540
+ device : torch.device | str
541
+ Device on which to generate the field.
542
+
543
+ Returns
544
+ -------
545
+ field : (batch, *smallshape) tensor | (batch, *fullshape) tensor
546
+ If `fullshape` is given, returns the upsampled field of values.
547
+ Otherise, returns the spline coefficients.
548
+
549
+ """
550
+ smallshape = ensure_list(smallshape, len(fullshape))
551
+ smallshape = [min(small, full) for small, full
552
+ in zip(smallshape, fullshape)]
553
+ if not backend['dtype'].is_floating_point:
554
+ backend['dtype'] = torch.get_default_dtype()
555
+ b = torch.rand([batch, *smallshape], **backend)
556
+ if fullshape:
557
+ b = self.upsample_field(b, fullshape)
558
+ return b
559
+
560
+ def upsample_field(self, coeff: Tensor, shape: tx.Sequence[int]) -> Tensor:
561
+ """Compute the full-sized field from its spline coefficients.
562
+
563
+ Parameters
564
+ ----------
565
+ coeff : (batch, *smallshape) tensor
566
+ Spline coefficients
567
+ shape : list of int
568
+ Target shape for the upsampled field
569
+
570
+ Returns
571
+ -------
572
+ field : (batch, *shape) tensor
573
+ Upsampled field of values
574
+ """
575
+ if self.order == 1:
576
+ mode = ('trilinear' if len(shape) == 3 else
577
+ 'bilinear' if len(shape) == 2 else
578
+ 'linear')
579
+ b = interpolate(
580
+ coeff.unsqueeze(0), shape, mode=mode,
581
+ align_corners=True
582
+ ).squeeze(-0)
583
+ else:
584
+ b = interpol.resize(
585
+ coeff, shape=shape, interpolation=self.order,
586
+ prefilter=False
587
+ )
588
+ return b
589
+
590
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
591
+ if max_depth == 0:
592
+ return self
593
+
594
+ ndim = x.ndim - 1
595
+ fullshape = list(x.shape[1:])
596
+ batch = 1 if 'channels' in self.shared else len(x)
597
+ backend = dict(dtype=x.dtype, device=x.device)
598
+
599
+ # slicewise bias
600
+ if self.slice is not None:
601
+ slice = positive_index(self.slice, ndim)
602
+ thickness = self.thickness or 1
603
+ thickness = min(thickness, x.shape[1+slice])
604
+ nb_slices = int(math.ceil(x.shape[1+slice] / thickness))
605
+
606
+ smallshape = ensure_list(self.shape, ndim)
607
+ smallshape[slice] = int(math.ceil(smallshape[slice] / nb_slices))
608
+ smallshape = [min(small, full) for small, full
609
+ in zip(smallshape, fullshape)]
610
+
611
+ if thickness == 1:
612
+ # bias independent across slices -> batch it
613
+ batch1 = batch * fullshape[slice]
614
+ del smallshape[slice]
615
+ del fullshape[slice]
616
+ b = self.make_field(batch1, smallshape, fullshape, **backend)
617
+ b = b.reshape([batch, -1, *b.shape[1:]])
618
+ b = b.movedim(1, 1+slice)
619
+
620
+ elif fullshape[slice] % thickness == 0:
621
+ # shape divisible by thickness -> unfold and batch
622
+ fullshape0 = list(fullshape)
623
+ _, *fullshape = x.shape
624
+ batch1 = batch * nb_slices
625
+ fullshape[slice] = thickness
626
+ b = self.make_field(batch1, smallshape, fullshape, **backend)
627
+ b = b.reshape([batch, -1, *b.shape[1:]])
628
+ b = b.movedim(1, 1+slice)
629
+ b = b.reshape([batch, *fullshape0])
630
+
631
+ else:
632
+ # otherwise, the input is not exactly divisible by thickness
633
+ b = x.new_empty([batch, *fullshape], **backend)
634
+
635
+ # use same strategy as before for all but last slice
636
+ fullshape0 = list(fullshape)
637
+ _, *fullshape = x.shape
638
+ batch1 = batch * (nb_slices - 1)
639
+ fullshape[slice] = thickness
640
+ fullshape0[slice] = (nb_slices - 1) * thickness
641
+ b1 = self.make_field(batch1, smallshape, fullshape, **backend)
642
+ b1 = b1.reshape([batch, -1, *b1.shape[1:]])
643
+ b1 = b1.movedim(1, 1+slice)
644
+ b1 = b1.reshape([batch, *fullshape0])
645
+
646
+ # copy into the larger placeholder
647
+ b1 = b1.movedim(1+slice, 0)
648
+ b.movedim(1+slice, 0)[:len(b1)].copy_(b1)
649
+
650
+ # process last slice
651
+ fullshape[slice] = b.shape[1+slice] - len(b1)
652
+ b1 = self.make_field(batch, smallshape, fullshape, **backend)
653
+ b1 = b1.movedim(1+slice, 0)
654
+ b.movedim(1+slice, 0)[-len(b1):].copy_(b1)
655
+
656
+ else:
657
+ # global bias
658
+ b = self.make_field(batch, self.shape, fullshape, **backend)
659
+
660
+ # rescale intensities
661
+ batch = len(b)
662
+ vmin, vmax = self.vmin, self.vmax
663
+ if torch.is_tensor(vmin):
664
+ while vmin.ndim < b.ndim:
665
+ vmin = vmin.unsqueeze(-1)
666
+ batch = max(batch, len(vmin))
667
+ if torch.is_tensor(vmax):
668
+ while vmax.ndim < b.ndim:
669
+ vmax = vmax.unsqueeze(-1)
670
+ batch = max(batch, len(vmax))
671
+ if len(b) < batch:
672
+ b = b.expand([batch, *b.shape[1:]]).clone()
673
+
674
+ b = add_(mul_(b, self.vmax-self.vmin), self.vmin)
675
+
676
+ return self.Next(
677
+ b, value_name=self.value_name, **self.get_prm()
678
+ ).unroll(x, max_depth-1)
679
+
680
+
681
+ class MulFieldTransform(BaseFieldTransform):
682
+ """Smooth multiplicative (bias) field"""
683
+
684
+ Final = Next = MulValueTransform
685
+ """The transform type returned by `unroll`, `next` and `final`."""
686
+
687
+
688
+ class RandomMulFieldTransform(NonFinalTransform):
689
+ """Random multiplicative bias field transform"""
690
+
691
+ Next = MulFieldTransform
692
+ """The transform type returned by `next`."""
693
+
694
+ Final = MulValueTransform
695
+ """The transform type returned by `final`."""
696
+
697
+ def __init__(
698
+ self,
699
+ shape: tx.Union[Sampler, int] = 8,
700
+ vmax: tx.Union[Sampler, float] = 1,
701
+ order: int = 3,
702
+ symmetric: tx.Union[bool, float] = False,
703
+ *,
704
+ shared: cct.SharedT = False,
705
+ shared_field: tx.Union[str, bool, None] = None,
706
+ **kwargs
707
+ ) -> None:
708
+ """
709
+ Parameters
710
+ ----------
711
+ shape : Sampler | int
712
+ Sampler or Upper bound for number of control points
713
+ vmax : Sampler | float
714
+ Sampler or Upper bound for maximum value
715
+ order : int
716
+ Spline order
717
+ symmetric : bool | float
718
+ If a float, the bias field will take values in
719
+ `(symmetric-vmax, symmetric+vmax)`.
720
+ If False, it will take values in `(0, vmax)`.
721
+ If True, it will take values in `(1-vmax, 1+vmax)`.
722
+
723
+ Other Parameters
724
+ ----------------
725
+ shared
726
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
727
+ for details.
728
+ shared_field
729
+ Whether to share random field across tensors and/or channels.
730
+ By default: same as `shared`
731
+ returns : [list or dict of] {'input', 'output', 'field'}
732
+ See [`Transform`][cornucopia.base.Transform] for details.
733
+ append, prefix, include, exclude, consume
734
+ See [`Transform`][cornucopia.base.Transform] for details.
735
+ """ # noqa: E501
736
+ super().__init__(shared=shared, **kwargs)
737
+ self.vmax = Uniform.make(make_range(0, vmax))
738
+ self.shape = RandInt.make(make_range(2, shape))
739
+ self.order = Fixed.make(order)
740
+ self.symmetric = symmetric
741
+ self.shared_field = self._prepare_shared(shared_field)
742
+
743
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
744
+ vmax, shape, order = self.vmax, self.shape, self.order
745
+ shared_field = self.shared_field
746
+ if isinstance(vmax, Sampler):
747
+ vmax = vmax()
748
+ if isinstance(shape, Sampler):
749
+ shape = shape(x.ndim-1)
750
+ if isinstance(order, Sampler):
751
+ order = order()
752
+ if shared_field is None:
753
+ shared_field = self.shared
754
+ if self.symmetric is False:
755
+ vmin = 0
756
+ else:
757
+ mid = self.symmetric
758
+ vmin, vmax = mid - vmax, mid + vmax
759
+ return MulFieldTransform(
760
+ shape, vmin, vmax, order, shared=shared_field, **self.get_prm()
761
+ ).unroll(x, max_depth-1)
762
+
763
+
764
+ class RandomSlicewiseMulFieldTransform(NonFinalTransform):
765
+ """Random multiplicative bias field transform, per slice or slab"""
766
+
767
+ Next = MulFieldTransform
768
+ """The transform type returned by `next`."""
769
+
770
+ Final = MulValueTransform
771
+ """The transform type returned by `final`."""
772
+
773
+ def __init__(
774
+ self,
775
+ shape: tx.Union[Sampler, int] = 8,
776
+ vmax: tx.Union[Sampler, float] = 1,
777
+ order: int = 3,
778
+ slice: tx.Optional[int] = None,
779
+ thickness: tx.Union[Sampler, int] = 32,
780
+ shape_through: tx.Optional[tx.Union[Sampler, int]] = None,
781
+ *,
782
+ shared: cct.SharedT = False,
783
+ shared_field: tx.Union[str, bool, None] = None,
784
+ **kwargs
785
+ ) -> None:
786
+ """
787
+ Parameters
788
+ ----------
789
+ shape : Sampler | int
790
+ Sampler or Upper bound for number of control points
791
+ vmax : Sampler | float
792
+ Sampler or Upper bound for maximum value
793
+ order : int
794
+ Spline order
795
+ slice : int | None
796
+ Slice axis. If None, sample one randomly
797
+ thickness : Sampler | int
798
+ Sampler or Upper bound for slice thickness
799
+ shape_through : Sampler | int | None
800
+ Sampler or Upper bound for number of control points
801
+ along the slice direction. If None, same as `shape`.
802
+
803
+ Other Parameters
804
+ ----------------
805
+ shared
806
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
807
+ for details.
808
+ shared_field
809
+ Whether to share random field across tensors and/or channels.
810
+ By default: same as `shared`
811
+ returns : [list or dict of] {'input', 'output', 'field'}
812
+ See [`Transform`][cornucopia.base.Transform] for details.
813
+ append, prefix, include, exclude, consume
814
+ See [`Transform`][cornucopia.base.Transform] for details.
815
+ """ # noqa: E501
816
+ super().__init__(shared=shared, **kwargs)
817
+ if shape_through is not None:
818
+ shape_through = RandInt.make(make_range(1, shape_through))
819
+ self.vmax = Uniform.make(make_range(0, vmax))
820
+ self.shape = RandInt.make(make_range(2, shape))
821
+ self.order = Fixed.make(order)
822
+ self.slice = slice
823
+ self.thickness = RandInt.make(make_range(1, thickness))
824
+ self.shape_through = shape_through
825
+ self.shared_field = self._prepare_shared(shared_field)
826
+
827
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
828
+ if max_depth == 0:
829
+ return self
830
+ ndim = x.ndim - 1
831
+
832
+ vmax = self.vmax
833
+ shape = self.shape
834
+ order = self.order
835
+ slice = self.slice
836
+ thickness = self.thickness
837
+ shape_through = self.shape_through
838
+ shared_field = self.shared_field
839
+
840
+ if slice is None:
841
+ slice = RandInt(x.ndim-2)
842
+
843
+ if shape_through is not None:
844
+ if isinstance(slice, Sampler):
845
+ slice = slice()
846
+ slice = positive_index(slice, ndim)
847
+ if isinstance(shape, Sampler):
848
+ shape = shape(ndim)
849
+ shape = list(ensure_list(shape, ndim))
850
+ if isinstance(thickness, Sampler):
851
+ thickness = thickness()
852
+ if isinstance(shape_through, Sampler):
853
+ shape_through = shape_through()
854
+ shape_through0 = x.shape[1+self.sample['slice']]
855
+ shape_through *= int(math.ceil(shape_through0 / thickness))
856
+ shape[slice] = shape_through
857
+
858
+ if isinstance(vmax, Sampler):
859
+ vmax = vmax()
860
+ if isinstance(shape, Sampler):
861
+ shape = shape(ndim)
862
+ if isinstance(order, Sampler):
863
+ order = order()
864
+ if isinstance(slice, Sampler):
865
+ slice = slice()
866
+ if isinstance(thickness, Sampler):
867
+ thickness = thickness()
868
+ if isinstance(shape_through, Sampler):
869
+ shape_through = shape_through()
870
+ if shared_field is None:
871
+ shared_field = self.shared
872
+
873
+ return MulFieldTransform(
874
+ shape, 0, vmax, order, slice, thickness,
875
+ shared=shared_field, **self.get_prm()
876
+ ).unroll(x, max_depth-1)
877
+
878
+
879
+ class AddFieldTransform(BaseFieldTransform):
880
+ """Smooth additive (bias) field"""
881
+
882
+ Final = Next = AddValueTransform
883
+ """The transform type returned by `unroll`, `next` and `final`."""
884
+
885
+
886
+ class RandomAddFieldTransform(NonFinalTransform):
887
+ """Random additive bias field transform"""
888
+
889
+ def __init__(
890
+ self,
891
+ shape: tx.Union[Sampler, int] = 8,
892
+ vmin: tx.Union[Sampler, float] = -1,
893
+ vmax: tx.Union[Sampler, float] = 1,
894
+ order: tx.Union[Sampler, int] = 3,
895
+ *,
896
+ shared: cct.SharedT = False,
897
+ shared_field: tx.Union[str, bool, None] = None,
898
+ **kwargs
899
+ ) -> None:
900
+ """
901
+ Parameters
902
+ ----------
903
+ shape : Sampler | int
904
+ Sampler or Upper bound for number of control points
905
+ vmin : Sampler | float
906
+ Sampler or Lower bound for minimum value
907
+ vmax : Sampler | float
908
+ Sampler or Upper bound for maximum value
909
+ order : Sampler | int
910
+ Spline order
911
+
912
+ Other Parameters
913
+ ----------------
914
+ shared
915
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
916
+ for details.
917
+ shared_field
918
+ Whether to share random field across tensors and/or channels.
919
+ By default: same as `shared`
920
+ returns : [list or dict of] {'input', 'output', 'field'}
921
+ See [`Transform`][cornucopia.base.Transform] for details.
922
+ append, prefix, include, exclude, consume
923
+ See [`Transform`][cornucopia.base.Transform] for details.
924
+ """
925
+ super().__init__(shared=shared, **kwargs)
926
+ self.vmin = Uniform.make(make_range(vmin, 0))
927
+ self.vmax = Uniform.make(make_range(0, vmax))
928
+ self.shape = RandInt.make(make_range(2, shape))
929
+ self.order = Fixed.make(order)
930
+ self.shared_field = self._prepare_shared(shared_field)
931
+
932
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
933
+ vmin, vmax, shape, order = self.vmin, self.vmax, self.shape, self.order
934
+ shared_field = self.shared_field
935
+ if isinstance(vmin, Sampler):
936
+ vmin = vmin()
937
+ if isinstance(vmax, Sampler):
938
+ vmax = vmax()
939
+ if isinstance(shape, Sampler):
940
+ shape = shape(x.ndim-1)
941
+ if isinstance(order, Sampler):
942
+ order = order()
943
+ if shared_field is None:
944
+ shared_field = self.shared
945
+ return AddFieldTransform(
946
+ shape, vmin, vmax, order, shared=shared_field, **self.get_prm()
947
+ ).unroll(x, max_depth-1)
948
+
949
+
950
+ class GammaFinalTransform(FinalTransform):
951
+ """Gamma correction with fixed parameters.
952
+
953
+ The transform is defined as:
954
+
955
+ ```python
956
+ y = (x-vmin) / (vmax-vmin) ** gamma * (vmax-vmin) + vmin
957
+ ```
958
+
959
+ In this transform, `vmin` and `vmax` are pre-calculated and fixed,
960
+ whereas in `GammaTransform`, they are computed from the image intensities.
961
+ """
962
+
963
+ _ScalarOrVector = tx.Union[float, tx.Sequence[float], Tensor]
964
+
965
+ def __init__(
966
+ self,
967
+ gamma: _ScalarOrVector = 1,
968
+ vmin: _ScalarOrVector = 0,
969
+ vmax: _ScalarOrVector = 1,
970
+ **kwargs
971
+ ):
972
+ """
973
+ Parameters
974
+ ----------
975
+ gamma : number | (C,) list[number] | (C,) tensor
976
+ Exponent of the Gamma transform
977
+ vmin : number | (C,) list[number] | (C,) tensor
978
+ Minimum value for the transform
979
+ vmax : number | (C,) list[number] | (C,) tensor
980
+ Maximum value for the transform
981
+
982
+ Other Parameters
983
+ ----------------
984
+ returns : [list or dict of] {'input', 'output', 'vmin', 'vmax', 'gamma'}
985
+ See [`Transform`][cornucopia.base.Transform] for details.
986
+ append, prefix, include, exclude, consume
987
+ See [`Transform`][cornucopia.base.Transform] for details.
988
+ """
989
+ super().__init__(**kwargs)
990
+ self.gamma = gamma
991
+ self.vmin = vmin
992
+ self.vmax = vmax
993
+
994
+ def __repr__(self) -> str:
995
+ gamma, vmin, vmax = self.gamma, self.vmin, self.vmax
996
+ if torch.is_tensor(gamma):
997
+ gamma = gamma.detach().tolist()
998
+ if torch.is_tensor(vmin):
999
+ vmin = vmin.detach().tolist()
1000
+ if torch.is_tensor(vmax):
1001
+ vmax = vmax.detach().tolist()
1002
+ return f"{type(self).__name__}(gamma={gamma}, vmin={vmin}, vmax={vmax})"
1003
+
1004
+ def _xform(self, x: Tensor) -> Returned:
1005
+ vmin = torch.as_tensor(self.vmin, dtype=x.dtype, device=x.device)
1006
+ vmax = torch.as_tensor(self.vmax, dtype=x.dtype, device=x.device)
1007
+ gamma = torch.as_tensor(self.gamma, dtype=x.dtype, device=x.device)
1008
+ vmin = vmin.reshape([-1] + [1] * (x.ndim-1))
1009
+ vmax = vmax.reshape([-1] + [1] * (x.ndim-1))
1010
+ gamma = gamma.reshape([-1] + [1] * (x.ndim-1))
1011
+
1012
+ # NOTE
1013
+ # * we add a little epsilon to the denominator to avoid
1014
+ # division by zero.
1015
+ # * We also ensure that the rescaled input is in (0+eps, 1-eps)
1016
+ # to ensure differentiability everywhere.
1017
+ # * The vmin/vmax may have been computed on a different image
1018
+ # than x, so we cannot trust that x.min() < vmin.
1019
+
1020
+ den = vmax - vmin
1021
+ num = x - vmin
1022
+ num = clamp_(num, 1e-5 * den, (1.0 - 1e-5) * den)
1023
+ y = div_(num, add_(den, 1e-5))
1024
+ y = pow_(y, gamma)
1025
+ if gamma.requires_grad:
1026
+ # When gamma requires grad, mul_(y, vmax-vmin) is happy
1027
+ # to overwrite y, but we cant because we need y to
1028
+ # backprop through pow. So we need an explicit branch.
1029
+ y = torch.add(torch.mul(y, vmax - vmin), vmin)
1030
+ else:
1031
+ y = add_(mul_(y, vmax - vmin), vmin)
1032
+
1033
+ return prepare_output(
1034
+ dict(input=x, output=y, vmin=vmin, vmax=vmax, gamma=gamma),
1035
+ self.returns)
1036
+
1037
+
1038
+ class GammaTransform(NonFinalTransform):
1039
+ """Gamma correction
1040
+
1041
+ References
1042
+ ----------
1043
+ 1. https://en.wikipedia.org/wiki/Gamma_correction
1044
+ """
1045
+
1046
+ Final = Next = GammaFinalTransform
1047
+ """The transform type returned by `unroll`, `next` and `final`."""
1048
+
1049
+ def __init__(
1050
+ self,
1051
+ gamma: float = 1,
1052
+ vmin: tx.Optional[float] = None,
1053
+ vmax: tx.Optional[float] = None,
1054
+ *,
1055
+ shared: cct.SharedT = False,
1056
+ **kwargs
1057
+ ) -> None:
1058
+ """
1059
+
1060
+ Parameters
1061
+ ----------
1062
+ gamma : float
1063
+ Exponent of the Gamma transform
1064
+ vmin : float | None
1065
+ Value to use as the minimum (default: x.min())
1066
+ vmax : float | None
1067
+ Value to use as the maximum (default: x.max())
1068
+
1069
+ Other Parameters
1070
+ ----------------
1071
+ shared
1072
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
1073
+ for details.
1074
+ returns : [list or dict of] {'input', 'output', 'vmin', 'vmax', 'gamma'}
1075
+ See [`Transform`][cornucopia.base.Transform] for details.
1076
+ append, prefix, include, exclude, consume
1077
+ See [`Transform`][cornucopia.base.Transform] for details.
1078
+ """
1079
+ super().__init__(shared=shared, **kwargs)
1080
+ self.gamma = kwargs.pop('value', gamma)
1081
+ self.vmin = vmin
1082
+ self.vmax = vmax
1083
+
1084
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
1085
+ if max_depth == 0:
1086
+ return self
1087
+ ndim = x.dim() - 1
1088
+ if self.vmin is None:
1089
+ vmin = x.reshape(len(x), -1).min(-1).values
1090
+ for _ in range(ndim):
1091
+ vmin = vmin.unsqueeze(-1)
1092
+ if 'channels' in self.shared:
1093
+ vmin = vmin.min()
1094
+ else:
1095
+ vmin = self.vmin
1096
+ if self.vmax is None:
1097
+ vmax = x.reshape(len(x), -1).max(-1).values
1098
+ for _ in range(ndim):
1099
+ vmax = vmax.unsqueeze(-1)
1100
+ if 'channels' in self.shared:
1101
+ vmax = vmax.max()
1102
+ else:
1103
+ vmax = self.vmax
1104
+ return self.Next(
1105
+ self.gamma, vmin, vmax, **self.get_prm()
1106
+ ).unroll(max_depth-1)
1107
+
1108
+
1109
+ class RandomGammaTransform(NonFinalTransform):
1110
+ """
1111
+ Random Gamma transform.
1112
+ """
1113
+
1114
+ Next = GammaTransform
1115
+ """The transform type returned by `next`."""
1116
+
1117
+ Final = GammaFinalTransform
1118
+ """The transform type returned by `final`."""
1119
+
1120
+ def __init__(
1121
+ self,
1122
+ gamma: tx.Union[Sampler, float, tx.Tuple[float, float]] = (0.5, 2),
1123
+ *,
1124
+ shared: cct.SharedT = False,
1125
+ shared_minmax: tx.Optional[cct.SharedT] = None,
1126
+ **kwargs
1127
+ ):
1128
+ """
1129
+ Parameters
1130
+ ----------
1131
+ gamma : Sampler or [pair of] float
1132
+ Sampler or range for the exponent value
1133
+
1134
+ Other Parameters
1135
+ ----------------
1136
+ shared
1137
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
1138
+ for details.
1139
+ shared_minmax
1140
+ Use the same vmin/vmax for all channels.
1141
+ Default: same as `shared`.
1142
+ returns : [list or dict of] {'input', 'output', 'vmin', 'vmax', 'gamma'}
1143
+ See [`Transform`][cornucopia.base.Transform] for details.
1144
+ append, prefix, include, exclude, consume
1145
+ See [`Transform`][cornucopia.base.Transform] for details.
1146
+ """
1147
+ super().__init__(shared=shared, **kwargs)
1148
+ self.gamma = Uniform.make(kwargs.pop('value', gamma))
1149
+ self.shared_minmax = self._prepare_shared(shared_minmax)
1150
+
1151
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
1152
+ gamma = self.gamma
1153
+ if isinstance(gamma, Sampler):
1154
+ gamma = gamma()
1155
+ shared_minmax = self.shared_minmax
1156
+ if shared_minmax is None:
1157
+ shared_minmax = self.shared
1158
+ return GammaTransform(
1159
+ gamma, shared=shared_minmax, **self.get_prm()
1160
+ ).unroll(x, max_depth-1)
1161
+
1162
+
1163
+ class ZTransform(NonFinalTransform):
1164
+ """
1165
+ Z-transform the data -> zero mean, unit standard deviation
1166
+ """
1167
+
1168
+ Final = Next = AddMulTransform
1169
+ """The transform type returned by `unroll`, `next` and `final`."""
1170
+
1171
+ def __init__(
1172
+ self, mu: float = 0, sigma: float = 1,
1173
+ *, shared: cct.SharedT = False, **kwargs
1174
+ ):
1175
+ """
1176
+ Parameters
1177
+ ----------
1178
+ mu : float
1179
+ Target mean. If None, keep the input mean.
1180
+ sigma : float
1181
+ Target standard deviation. If None, keep the input sd.
1182
+
1183
+ Other Parameters
1184
+ ----------------
1185
+ shared
1186
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
1187
+ for details.
1188
+ returns, append, prefix, include, exclude, consume
1189
+ See [`Transform`][cornucopia.base.Transform] for details.
1190
+ """
1191
+ super().__init__(shared=shared, **kwargs)
1192
+ self.mu = mu
1193
+ self.sigma = sigma
1194
+
1195
+ def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
1196
+ if max_depth == 0:
1197
+ return self
1198
+ if 'channels' in self.shared:
1199
+ opt = dict()
1200
+ else:
1201
+ opt = dict(dim=list(range(1, x.ndim)), keepdim=True)
1202
+ mu0, sigma0 = x.mean(**opt), x.std(**opt)
1203
+ mu1 = self.mu if self.mu is not None else mu0
1204
+ sigma1 = self.sigma if self.sigma is not None else sigma0
1205
+ scale = sigma1 / sigma0
1206
+ offset = mu1 - mu0 * scale
1207
+ return AddMulTransform(
1208
+ scale, offset, **self.get_prm()
1209
+ ).unroll(x, max_depth-1)
1210
+
1211
+
1212
+ class QuantileTransform(NonFinalTransform):
1213
+ """Match lower and upper quantiles to (0, 1)"""
1214
+
1215
+ Final = Next = AddMulTransform
1216
+ """The transform type returned by `unroll`, `next` and `final`."""
1217
+
1218
+ def __init__(
1219
+ self,
1220
+ pmin: float = 0.01,
1221
+ pmax: float = 0.99,
1222
+ vmin: float = 0,
1223
+ vmax: float = 1,
1224
+ clip: bool = False,
1225
+ max_samples: int = 10000,
1226
+ **kwargs
1227
+ ) -> None:
1228
+ """
1229
+
1230
+ Parameters
1231
+ ----------
1232
+ pmin : (0..1)
1233
+ Lower quantile
1234
+ pmax : (0..1)
1235
+ Upper quantile
1236
+ vmin : float
1237
+ Lower target value
1238
+ vmax : float
1239
+ Upper target value
1240
+ clip : bool
1241
+ Clip values outside (vmin, vmax)
1242
+ max_samples : int
1243
+ Maximum number of pixels to use for quantile estimation (for speed)
1244
+
1245
+ Other Parameters
1246
+ ----------------
1247
+ shared
1248
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
1249
+ for details.
1250
+ returns, append, prefix, include, exclude, consume
1251
+ See [`Transform`][cornucopia.base.Transform] for details.
1252
+ """
1253
+ super().__init__(**kwargs)
1254
+ self.pmin = pmin
1255
+ self.pmax = pmax
1256
+ self.vmin = vmin
1257
+ self.vmax = vmax
1258
+ self.clip = clip
1259
+ self.max_samples = max_samples
1260
+
1261
+ def _unroll(self, x: Tensor, max_depth: float = inf) -> Transform:
1262
+ if max_depth == 0:
1263
+ return self
1264
+
1265
+ ndim = x.ndim - 1
1266
+
1267
+ x_ = x.reshape([len(x), -1])
1268
+ x_ = x_[:, (x_ != 0).all(0) & x_.isfinite().all(0)]
1269
+ if self.max_samples and self.max_samples < x_.shape[1]:
1270
+ idx_ = torch.randperm(x_.shape[-1], device=x_.device)
1271
+ idx_ = idx_[:self.max_samples]
1272
+ x_ = x_[:, idx_]
1273
+
1274
+ qdim = (-1 if 'channels' not in self.shared else None)
1275
+ pmin = torch.quantile(x_, self.pmin, dim=qdim)
1276
+ pmax = torch.quantile(x_, self.pmax, dim=qdim)
1277
+ pmin = pmin[(Ellipsis,) + (None,) * ndim]
1278
+ pmax = pmax[(Ellipsis,) + (None,) * ndim]
1279
+
1280
+ num = self.vmax - self.vmin
1281
+ den = (pmax - pmin).clamp_min_(1e-16)
1282
+ slope = num / den
1283
+ offset = self.vmin - pmin * slope
1284
+
1285
+ if self.clip:
1286
+ return SequentialTransform([
1287
+ AddMulTransform(slope, offset, **self.get_prm()),
1288
+ ClipTransform(self.vmin, self.vmax, **self.get_prm())
1289
+ ]).unroll(x, max_depth-1)
1290
+ else:
1291
+ return AddMulTransform(
1292
+ slope, offset, **self.get_prm()
1293
+ ).unroll(x, max_depth-1)
1294
+
1295
+
1296
+ class MinMaxTransform(NonFinalTransform):
1297
+ """Match min and max values to (0, 1)"""
1298
+
1299
+ Final = Next = AddMulTransform
1300
+ """The transform type returned by `unroll`, `next` and `final`."""
1301
+
1302
+ def __init__(
1303
+ self, vmin: float = 0, vmax: float = 1, clip: bool = False, **kwargs
1304
+ ) -> None:
1305
+ """
1306
+
1307
+ Parameters
1308
+ ----------
1309
+ vmin : float
1310
+ Lower target value
1311
+ vmax : float
1312
+ Upper target value
1313
+ clip : bool
1314
+ Clip values outside (vmin, vmax)
1315
+
1316
+ Other Parameters
1317
+ ----------------
1318
+ shared
1319
+ See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
1320
+ for details.
1321
+ returns, append, prefix, include, exclude, consume
1322
+ See [`Transform`][cornucopia.base.Transform] for details.
1323
+ """
1324
+ super().__init__(**kwargs)
1325
+ self.vmin = vmin
1326
+ self.vmax = vmax
1327
+ self.clip = clip
1328
+
1329
+ def _unroll(self, x: Tensor, max_depth: float = inf) -> Transform:
1330
+ if max_depth == 0:
1331
+ return self
1332
+
1333
+ ndim = x.ndim - 1
1334
+
1335
+ x_ = x.reshape([len(x), -1])
1336
+ x_ = x_[:, x_.isfinite().all(0)]
1337
+
1338
+ if 'channels' not in self.shared:
1339
+ pmin = torch.min(x_, dim=-1).values
1340
+ pmax = torch.max(x_, dim=-1).values
1341
+ else:
1342
+ pmin = torch.min(x_)
1343
+ pmax = torch.max(x_)
1344
+ pmin = pmin[(Ellipsis,) + (None,) * ndim]
1345
+ pmax = pmax[(Ellipsis,) + (None,) * ndim]
1346
+
1347
+ slope = (self.vmax - self.vmin) / (pmax - pmin)
1348
+ offset = self.vmin - pmin * slope
1349
+
1350
+ if self.clip:
1351
+ return SequentialTransform([
1352
+ AddMulTransform(slope, offset, **self.get_prm()),
1353
+ ClipTransform(self.vmin, self.vmax, **self.get_prm())
1354
+ ]).unroll(x, max_depth-1)
1355
+ else:
1356
+ return AddMulTransform(
1357
+ slope, offset, **self.get_prm()
1358
+ ).unroll(x, max_depth-1)