cornucopia 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornucopia/__init__.py +73 -0
- cornucopia/base.py +1915 -0
- cornucopia/baseutils.py +575 -0
- cornucopia/contrast.py +260 -0
- cornucopia/ctx.py +25 -0
- cornucopia/fov.py +707 -0
- cornucopia/geometric.py +2068 -0
- cornucopia/intensity.py +1358 -0
- cornucopia/io.py +161 -0
- cornucopia/kspace.py +505 -0
- cornucopia/labels.py +1872 -0
- cornucopia/noise.py +508 -0
- cornucopia/psf.py +463 -0
- cornucopia/qmri.py +1288 -0
- cornucopia/random.py +1480 -0
- cornucopia/special.py +159 -0
- cornucopia/synth.py +708 -0
- cornucopia/tests/__init__.py +0 -0
- cornucopia/tests/test_backward_geometric.py +173 -0
- cornucopia/tests/test_backward_intensity.py +243 -0
- cornucopia/tests/test_backward_kspace.py +115 -0
- cornucopia/tests/test_backward_noise.py +169 -0
- cornucopia/tests/test_backward_psf.py +142 -0
- cornucopia/tests/test_backward_qmri.py +249 -0
- cornucopia/tests/test_backward_random.py +44 -0
- cornucopia/tests/test_backward_synth.py +72 -0
- cornucopia/tests/test_base.py +401 -0
- cornucopia/tests/test_geometric.py +26 -0
- cornucopia/tests/test_intensity.py +9 -0
- cornucopia/tests/test_random.py +722 -0
- cornucopia/tests/test_run_contrast.py +28 -0
- cornucopia/tests/test_run_fov.py +132 -0
- cornucopia/tests/test_run_geometric.py +157 -0
- cornucopia/tests/test_run_intensity.py +192 -0
- cornucopia/tests/test_run_kspace.py +70 -0
- cornucopia/tests/test_run_labels.py +224 -0
- cornucopia/tests/test_run_noise.py +127 -0
- cornucopia/tests/test_run_psf.py +115 -0
- cornucopia/tests/test_run_qmri.py +114 -0
- cornucopia/tests/test_run_synth.py +67 -0
- cornucopia/typing.py +97 -0
- cornucopia/utils/__init__.py +0 -0
- cornucopia/utils/b0.py +745 -0
- cornucopia/utils/bounds.py +412 -0
- cornucopia/utils/compat.py +47 -0
- cornucopia/utils/conv.py +305 -0
- cornucopia/utils/gmm.py +169 -0
- cornucopia/utils/indexing.py +911 -0
- cornucopia/utils/io.py +258 -0
- cornucopia/utils/jit.py +128 -0
- cornucopia/utils/kernels.py +288 -0
- cornucopia/utils/morpho.py +234 -0
- cornucopia/utils/mrf.py +574 -0
- cornucopia/utils/padding.py +173 -0
- cornucopia/utils/patch.py +302 -0
- cornucopia/utils/pool.py +282 -0
- cornucopia/utils/py.py +348 -0
- cornucopia/utils/smart_inplace.py +163 -0
- cornucopia/utils/version.py +57 -0
- cornucopia/utils/warps.py +606 -0
- cornucopia-0.0.0.dist-info/METADATA +92 -0
- cornucopia-0.0.0.dist-info/RECORD +65 -0
- cornucopia-0.0.0.dist-info/WHEEL +5 -0
- cornucopia-0.0.0.dist-info/licenses/LICENSE +21 -0
- cornucopia-0.0.0.dist-info/top_level.txt +1 -0
cornucopia/intensity.py
ADDED
|
@@ -0,0 +1,1358 @@
|
|
|
1
|
+
"""This module contains transforms that operate on image intensities."""
|
|
2
|
+
__all__ = [
|
|
3
|
+
'AddValueTransform',
|
|
4
|
+
'MulValueTransform',
|
|
5
|
+
'AddMulTransform',
|
|
6
|
+
'ReturnValueTransform',
|
|
7
|
+
'FillValueTransform',
|
|
8
|
+
'ClipTransform',
|
|
9
|
+
'BaseFieldTransform',
|
|
10
|
+
'AddFieldTransform',
|
|
11
|
+
'MulFieldTransform',
|
|
12
|
+
'RandomAddFieldTransform',
|
|
13
|
+
'RandomMulFieldTransform',
|
|
14
|
+
'RandomSlicewiseMulFieldTransform',
|
|
15
|
+
'RandomMulTransform',
|
|
16
|
+
'RandomAddTransform',
|
|
17
|
+
'RandomAddMulTransform',
|
|
18
|
+
'GammaFinalTransform',
|
|
19
|
+
'GammaTransform',
|
|
20
|
+
'RandomGammaTransform',
|
|
21
|
+
'ZTransform',
|
|
22
|
+
'QuantileTransform',
|
|
23
|
+
'MinMaxTransform',
|
|
24
|
+
]
|
|
25
|
+
# stdlib
|
|
26
|
+
import math
|
|
27
|
+
from math import inf
|
|
28
|
+
from numbers import Number
|
|
29
|
+
|
|
30
|
+
# dependencies
|
|
31
|
+
import interpol
|
|
32
|
+
import torch
|
|
33
|
+
import typing_extensions as tx
|
|
34
|
+
from torch import Tensor
|
|
35
|
+
from torch.nn.functional import interpolate
|
|
36
|
+
|
|
37
|
+
# internals
|
|
38
|
+
from .baseutils import Returned, prepare_output
|
|
39
|
+
from .base import Transform, FinalTransform, NonFinalTransform
|
|
40
|
+
from .special import RandomizedTransform, SequentialTransform
|
|
41
|
+
from .random import Sampler, Uniform, RandInt, Fixed, make_range
|
|
42
|
+
from .utils.py import ensure_list, positive_index
|
|
43
|
+
from .utils.smart_inplace import add_, mul_, div_, pow_
|
|
44
|
+
from .utils.compat import clamp, clamp_
|
|
45
|
+
from . import typing as cct
|
|
46
|
+
|
|
47
|
+
# typing
|
|
48
|
+
_NumberOrTensor = tx.Union[Number, Tensor]
|
|
49
|
+
_UnaryOperator = tx.Callable[[Tensor], Tensor]
|
|
50
|
+
_BinaryOperator = tx.Callable[[Tensor, _NumberOrTensor], Tensor]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class OpConstTransform(FinalTransform):
|
|
54
|
+
"""Base class for arithmetic operations with a constant value"""
|
|
55
|
+
|
|
56
|
+
_op: tx.Optional[_BinaryOperator] = None
|
|
57
|
+
_inv: tx.Dict[_BinaryOperator, _UnaryOperator] = {
|
|
58
|
+
torch.add: lambda x: -x,
|
|
59
|
+
torch.mul: lambda x: 1/x,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
value: _NumberOrTensor,
|
|
65
|
+
op: tx.Optional[_BinaryOperator] = None,
|
|
66
|
+
value_name: str = 'value',
|
|
67
|
+
**kwargs
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
value : number or tensor
|
|
73
|
+
right-hand side of the operation
|
|
74
|
+
op : {torch.add, torch.mul}
|
|
75
|
+
Arithmetic operation
|
|
76
|
+
value_name : str
|
|
77
|
+
Name used when returning the rhs value
|
|
78
|
+
|
|
79
|
+
Other Parameters
|
|
80
|
+
----------------
|
|
81
|
+
returns, append, prefix, include, exclude, consume
|
|
82
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
83
|
+
"""
|
|
84
|
+
super().__init__(**kwargs)
|
|
85
|
+
self.value = value
|
|
86
|
+
self.op = op or self._op
|
|
87
|
+
self.value_name = value_name
|
|
88
|
+
|
|
89
|
+
def __getattr__(self, name: str) -> _NumberOrTensor:
|
|
90
|
+
if name == self.__dict__.get("value_name"):
|
|
91
|
+
return self.__dict__.get("value")
|
|
92
|
+
return super().__getattr__(name)
|
|
93
|
+
|
|
94
|
+
def __setattr__(self, name: str, value: _NumberOrTensor) -> None:
|
|
95
|
+
if name == self.__dict__.get("value_name"):
|
|
96
|
+
name = 'value'
|
|
97
|
+
super().__setattr__(name, value)
|
|
98
|
+
|
|
99
|
+
def _xform(self, x: Tensor) -> Returned:
|
|
100
|
+
value = self.value
|
|
101
|
+
if torch.is_tensor(value):
|
|
102
|
+
value = value.to(x)
|
|
103
|
+
y = self.op(x, value)
|
|
104
|
+
return prepare_output(
|
|
105
|
+
{'input': x, 'output': y, self.value_name: value}, self.returns
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def make_inverse(self) -> Transform:
|
|
109
|
+
inv = self._inv[self.op]
|
|
110
|
+
return type(self)(
|
|
111
|
+
inv(self.value), **self.get_prm(), value_name=self.value_name
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class AddValueTransform(OpConstTransform):
|
|
116
|
+
"""Add a constant value"""
|
|
117
|
+
_op: _BinaryOperator = torch.add
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class MulValueTransform(OpConstTransform):
|
|
121
|
+
"""Multiply with a constant value"""
|
|
122
|
+
_op: _BinaryOperator = torch.mul
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class FillValueTransform(FinalTransform):
|
|
126
|
+
"""Fills the tensor with a value inside a mask"""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
mask: Tensor,
|
|
131
|
+
value: _NumberOrTensor,
|
|
132
|
+
mask_name: str = 'mask',
|
|
133
|
+
value_name: str = 'value',
|
|
134
|
+
**kwargs
|
|
135
|
+
) -> None:
|
|
136
|
+
"""
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
mask : tensor
|
|
140
|
+
Mask of voxels in which to set the value
|
|
141
|
+
value : number or tensor
|
|
142
|
+
right-hand side of the operation
|
|
143
|
+
mask_name : str
|
|
144
|
+
Name used when returning the mask
|
|
145
|
+
value_name : str
|
|
146
|
+
Name used when returning the rhs value
|
|
147
|
+
|
|
148
|
+
Other Parameters
|
|
149
|
+
----------------
|
|
150
|
+
returns, append, prefix, include, exclude, consume
|
|
151
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
152
|
+
"""
|
|
153
|
+
super().__init__(**kwargs)
|
|
154
|
+
self.mask = mask
|
|
155
|
+
self.value = value
|
|
156
|
+
self.mask_name = mask_name
|
|
157
|
+
self.value_name = value_name
|
|
158
|
+
|
|
159
|
+
def _xform(self, x: Tensor) -> Returned:
|
|
160
|
+
mask, value = self.mask, self.value
|
|
161
|
+
mask = mask.to(x.device)
|
|
162
|
+
if torch.is_tensor(value):
|
|
163
|
+
value = value.to(x)
|
|
164
|
+
y = x.masked_fill(mask, value)
|
|
165
|
+
return prepare_output(
|
|
166
|
+
{'input': x, 'output': y,
|
|
167
|
+
self.mask_name: mask,
|
|
168
|
+
self.value_name: value},
|
|
169
|
+
self.returns
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class ReturnValueTransform(FinalTransform):
|
|
174
|
+
"""Fills the tensor with a value inside a mask"""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
value: _NumberOrTensor,
|
|
179
|
+
value_name: str = 'output',
|
|
180
|
+
dtype: tx.Optional[torch.dtype] = None,
|
|
181
|
+
**kwargs
|
|
182
|
+
) -> None:
|
|
183
|
+
"""
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
value : number or tensor
|
|
187
|
+
right-hand side of the operation
|
|
188
|
+
value_name : str
|
|
189
|
+
Name used when returning the rhs value
|
|
190
|
+
|
|
191
|
+
Other Parameters
|
|
192
|
+
----------------
|
|
193
|
+
returns, append, prefix, include, exclude, consume
|
|
194
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
195
|
+
"""
|
|
196
|
+
super().__init__(**kwargs)
|
|
197
|
+
self.value = value
|
|
198
|
+
self.value_name = value_name
|
|
199
|
+
self.dtype = dtype
|
|
200
|
+
|
|
201
|
+
def __getattr__(self, name: str) -> _NumberOrTensor:
|
|
202
|
+
if name == self.__dict__.get("value_name"):
|
|
203
|
+
return self.__dict__.get("value")
|
|
204
|
+
return super().__getattr__(name)
|
|
205
|
+
|
|
206
|
+
def __setattr__(self, name: str, value: _NumberOrTensor) -> None:
|
|
207
|
+
if name == self.__dict__.get("value_name"):
|
|
208
|
+
name = 'value'
|
|
209
|
+
super().__setattr__(name, value)
|
|
210
|
+
|
|
211
|
+
def _xform(self, x: Tensor) -> Returned:
|
|
212
|
+
dtype = self.dtype or x.dtype
|
|
213
|
+
return torch.as_tensor(self.value, dtype=dtype, device=x.device)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class AddMulTransform(FinalTransform):
|
|
217
|
+
"""Constant intensity affine transform: `y = x * slope + offset`"""
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
slope: _NumberOrTensor = 1,
|
|
222
|
+
offset: _NumberOrTensor = 0,
|
|
223
|
+
**kwargs
|
|
224
|
+
) -> None:
|
|
225
|
+
"""
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
slope : number or tensor
|
|
229
|
+
Affine slope
|
|
230
|
+
offset : number or tensor
|
|
231
|
+
Affine offset
|
|
232
|
+
|
|
233
|
+
Other Parameters
|
|
234
|
+
----------------
|
|
235
|
+
returns, append, prefix, include, exclude, consume
|
|
236
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
237
|
+
"""
|
|
238
|
+
super().__init__(**kwargs)
|
|
239
|
+
self.slope = slope
|
|
240
|
+
self.offset = offset
|
|
241
|
+
|
|
242
|
+
def _xform(self, x: Tensor) -> Returned:
|
|
243
|
+
slope, offset = self.slope, self.offset
|
|
244
|
+
if torch.is_tensor(slope):
|
|
245
|
+
slope = slope.to(x)
|
|
246
|
+
if torch.is_tensor(offset):
|
|
247
|
+
offset = offset.to(x)
|
|
248
|
+
y = slope * x + offset
|
|
249
|
+
return prepare_output(
|
|
250
|
+
{'input': x, 'output': y, 'slope': slope, 'offset': offset},
|
|
251
|
+
self.returns
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def make_inverse(self) -> 'AddMulTransform':
|
|
255
|
+
return AddMulTransform(
|
|
256
|
+
1/self.slope, -self.offset/self.slope, **self.get_prm()
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class ClipTransform(FinalTransform):
|
|
261
|
+
"""Clip extremum values"""
|
|
262
|
+
|
|
263
|
+
def __init__(
|
|
264
|
+
self,
|
|
265
|
+
vmin: tx.Optional[_NumberOrTensor] = None,
|
|
266
|
+
vmax: tx.Optional[_NumberOrTensor] = None,
|
|
267
|
+
**kwargs
|
|
268
|
+
) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
vmin : number or tensor, optional
|
|
273
|
+
Min value
|
|
274
|
+
vmax : number or tensor, optional
|
|
275
|
+
Max value
|
|
276
|
+
|
|
277
|
+
Other Parameters
|
|
278
|
+
----------------
|
|
279
|
+
returns, append, prefix, include, exclude, consume
|
|
280
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
281
|
+
"""
|
|
282
|
+
super().__init__(**kwargs)
|
|
283
|
+
self.vmin = vmin
|
|
284
|
+
self.vmax = vmax
|
|
285
|
+
|
|
286
|
+
def _xform(self, x: Tensor) -> Returned:
|
|
287
|
+
vmin, vmax = self.vmin, self.vmax
|
|
288
|
+
if torch.is_tensor(vmin):
|
|
289
|
+
vmin = vmin.to(x)
|
|
290
|
+
if torch.is_tensor(vmax):
|
|
291
|
+
vmax = vmax.to(x)
|
|
292
|
+
y = clamp(x, vmin, vmax)
|
|
293
|
+
return prepare_output(
|
|
294
|
+
{'input': x, 'output': y, 'vmin': vmin, 'vmax': vmax},
|
|
295
|
+
self.returns
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class RandomMulTransform(RandomizedTransform):
|
|
300
|
+
"""
|
|
301
|
+
Random multiplicative transform.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
Final = Next = MulValueTransform
|
|
305
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
306
|
+
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
value: tx.Union[Sampler, float, tx.Tuple[float, float]] = (0.5, 2),
|
|
310
|
+
*,
|
|
311
|
+
shared: cct.SharedT = False,
|
|
312
|
+
**kwargs
|
|
313
|
+
) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Parameters
|
|
316
|
+
----------
|
|
317
|
+
value : Sampler | [pair of] float
|
|
318
|
+
Bound for multiplicative value
|
|
319
|
+
|
|
320
|
+
Other Parameters
|
|
321
|
+
----------------
|
|
322
|
+
shared
|
|
323
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
324
|
+
for details.
|
|
325
|
+
returns, append, prefix, include, exclude, consume
|
|
326
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
327
|
+
"""
|
|
328
|
+
super().__init__(
|
|
329
|
+
MulValueTransform,
|
|
330
|
+
Uniform.make(make_range(0, value)),
|
|
331
|
+
shared=shared,
|
|
332
|
+
**kwargs
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class RandomAddTransform(RandomizedTransform):
|
|
337
|
+
"""
|
|
338
|
+
Random additive transform.
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
Final = Next = AddValueTransform
|
|
342
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
343
|
+
|
|
344
|
+
def __init__(
|
|
345
|
+
self,
|
|
346
|
+
value: tx.Union[Sampler, float, tx.Tuple[float, float]] = 1,
|
|
347
|
+
*,
|
|
348
|
+
shared: cct.SharedT = False,
|
|
349
|
+
**kwargs
|
|
350
|
+
) -> None:
|
|
351
|
+
"""
|
|
352
|
+
Parameters
|
|
353
|
+
----------
|
|
354
|
+
value : Sampler | [pair of] float
|
|
355
|
+
Bound for additive value
|
|
356
|
+
|
|
357
|
+
Other Parameters
|
|
358
|
+
----------------
|
|
359
|
+
shared
|
|
360
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
361
|
+
for details.
|
|
362
|
+
returns, append, prefix, include, exclude, consume
|
|
363
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
364
|
+
"""
|
|
365
|
+
super().__init__(
|
|
366
|
+
AddValueTransform,
|
|
367
|
+
Uniform.make(make_range(value)),
|
|
368
|
+
shared=shared,
|
|
369
|
+
**kwargs
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class RandomAddMulTransform(RandomizedTransform):
|
|
374
|
+
"""
|
|
375
|
+
Random intensity affine transform.
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
Final = Next = AddMulTransform
|
|
379
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
380
|
+
|
|
381
|
+
def __init__(
|
|
382
|
+
self,
|
|
383
|
+
slope: tx.Union[Sampler, float, tx.Tuple[float, float]] = 1,
|
|
384
|
+
offset: tx.Union[Sampler, float, tx.Tuple[float, float]] = 0.5,
|
|
385
|
+
*,
|
|
386
|
+
shared: cct.SharedT = False,
|
|
387
|
+
**kwargs
|
|
388
|
+
) -> None:
|
|
389
|
+
"""
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
slope : Sampler | [pair of] float
|
|
394
|
+
Bound for slope
|
|
395
|
+
offset : Sampler | [pair of] float
|
|
396
|
+
Bound for offset
|
|
397
|
+
|
|
398
|
+
Other Parameters
|
|
399
|
+
----------------
|
|
400
|
+
shared
|
|
401
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
402
|
+
for details.
|
|
403
|
+
returns, append, prefix, include, exclude, consume
|
|
404
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
405
|
+
"""
|
|
406
|
+
super().__init__(
|
|
407
|
+
AddMulTransform,
|
|
408
|
+
(Uniform.make(make_range(slope)),
|
|
409
|
+
Uniform.make(make_range(offset))),
|
|
410
|
+
shared=shared,
|
|
411
|
+
**kwargs
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class SplineUpsampleTransform(FinalTransform):
|
|
416
|
+
"""Upsample a field using spline interpolation"""
|
|
417
|
+
|
|
418
|
+
def __init__(
|
|
419
|
+
self,
|
|
420
|
+
order: int = 3,
|
|
421
|
+
prefilter: bool = False,
|
|
422
|
+
**kwargs
|
|
423
|
+
) -> None:
|
|
424
|
+
"""
|
|
425
|
+
Parameters
|
|
426
|
+
----------
|
|
427
|
+
order : int
|
|
428
|
+
Spline interpolation order
|
|
429
|
+
prefilter : bool
|
|
430
|
+
Spline prefiltering
|
|
431
|
+
(True for interpolation, False for spline evaluation)
|
|
432
|
+
|
|
433
|
+
Other Parameters
|
|
434
|
+
----------------
|
|
435
|
+
returns, append, prefix, include, exclude, consume
|
|
436
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
437
|
+
"""
|
|
438
|
+
super().__init__(**kwargs)
|
|
439
|
+
self.order = order
|
|
440
|
+
self.prefilter = prefilter
|
|
441
|
+
|
|
442
|
+
def _xform(self, x: Tensor) -> Tensor:
|
|
443
|
+
fullshape = x.shape[1:]
|
|
444
|
+
if self.order == 1:
|
|
445
|
+
mode = ('trilinear' if len(fullshape) == 3 else
|
|
446
|
+
'bilinear' if len(fullshape) == 2 else
|
|
447
|
+
'linear')
|
|
448
|
+
y = interpolate(
|
|
449
|
+
x.unsqueeze(0), fullshape, mode=mode,
|
|
450
|
+
align_corners=True
|
|
451
|
+
).squeeze(-0)
|
|
452
|
+
else:
|
|
453
|
+
y = interpol.resize(
|
|
454
|
+
x, shape=fullshape, interpolation=self.order,
|
|
455
|
+
prefilter=self.prefilter
|
|
456
|
+
)
|
|
457
|
+
return y
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
class BaseFieldTransform(NonFinalTransform):
|
|
461
|
+
"""Base class for transforms that sample a smooth field"""
|
|
462
|
+
|
|
463
|
+
Final = Next = AddValueTransform
|
|
464
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
465
|
+
|
|
466
|
+
value_name: str = 'field'
|
|
467
|
+
|
|
468
|
+
def __init__(
|
|
469
|
+
self,
|
|
470
|
+
shape: tx.Union[int, tx.Sequence[int]] = 5,
|
|
471
|
+
vmin: float = 0 ,
|
|
472
|
+
vmax: float = 1,
|
|
473
|
+
order: int = 3,
|
|
474
|
+
slice: tx.Optional[int] = None,
|
|
475
|
+
thickness: tx.Optional[int] = None,
|
|
476
|
+
*,
|
|
477
|
+
shared: cct.SharedT = False,
|
|
478
|
+
**kwargs
|
|
479
|
+
) -> None:
|
|
480
|
+
"""
|
|
481
|
+
|
|
482
|
+
Parameters
|
|
483
|
+
----------
|
|
484
|
+
shape : [list of] int
|
|
485
|
+
Number of spline control points
|
|
486
|
+
vmin : float
|
|
487
|
+
Minimum value
|
|
488
|
+
vmax : float
|
|
489
|
+
Maximum value
|
|
490
|
+
order : int
|
|
491
|
+
Spline order
|
|
492
|
+
slice : int
|
|
493
|
+
Slice direction, if slicewise.
|
|
494
|
+
thickness : int
|
|
495
|
+
Slice thickness, if slicewise.
|
|
496
|
+
Note that `shape` will be scaled along the slice direction
|
|
497
|
+
so that the number of nodes is approximately preserved.
|
|
498
|
+
|
|
499
|
+
Other Parameters
|
|
500
|
+
----------------
|
|
501
|
+
shared
|
|
502
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
503
|
+
for details.
|
|
504
|
+
returns : [list or dict of] {'input', 'output', 'field'}
|
|
505
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
506
|
+
append, prefix, include, exclude, consume
|
|
507
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
508
|
+
|
|
509
|
+
"""
|
|
510
|
+
super().__init__(shared=shared, **kwargs)
|
|
511
|
+
self.shape = shape
|
|
512
|
+
self.vmax = vmax
|
|
513
|
+
self.vmin = vmin
|
|
514
|
+
self.order = order
|
|
515
|
+
self.slice = slice
|
|
516
|
+
self.thickness = thickness
|
|
517
|
+
|
|
518
|
+
def make_field(
|
|
519
|
+
self,
|
|
520
|
+
batch: int,
|
|
521
|
+
smallshape: tx.Sequence[int],
|
|
522
|
+
fullshape: tx.Optional[tx.Sequence[int]] = None,
|
|
523
|
+
**backend
|
|
524
|
+
) -> None:
|
|
525
|
+
"""Generate the random coefficients.
|
|
526
|
+
|
|
527
|
+
Parameters
|
|
528
|
+
----------
|
|
529
|
+
batch : int
|
|
530
|
+
Number of fields to generate
|
|
531
|
+
smallshape : list of int
|
|
532
|
+
Number of spline control points
|
|
533
|
+
fullshape : list of int, optional
|
|
534
|
+
If given, the coefficients will be upsampled to this shape.
|
|
535
|
+
|
|
536
|
+
Other Parameters
|
|
537
|
+
----------------
|
|
538
|
+
dtype : torch.dtype
|
|
539
|
+
Data type of the generated field.
|
|
540
|
+
device : torch.device | str
|
|
541
|
+
Device on which to generate the field.
|
|
542
|
+
|
|
543
|
+
Returns
|
|
544
|
+
-------
|
|
545
|
+
field : (batch, *smallshape) tensor | (batch, *fullshape) tensor
|
|
546
|
+
If `fullshape` is given, returns the upsampled field of values.
|
|
547
|
+
Otherise, returns the spline coefficients.
|
|
548
|
+
|
|
549
|
+
"""
|
|
550
|
+
smallshape = ensure_list(smallshape, len(fullshape))
|
|
551
|
+
smallshape = [min(small, full) for small, full
|
|
552
|
+
in zip(smallshape, fullshape)]
|
|
553
|
+
if not backend['dtype'].is_floating_point:
|
|
554
|
+
backend['dtype'] = torch.get_default_dtype()
|
|
555
|
+
b = torch.rand([batch, *smallshape], **backend)
|
|
556
|
+
if fullshape:
|
|
557
|
+
b = self.upsample_field(b, fullshape)
|
|
558
|
+
return b
|
|
559
|
+
|
|
560
|
+
def upsample_field(self, coeff: Tensor, shape: tx.Sequence[int]) -> Tensor:
|
|
561
|
+
"""Compute the full-sized field from its spline coefficients.
|
|
562
|
+
|
|
563
|
+
Parameters
|
|
564
|
+
----------
|
|
565
|
+
coeff : (batch, *smallshape) tensor
|
|
566
|
+
Spline coefficients
|
|
567
|
+
shape : list of int
|
|
568
|
+
Target shape for the upsampled field
|
|
569
|
+
|
|
570
|
+
Returns
|
|
571
|
+
-------
|
|
572
|
+
field : (batch, *shape) tensor
|
|
573
|
+
Upsampled field of values
|
|
574
|
+
"""
|
|
575
|
+
if self.order == 1:
|
|
576
|
+
mode = ('trilinear' if len(shape) == 3 else
|
|
577
|
+
'bilinear' if len(shape) == 2 else
|
|
578
|
+
'linear')
|
|
579
|
+
b = interpolate(
|
|
580
|
+
coeff.unsqueeze(0), shape, mode=mode,
|
|
581
|
+
align_corners=True
|
|
582
|
+
).squeeze(-0)
|
|
583
|
+
else:
|
|
584
|
+
b = interpol.resize(
|
|
585
|
+
coeff, shape=shape, interpolation=self.order,
|
|
586
|
+
prefilter=False
|
|
587
|
+
)
|
|
588
|
+
return b
|
|
589
|
+
|
|
590
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
591
|
+
if max_depth == 0:
|
|
592
|
+
return self
|
|
593
|
+
|
|
594
|
+
ndim = x.ndim - 1
|
|
595
|
+
fullshape = list(x.shape[1:])
|
|
596
|
+
batch = 1 if 'channels' in self.shared else len(x)
|
|
597
|
+
backend = dict(dtype=x.dtype, device=x.device)
|
|
598
|
+
|
|
599
|
+
# slicewise bias
|
|
600
|
+
if self.slice is not None:
|
|
601
|
+
slice = positive_index(self.slice, ndim)
|
|
602
|
+
thickness = self.thickness or 1
|
|
603
|
+
thickness = min(thickness, x.shape[1+slice])
|
|
604
|
+
nb_slices = int(math.ceil(x.shape[1+slice] / thickness))
|
|
605
|
+
|
|
606
|
+
smallshape = ensure_list(self.shape, ndim)
|
|
607
|
+
smallshape[slice] = int(math.ceil(smallshape[slice] / nb_slices))
|
|
608
|
+
smallshape = [min(small, full) for small, full
|
|
609
|
+
in zip(smallshape, fullshape)]
|
|
610
|
+
|
|
611
|
+
if thickness == 1:
|
|
612
|
+
# bias independent across slices -> batch it
|
|
613
|
+
batch1 = batch * fullshape[slice]
|
|
614
|
+
del smallshape[slice]
|
|
615
|
+
del fullshape[slice]
|
|
616
|
+
b = self.make_field(batch1, smallshape, fullshape, **backend)
|
|
617
|
+
b = b.reshape([batch, -1, *b.shape[1:]])
|
|
618
|
+
b = b.movedim(1, 1+slice)
|
|
619
|
+
|
|
620
|
+
elif fullshape[slice] % thickness == 0:
|
|
621
|
+
# shape divisible by thickness -> unfold and batch
|
|
622
|
+
fullshape0 = list(fullshape)
|
|
623
|
+
_, *fullshape = x.shape
|
|
624
|
+
batch1 = batch * nb_slices
|
|
625
|
+
fullshape[slice] = thickness
|
|
626
|
+
b = self.make_field(batch1, smallshape, fullshape, **backend)
|
|
627
|
+
b = b.reshape([batch, -1, *b.shape[1:]])
|
|
628
|
+
b = b.movedim(1, 1+slice)
|
|
629
|
+
b = b.reshape([batch, *fullshape0])
|
|
630
|
+
|
|
631
|
+
else:
|
|
632
|
+
# otherwise, the input is not exactly divisible by thickness
|
|
633
|
+
b = x.new_empty([batch, *fullshape], **backend)
|
|
634
|
+
|
|
635
|
+
# use same strategy as before for all but last slice
|
|
636
|
+
fullshape0 = list(fullshape)
|
|
637
|
+
_, *fullshape = x.shape
|
|
638
|
+
batch1 = batch * (nb_slices - 1)
|
|
639
|
+
fullshape[slice] = thickness
|
|
640
|
+
fullshape0[slice] = (nb_slices - 1) * thickness
|
|
641
|
+
b1 = self.make_field(batch1, smallshape, fullshape, **backend)
|
|
642
|
+
b1 = b1.reshape([batch, -1, *b1.shape[1:]])
|
|
643
|
+
b1 = b1.movedim(1, 1+slice)
|
|
644
|
+
b1 = b1.reshape([batch, *fullshape0])
|
|
645
|
+
|
|
646
|
+
# copy into the larger placeholder
|
|
647
|
+
b1 = b1.movedim(1+slice, 0)
|
|
648
|
+
b.movedim(1+slice, 0)[:len(b1)].copy_(b1)
|
|
649
|
+
|
|
650
|
+
# process last slice
|
|
651
|
+
fullshape[slice] = b.shape[1+slice] - len(b1)
|
|
652
|
+
b1 = self.make_field(batch, smallshape, fullshape, **backend)
|
|
653
|
+
b1 = b1.movedim(1+slice, 0)
|
|
654
|
+
b.movedim(1+slice, 0)[-len(b1):].copy_(b1)
|
|
655
|
+
|
|
656
|
+
else:
|
|
657
|
+
# global bias
|
|
658
|
+
b = self.make_field(batch, self.shape, fullshape, **backend)
|
|
659
|
+
|
|
660
|
+
# rescale intensities
|
|
661
|
+
batch = len(b)
|
|
662
|
+
vmin, vmax = self.vmin, self.vmax
|
|
663
|
+
if torch.is_tensor(vmin):
|
|
664
|
+
while vmin.ndim < b.ndim:
|
|
665
|
+
vmin = vmin.unsqueeze(-1)
|
|
666
|
+
batch = max(batch, len(vmin))
|
|
667
|
+
if torch.is_tensor(vmax):
|
|
668
|
+
while vmax.ndim < b.ndim:
|
|
669
|
+
vmax = vmax.unsqueeze(-1)
|
|
670
|
+
batch = max(batch, len(vmax))
|
|
671
|
+
if len(b) < batch:
|
|
672
|
+
b = b.expand([batch, *b.shape[1:]]).clone()
|
|
673
|
+
|
|
674
|
+
b = add_(mul_(b, self.vmax-self.vmin), self.vmin)
|
|
675
|
+
|
|
676
|
+
return self.Next(
|
|
677
|
+
b, value_name=self.value_name, **self.get_prm()
|
|
678
|
+
).unroll(x, max_depth-1)
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
class MulFieldTransform(BaseFieldTransform):
|
|
682
|
+
"""Smooth multiplicative (bias) field"""
|
|
683
|
+
|
|
684
|
+
Final = Next = MulValueTransform
|
|
685
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
class RandomMulFieldTransform(NonFinalTransform):
|
|
689
|
+
"""Random multiplicative bias field transform"""
|
|
690
|
+
|
|
691
|
+
Next = MulFieldTransform
|
|
692
|
+
"""The transform type returned by `next`."""
|
|
693
|
+
|
|
694
|
+
Final = MulValueTransform
|
|
695
|
+
"""The transform type returned by `final`."""
|
|
696
|
+
|
|
697
|
+
def __init__(
|
|
698
|
+
self,
|
|
699
|
+
shape: tx.Union[Sampler, int] = 8,
|
|
700
|
+
vmax: tx.Union[Sampler, float] = 1,
|
|
701
|
+
order: int = 3,
|
|
702
|
+
symmetric: tx.Union[bool, float] = False,
|
|
703
|
+
*,
|
|
704
|
+
shared: cct.SharedT = False,
|
|
705
|
+
shared_field: tx.Union[str, bool, None] = None,
|
|
706
|
+
**kwargs
|
|
707
|
+
) -> None:
|
|
708
|
+
"""
|
|
709
|
+
Parameters
|
|
710
|
+
----------
|
|
711
|
+
shape : Sampler | int
|
|
712
|
+
Sampler or Upper bound for number of control points
|
|
713
|
+
vmax : Sampler | float
|
|
714
|
+
Sampler or Upper bound for maximum value
|
|
715
|
+
order : int
|
|
716
|
+
Spline order
|
|
717
|
+
symmetric : bool | float
|
|
718
|
+
If a float, the bias field will take values in
|
|
719
|
+
`(symmetric-vmax, symmetric+vmax)`.
|
|
720
|
+
If False, it will take values in `(0, vmax)`.
|
|
721
|
+
If True, it will take values in `(1-vmax, 1+vmax)`.
|
|
722
|
+
|
|
723
|
+
Other Parameters
|
|
724
|
+
----------------
|
|
725
|
+
shared
|
|
726
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
727
|
+
for details.
|
|
728
|
+
shared_field
|
|
729
|
+
Whether to share random field across tensors and/or channels.
|
|
730
|
+
By default: same as `shared`
|
|
731
|
+
returns : [list or dict of] {'input', 'output', 'field'}
|
|
732
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
733
|
+
append, prefix, include, exclude, consume
|
|
734
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
735
|
+
""" # noqa: E501
|
|
736
|
+
super().__init__(shared=shared, **kwargs)
|
|
737
|
+
self.vmax = Uniform.make(make_range(0, vmax))
|
|
738
|
+
self.shape = RandInt.make(make_range(2, shape))
|
|
739
|
+
self.order = Fixed.make(order)
|
|
740
|
+
self.symmetric = symmetric
|
|
741
|
+
self.shared_field = self._prepare_shared(shared_field)
|
|
742
|
+
|
|
743
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
744
|
+
vmax, shape, order = self.vmax, self.shape, self.order
|
|
745
|
+
shared_field = self.shared_field
|
|
746
|
+
if isinstance(vmax, Sampler):
|
|
747
|
+
vmax = vmax()
|
|
748
|
+
if isinstance(shape, Sampler):
|
|
749
|
+
shape = shape(x.ndim-1)
|
|
750
|
+
if isinstance(order, Sampler):
|
|
751
|
+
order = order()
|
|
752
|
+
if shared_field is None:
|
|
753
|
+
shared_field = self.shared
|
|
754
|
+
if self.symmetric is False:
|
|
755
|
+
vmin = 0
|
|
756
|
+
else:
|
|
757
|
+
mid = self.symmetric
|
|
758
|
+
vmin, vmax = mid - vmax, mid + vmax
|
|
759
|
+
return MulFieldTransform(
|
|
760
|
+
shape, vmin, vmax, order, shared=shared_field, **self.get_prm()
|
|
761
|
+
).unroll(x, max_depth-1)
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
class RandomSlicewiseMulFieldTransform(NonFinalTransform):
|
|
765
|
+
"""Random multiplicative bias field transform, per slice or slab"""
|
|
766
|
+
|
|
767
|
+
Next = MulFieldTransform
|
|
768
|
+
"""The transform type returned by `next`."""
|
|
769
|
+
|
|
770
|
+
Final = MulValueTransform
|
|
771
|
+
"""The transform type returned by `final`."""
|
|
772
|
+
|
|
773
|
+
def __init__(
|
|
774
|
+
self,
|
|
775
|
+
shape: tx.Union[Sampler, int] = 8,
|
|
776
|
+
vmax: tx.Union[Sampler, float] = 1,
|
|
777
|
+
order: int = 3,
|
|
778
|
+
slice: tx.Optional[int] = None,
|
|
779
|
+
thickness: tx.Union[Sampler, int] = 32,
|
|
780
|
+
shape_through: tx.Optional[tx.Union[Sampler, int]] = None,
|
|
781
|
+
*,
|
|
782
|
+
shared: cct.SharedT = False,
|
|
783
|
+
shared_field: tx.Union[str, bool, None] = None,
|
|
784
|
+
**kwargs
|
|
785
|
+
) -> None:
|
|
786
|
+
"""
|
|
787
|
+
Parameters
|
|
788
|
+
----------
|
|
789
|
+
shape : Sampler | int
|
|
790
|
+
Sampler or Upper bound for number of control points
|
|
791
|
+
vmax : Sampler | float
|
|
792
|
+
Sampler or Upper bound for maximum value
|
|
793
|
+
order : int
|
|
794
|
+
Spline order
|
|
795
|
+
slice : int | None
|
|
796
|
+
Slice axis. If None, sample one randomly
|
|
797
|
+
thickness : Sampler | int
|
|
798
|
+
Sampler or Upper bound for slice thickness
|
|
799
|
+
shape_through : Sampler | int | None
|
|
800
|
+
Sampler or Upper bound for number of control points
|
|
801
|
+
along the slice direction. If None, same as `shape`.
|
|
802
|
+
|
|
803
|
+
Other Parameters
|
|
804
|
+
----------------
|
|
805
|
+
shared
|
|
806
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
807
|
+
for details.
|
|
808
|
+
shared_field
|
|
809
|
+
Whether to share random field across tensors and/or channels.
|
|
810
|
+
By default: same as `shared`
|
|
811
|
+
returns : [list or dict of] {'input', 'output', 'field'}
|
|
812
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
813
|
+
append, prefix, include, exclude, consume
|
|
814
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
815
|
+
""" # noqa: E501
|
|
816
|
+
super().__init__(shared=shared, **kwargs)
|
|
817
|
+
if shape_through is not None:
|
|
818
|
+
shape_through = RandInt.make(make_range(1, shape_through))
|
|
819
|
+
self.vmax = Uniform.make(make_range(0, vmax))
|
|
820
|
+
self.shape = RandInt.make(make_range(2, shape))
|
|
821
|
+
self.order = Fixed.make(order)
|
|
822
|
+
self.slice = slice
|
|
823
|
+
self.thickness = RandInt.make(make_range(1, thickness))
|
|
824
|
+
self.shape_through = shape_through
|
|
825
|
+
self.shared_field = self._prepare_shared(shared_field)
|
|
826
|
+
|
|
827
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
828
|
+
if max_depth == 0:
|
|
829
|
+
return self
|
|
830
|
+
ndim = x.ndim - 1
|
|
831
|
+
|
|
832
|
+
vmax = self.vmax
|
|
833
|
+
shape = self.shape
|
|
834
|
+
order = self.order
|
|
835
|
+
slice = self.slice
|
|
836
|
+
thickness = self.thickness
|
|
837
|
+
shape_through = self.shape_through
|
|
838
|
+
shared_field = self.shared_field
|
|
839
|
+
|
|
840
|
+
if slice is None:
|
|
841
|
+
slice = RandInt(x.ndim-2)
|
|
842
|
+
|
|
843
|
+
if shape_through is not None:
|
|
844
|
+
if isinstance(slice, Sampler):
|
|
845
|
+
slice = slice()
|
|
846
|
+
slice = positive_index(slice, ndim)
|
|
847
|
+
if isinstance(shape, Sampler):
|
|
848
|
+
shape = shape(ndim)
|
|
849
|
+
shape = list(ensure_list(shape, ndim))
|
|
850
|
+
if isinstance(thickness, Sampler):
|
|
851
|
+
thickness = thickness()
|
|
852
|
+
if isinstance(shape_through, Sampler):
|
|
853
|
+
shape_through = shape_through()
|
|
854
|
+
shape_through0 = x.shape[1+self.sample['slice']]
|
|
855
|
+
shape_through *= int(math.ceil(shape_through0 / thickness))
|
|
856
|
+
shape[slice] = shape_through
|
|
857
|
+
|
|
858
|
+
if isinstance(vmax, Sampler):
|
|
859
|
+
vmax = vmax()
|
|
860
|
+
if isinstance(shape, Sampler):
|
|
861
|
+
shape = shape(ndim)
|
|
862
|
+
if isinstance(order, Sampler):
|
|
863
|
+
order = order()
|
|
864
|
+
if isinstance(slice, Sampler):
|
|
865
|
+
slice = slice()
|
|
866
|
+
if isinstance(thickness, Sampler):
|
|
867
|
+
thickness = thickness()
|
|
868
|
+
if isinstance(shape_through, Sampler):
|
|
869
|
+
shape_through = shape_through()
|
|
870
|
+
if shared_field is None:
|
|
871
|
+
shared_field = self.shared
|
|
872
|
+
|
|
873
|
+
return MulFieldTransform(
|
|
874
|
+
shape, 0, vmax, order, slice, thickness,
|
|
875
|
+
shared=shared_field, **self.get_prm()
|
|
876
|
+
).unroll(x, max_depth-1)
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
class AddFieldTransform(BaseFieldTransform):
|
|
880
|
+
"""Smooth additive (bias) field"""
|
|
881
|
+
|
|
882
|
+
Final = Next = AddValueTransform
|
|
883
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
class RandomAddFieldTransform(NonFinalTransform):
|
|
887
|
+
"""Random additive bias field transform"""
|
|
888
|
+
|
|
889
|
+
def __init__(
|
|
890
|
+
self,
|
|
891
|
+
shape: tx.Union[Sampler, int] = 8,
|
|
892
|
+
vmin: tx.Union[Sampler, float] = -1,
|
|
893
|
+
vmax: tx.Union[Sampler, float] = 1,
|
|
894
|
+
order: tx.Union[Sampler, int] = 3,
|
|
895
|
+
*,
|
|
896
|
+
shared: cct.SharedT = False,
|
|
897
|
+
shared_field: tx.Union[str, bool, None] = None,
|
|
898
|
+
**kwargs
|
|
899
|
+
) -> None:
|
|
900
|
+
"""
|
|
901
|
+
Parameters
|
|
902
|
+
----------
|
|
903
|
+
shape : Sampler | int
|
|
904
|
+
Sampler or Upper bound for number of control points
|
|
905
|
+
vmin : Sampler | float
|
|
906
|
+
Sampler or Lower bound for minimum value
|
|
907
|
+
vmax : Sampler | float
|
|
908
|
+
Sampler or Upper bound for maximum value
|
|
909
|
+
order : Sampler | int
|
|
910
|
+
Spline order
|
|
911
|
+
|
|
912
|
+
Other Parameters
|
|
913
|
+
----------------
|
|
914
|
+
shared
|
|
915
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
916
|
+
for details.
|
|
917
|
+
shared_field
|
|
918
|
+
Whether to share random field across tensors and/or channels.
|
|
919
|
+
By default: same as `shared`
|
|
920
|
+
returns : [list or dict of] {'input', 'output', 'field'}
|
|
921
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
922
|
+
append, prefix, include, exclude, consume
|
|
923
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
924
|
+
"""
|
|
925
|
+
super().__init__(shared=shared, **kwargs)
|
|
926
|
+
self.vmin = Uniform.make(make_range(vmin, 0))
|
|
927
|
+
self.vmax = Uniform.make(make_range(0, vmax))
|
|
928
|
+
self.shape = RandInt.make(make_range(2, shape))
|
|
929
|
+
self.order = Fixed.make(order)
|
|
930
|
+
self.shared_field = self._prepare_shared(shared_field)
|
|
931
|
+
|
|
932
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
933
|
+
vmin, vmax, shape, order = self.vmin, self.vmax, self.shape, self.order
|
|
934
|
+
shared_field = self.shared_field
|
|
935
|
+
if isinstance(vmin, Sampler):
|
|
936
|
+
vmin = vmin()
|
|
937
|
+
if isinstance(vmax, Sampler):
|
|
938
|
+
vmax = vmax()
|
|
939
|
+
if isinstance(shape, Sampler):
|
|
940
|
+
shape = shape(x.ndim-1)
|
|
941
|
+
if isinstance(order, Sampler):
|
|
942
|
+
order = order()
|
|
943
|
+
if shared_field is None:
|
|
944
|
+
shared_field = self.shared
|
|
945
|
+
return AddFieldTransform(
|
|
946
|
+
shape, vmin, vmax, order, shared=shared_field, **self.get_prm()
|
|
947
|
+
).unroll(x, max_depth-1)
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
class GammaFinalTransform(FinalTransform):
|
|
951
|
+
"""Gamma correction with fixed parameters.
|
|
952
|
+
|
|
953
|
+
The transform is defined as:
|
|
954
|
+
|
|
955
|
+
```python
|
|
956
|
+
y = (x-vmin) / (vmax-vmin) ** gamma * (vmax-vmin) + vmin
|
|
957
|
+
```
|
|
958
|
+
|
|
959
|
+
In this transform, `vmin` and `vmax` are pre-calculated and fixed,
|
|
960
|
+
whereas in `GammaTransform`, they are computed from the image intensities.
|
|
961
|
+
"""
|
|
962
|
+
|
|
963
|
+
_ScalarOrVector = tx.Union[float, tx.Sequence[float], Tensor]
|
|
964
|
+
|
|
965
|
+
def __init__(
|
|
966
|
+
self,
|
|
967
|
+
gamma: _ScalarOrVector = 1,
|
|
968
|
+
vmin: _ScalarOrVector = 0,
|
|
969
|
+
vmax: _ScalarOrVector = 1,
|
|
970
|
+
**kwargs
|
|
971
|
+
):
|
|
972
|
+
"""
|
|
973
|
+
Parameters
|
|
974
|
+
----------
|
|
975
|
+
gamma : number | (C,) list[number] | (C,) tensor
|
|
976
|
+
Exponent of the Gamma transform
|
|
977
|
+
vmin : number | (C,) list[number] | (C,) tensor
|
|
978
|
+
Minimum value for the transform
|
|
979
|
+
vmax : number | (C,) list[number] | (C,) tensor
|
|
980
|
+
Maximum value for the transform
|
|
981
|
+
|
|
982
|
+
Other Parameters
|
|
983
|
+
----------------
|
|
984
|
+
returns : [list or dict of] {'input', 'output', 'vmin', 'vmax', 'gamma'}
|
|
985
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
986
|
+
append, prefix, include, exclude, consume
|
|
987
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
988
|
+
"""
|
|
989
|
+
super().__init__(**kwargs)
|
|
990
|
+
self.gamma = gamma
|
|
991
|
+
self.vmin = vmin
|
|
992
|
+
self.vmax = vmax
|
|
993
|
+
|
|
994
|
+
def __repr__(self) -> str:
|
|
995
|
+
gamma, vmin, vmax = self.gamma, self.vmin, self.vmax
|
|
996
|
+
if torch.is_tensor(gamma):
|
|
997
|
+
gamma = gamma.detach().tolist()
|
|
998
|
+
if torch.is_tensor(vmin):
|
|
999
|
+
vmin = vmin.detach().tolist()
|
|
1000
|
+
if torch.is_tensor(vmax):
|
|
1001
|
+
vmax = vmax.detach().tolist()
|
|
1002
|
+
return f"{type(self).__name__}(gamma={gamma}, vmin={vmin}, vmax={vmax})"
|
|
1003
|
+
|
|
1004
|
+
def _xform(self, x: Tensor) -> Returned:
|
|
1005
|
+
vmin = torch.as_tensor(self.vmin, dtype=x.dtype, device=x.device)
|
|
1006
|
+
vmax = torch.as_tensor(self.vmax, dtype=x.dtype, device=x.device)
|
|
1007
|
+
gamma = torch.as_tensor(self.gamma, dtype=x.dtype, device=x.device)
|
|
1008
|
+
vmin = vmin.reshape([-1] + [1] * (x.ndim-1))
|
|
1009
|
+
vmax = vmax.reshape([-1] + [1] * (x.ndim-1))
|
|
1010
|
+
gamma = gamma.reshape([-1] + [1] * (x.ndim-1))
|
|
1011
|
+
|
|
1012
|
+
# NOTE
|
|
1013
|
+
# * we add a little epsilon to the denominator to avoid
|
|
1014
|
+
# division by zero.
|
|
1015
|
+
# * We also ensure that the rescaled input is in (0+eps, 1-eps)
|
|
1016
|
+
# to ensure differentiability everywhere.
|
|
1017
|
+
# * The vmin/vmax may have been computed on a different image
|
|
1018
|
+
# than x, so we cannot trust that x.min() < vmin.
|
|
1019
|
+
|
|
1020
|
+
den = vmax - vmin
|
|
1021
|
+
num = x - vmin
|
|
1022
|
+
num = clamp_(num, 1e-5 * den, (1.0 - 1e-5) * den)
|
|
1023
|
+
y = div_(num, add_(den, 1e-5))
|
|
1024
|
+
y = pow_(y, gamma)
|
|
1025
|
+
if gamma.requires_grad:
|
|
1026
|
+
# When gamma requires grad, mul_(y, vmax-vmin) is happy
|
|
1027
|
+
# to overwrite y, but we cant because we need y to
|
|
1028
|
+
# backprop through pow. So we need an explicit branch.
|
|
1029
|
+
y = torch.add(torch.mul(y, vmax - vmin), vmin)
|
|
1030
|
+
else:
|
|
1031
|
+
y = add_(mul_(y, vmax - vmin), vmin)
|
|
1032
|
+
|
|
1033
|
+
return prepare_output(
|
|
1034
|
+
dict(input=x, output=y, vmin=vmin, vmax=vmax, gamma=gamma),
|
|
1035
|
+
self.returns)
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
class GammaTransform(NonFinalTransform):
|
|
1039
|
+
"""Gamma correction
|
|
1040
|
+
|
|
1041
|
+
References
|
|
1042
|
+
----------
|
|
1043
|
+
1. https://en.wikipedia.org/wiki/Gamma_correction
|
|
1044
|
+
"""
|
|
1045
|
+
|
|
1046
|
+
Final = Next = GammaFinalTransform
|
|
1047
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
1048
|
+
|
|
1049
|
+
def __init__(
|
|
1050
|
+
self,
|
|
1051
|
+
gamma: float = 1,
|
|
1052
|
+
vmin: tx.Optional[float] = None,
|
|
1053
|
+
vmax: tx.Optional[float] = None,
|
|
1054
|
+
*,
|
|
1055
|
+
shared: cct.SharedT = False,
|
|
1056
|
+
**kwargs
|
|
1057
|
+
) -> None:
|
|
1058
|
+
"""
|
|
1059
|
+
|
|
1060
|
+
Parameters
|
|
1061
|
+
----------
|
|
1062
|
+
gamma : float
|
|
1063
|
+
Exponent of the Gamma transform
|
|
1064
|
+
vmin : float | None
|
|
1065
|
+
Value to use as the minimum (default: x.min())
|
|
1066
|
+
vmax : float | None
|
|
1067
|
+
Value to use as the maximum (default: x.max())
|
|
1068
|
+
|
|
1069
|
+
Other Parameters
|
|
1070
|
+
----------------
|
|
1071
|
+
shared
|
|
1072
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
1073
|
+
for details.
|
|
1074
|
+
returns : [list or dict of] {'input', 'output', 'vmin', 'vmax', 'gamma'}
|
|
1075
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
1076
|
+
append, prefix, include, exclude, consume
|
|
1077
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
1078
|
+
"""
|
|
1079
|
+
super().__init__(shared=shared, **kwargs)
|
|
1080
|
+
self.gamma = kwargs.pop('value', gamma)
|
|
1081
|
+
self.vmin = vmin
|
|
1082
|
+
self.vmax = vmax
|
|
1083
|
+
|
|
1084
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
1085
|
+
if max_depth == 0:
|
|
1086
|
+
return self
|
|
1087
|
+
ndim = x.dim() - 1
|
|
1088
|
+
if self.vmin is None:
|
|
1089
|
+
vmin = x.reshape(len(x), -1).min(-1).values
|
|
1090
|
+
for _ in range(ndim):
|
|
1091
|
+
vmin = vmin.unsqueeze(-1)
|
|
1092
|
+
if 'channels' in self.shared:
|
|
1093
|
+
vmin = vmin.min()
|
|
1094
|
+
else:
|
|
1095
|
+
vmin = self.vmin
|
|
1096
|
+
if self.vmax is None:
|
|
1097
|
+
vmax = x.reshape(len(x), -1).max(-1).values
|
|
1098
|
+
for _ in range(ndim):
|
|
1099
|
+
vmax = vmax.unsqueeze(-1)
|
|
1100
|
+
if 'channels' in self.shared:
|
|
1101
|
+
vmax = vmax.max()
|
|
1102
|
+
else:
|
|
1103
|
+
vmax = self.vmax
|
|
1104
|
+
return self.Next(
|
|
1105
|
+
self.gamma, vmin, vmax, **self.get_prm()
|
|
1106
|
+
).unroll(max_depth-1)
|
|
1107
|
+
|
|
1108
|
+
|
|
1109
|
+
class RandomGammaTransform(NonFinalTransform):
|
|
1110
|
+
"""
|
|
1111
|
+
Random Gamma transform.
|
|
1112
|
+
"""
|
|
1113
|
+
|
|
1114
|
+
Next = GammaTransform
|
|
1115
|
+
"""The transform type returned by `next`."""
|
|
1116
|
+
|
|
1117
|
+
Final = GammaFinalTransform
|
|
1118
|
+
"""The transform type returned by `final`."""
|
|
1119
|
+
|
|
1120
|
+
def __init__(
|
|
1121
|
+
self,
|
|
1122
|
+
gamma: tx.Union[Sampler, float, tx.Tuple[float, float]] = (0.5, 2),
|
|
1123
|
+
*,
|
|
1124
|
+
shared: cct.SharedT = False,
|
|
1125
|
+
shared_minmax: tx.Optional[cct.SharedT] = None,
|
|
1126
|
+
**kwargs
|
|
1127
|
+
):
|
|
1128
|
+
"""
|
|
1129
|
+
Parameters
|
|
1130
|
+
----------
|
|
1131
|
+
gamma : Sampler or [pair of] float
|
|
1132
|
+
Sampler or range for the exponent value
|
|
1133
|
+
|
|
1134
|
+
Other Parameters
|
|
1135
|
+
----------------
|
|
1136
|
+
shared
|
|
1137
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
1138
|
+
for details.
|
|
1139
|
+
shared_minmax
|
|
1140
|
+
Use the same vmin/vmax for all channels.
|
|
1141
|
+
Default: same as `shared`.
|
|
1142
|
+
returns : [list or dict of] {'input', 'output', 'vmin', 'vmax', 'gamma'}
|
|
1143
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
1144
|
+
append, prefix, include, exclude, consume
|
|
1145
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
1146
|
+
"""
|
|
1147
|
+
super().__init__(shared=shared, **kwargs)
|
|
1148
|
+
self.gamma = Uniform.make(kwargs.pop('value', gamma))
|
|
1149
|
+
self.shared_minmax = self._prepare_shared(shared_minmax)
|
|
1150
|
+
|
|
1151
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
1152
|
+
gamma = self.gamma
|
|
1153
|
+
if isinstance(gamma, Sampler):
|
|
1154
|
+
gamma = gamma()
|
|
1155
|
+
shared_minmax = self.shared_minmax
|
|
1156
|
+
if shared_minmax is None:
|
|
1157
|
+
shared_minmax = self.shared
|
|
1158
|
+
return GammaTransform(
|
|
1159
|
+
gamma, shared=shared_minmax, **self.get_prm()
|
|
1160
|
+
).unroll(x, max_depth-1)
|
|
1161
|
+
|
|
1162
|
+
|
|
1163
|
+
class ZTransform(NonFinalTransform):
|
|
1164
|
+
"""
|
|
1165
|
+
Z-transform the data -> zero mean, unit standard deviation
|
|
1166
|
+
"""
|
|
1167
|
+
|
|
1168
|
+
Final = Next = AddMulTransform
|
|
1169
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
1170
|
+
|
|
1171
|
+
def __init__(
|
|
1172
|
+
self, mu: float = 0, sigma: float = 1,
|
|
1173
|
+
*, shared: cct.SharedT = False, **kwargs
|
|
1174
|
+
):
|
|
1175
|
+
"""
|
|
1176
|
+
Parameters
|
|
1177
|
+
----------
|
|
1178
|
+
mu : float
|
|
1179
|
+
Target mean. If None, keep the input mean.
|
|
1180
|
+
sigma : float
|
|
1181
|
+
Target standard deviation. If None, keep the input sd.
|
|
1182
|
+
|
|
1183
|
+
Other Parameters
|
|
1184
|
+
----------------
|
|
1185
|
+
shared
|
|
1186
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
1187
|
+
for details.
|
|
1188
|
+
returns, append, prefix, include, exclude, consume
|
|
1189
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
1190
|
+
"""
|
|
1191
|
+
super().__init__(shared=shared, **kwargs)
|
|
1192
|
+
self.mu = mu
|
|
1193
|
+
self.sigma = sigma
|
|
1194
|
+
|
|
1195
|
+
def _unroll(self, x: Tensor, max_depth: int = inf) -> Transform:
|
|
1196
|
+
if max_depth == 0:
|
|
1197
|
+
return self
|
|
1198
|
+
if 'channels' in self.shared:
|
|
1199
|
+
opt = dict()
|
|
1200
|
+
else:
|
|
1201
|
+
opt = dict(dim=list(range(1, x.ndim)), keepdim=True)
|
|
1202
|
+
mu0, sigma0 = x.mean(**opt), x.std(**opt)
|
|
1203
|
+
mu1 = self.mu if self.mu is not None else mu0
|
|
1204
|
+
sigma1 = self.sigma if self.sigma is not None else sigma0
|
|
1205
|
+
scale = sigma1 / sigma0
|
|
1206
|
+
offset = mu1 - mu0 * scale
|
|
1207
|
+
return AddMulTransform(
|
|
1208
|
+
scale, offset, **self.get_prm()
|
|
1209
|
+
).unroll(x, max_depth-1)
|
|
1210
|
+
|
|
1211
|
+
|
|
1212
|
+
class QuantileTransform(NonFinalTransform):
|
|
1213
|
+
"""Match lower and upper quantiles to (0, 1)"""
|
|
1214
|
+
|
|
1215
|
+
Final = Next = AddMulTransform
|
|
1216
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
1217
|
+
|
|
1218
|
+
def __init__(
|
|
1219
|
+
self,
|
|
1220
|
+
pmin: float = 0.01,
|
|
1221
|
+
pmax: float = 0.99,
|
|
1222
|
+
vmin: float = 0,
|
|
1223
|
+
vmax: float = 1,
|
|
1224
|
+
clip: bool = False,
|
|
1225
|
+
max_samples: int = 10000,
|
|
1226
|
+
**kwargs
|
|
1227
|
+
) -> None:
|
|
1228
|
+
"""
|
|
1229
|
+
|
|
1230
|
+
Parameters
|
|
1231
|
+
----------
|
|
1232
|
+
pmin : (0..1)
|
|
1233
|
+
Lower quantile
|
|
1234
|
+
pmax : (0..1)
|
|
1235
|
+
Upper quantile
|
|
1236
|
+
vmin : float
|
|
1237
|
+
Lower target value
|
|
1238
|
+
vmax : float
|
|
1239
|
+
Upper target value
|
|
1240
|
+
clip : bool
|
|
1241
|
+
Clip values outside (vmin, vmax)
|
|
1242
|
+
max_samples : int
|
|
1243
|
+
Maximum number of pixels to use for quantile estimation (for speed)
|
|
1244
|
+
|
|
1245
|
+
Other Parameters
|
|
1246
|
+
----------------
|
|
1247
|
+
shared
|
|
1248
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
1249
|
+
for details.
|
|
1250
|
+
returns, append, prefix, include, exclude, consume
|
|
1251
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
1252
|
+
"""
|
|
1253
|
+
super().__init__(**kwargs)
|
|
1254
|
+
self.pmin = pmin
|
|
1255
|
+
self.pmax = pmax
|
|
1256
|
+
self.vmin = vmin
|
|
1257
|
+
self.vmax = vmax
|
|
1258
|
+
self.clip = clip
|
|
1259
|
+
self.max_samples = max_samples
|
|
1260
|
+
|
|
1261
|
+
def _unroll(self, x: Tensor, max_depth: float = inf) -> Transform:
|
|
1262
|
+
if max_depth == 0:
|
|
1263
|
+
return self
|
|
1264
|
+
|
|
1265
|
+
ndim = x.ndim - 1
|
|
1266
|
+
|
|
1267
|
+
x_ = x.reshape([len(x), -1])
|
|
1268
|
+
x_ = x_[:, (x_ != 0).all(0) & x_.isfinite().all(0)]
|
|
1269
|
+
if self.max_samples and self.max_samples < x_.shape[1]:
|
|
1270
|
+
idx_ = torch.randperm(x_.shape[-1], device=x_.device)
|
|
1271
|
+
idx_ = idx_[:self.max_samples]
|
|
1272
|
+
x_ = x_[:, idx_]
|
|
1273
|
+
|
|
1274
|
+
qdim = (-1 if 'channels' not in self.shared else None)
|
|
1275
|
+
pmin = torch.quantile(x_, self.pmin, dim=qdim)
|
|
1276
|
+
pmax = torch.quantile(x_, self.pmax, dim=qdim)
|
|
1277
|
+
pmin = pmin[(Ellipsis,) + (None,) * ndim]
|
|
1278
|
+
pmax = pmax[(Ellipsis,) + (None,) * ndim]
|
|
1279
|
+
|
|
1280
|
+
num = self.vmax - self.vmin
|
|
1281
|
+
den = (pmax - pmin).clamp_min_(1e-16)
|
|
1282
|
+
slope = num / den
|
|
1283
|
+
offset = self.vmin - pmin * slope
|
|
1284
|
+
|
|
1285
|
+
if self.clip:
|
|
1286
|
+
return SequentialTransform([
|
|
1287
|
+
AddMulTransform(slope, offset, **self.get_prm()),
|
|
1288
|
+
ClipTransform(self.vmin, self.vmax, **self.get_prm())
|
|
1289
|
+
]).unroll(x, max_depth-1)
|
|
1290
|
+
else:
|
|
1291
|
+
return AddMulTransform(
|
|
1292
|
+
slope, offset, **self.get_prm()
|
|
1293
|
+
).unroll(x, max_depth-1)
|
|
1294
|
+
|
|
1295
|
+
|
|
1296
|
+
class MinMaxTransform(NonFinalTransform):
|
|
1297
|
+
"""Match min and max values to (0, 1)"""
|
|
1298
|
+
|
|
1299
|
+
Final = Next = AddMulTransform
|
|
1300
|
+
"""The transform type returned by `unroll`, `next` and `final`."""
|
|
1301
|
+
|
|
1302
|
+
def __init__(
|
|
1303
|
+
self, vmin: float = 0, vmax: float = 1, clip: bool = False, **kwargs
|
|
1304
|
+
) -> None:
|
|
1305
|
+
"""
|
|
1306
|
+
|
|
1307
|
+
Parameters
|
|
1308
|
+
----------
|
|
1309
|
+
vmin : float
|
|
1310
|
+
Lower target value
|
|
1311
|
+
vmax : float
|
|
1312
|
+
Upper target value
|
|
1313
|
+
clip : bool
|
|
1314
|
+
Clip values outside (vmin, vmax)
|
|
1315
|
+
|
|
1316
|
+
Other Parameters
|
|
1317
|
+
----------------
|
|
1318
|
+
shared
|
|
1319
|
+
See [`NonFinalTransform`][cornucopia.base.NonFinalTransform]
|
|
1320
|
+
for details.
|
|
1321
|
+
returns, append, prefix, include, exclude, consume
|
|
1322
|
+
See [`Transform`][cornucopia.base.Transform] for details.
|
|
1323
|
+
"""
|
|
1324
|
+
super().__init__(**kwargs)
|
|
1325
|
+
self.vmin = vmin
|
|
1326
|
+
self.vmax = vmax
|
|
1327
|
+
self.clip = clip
|
|
1328
|
+
|
|
1329
|
+
def _unroll(self, x: Tensor, max_depth: float = inf) -> Transform:
|
|
1330
|
+
if max_depth == 0:
|
|
1331
|
+
return self
|
|
1332
|
+
|
|
1333
|
+
ndim = x.ndim - 1
|
|
1334
|
+
|
|
1335
|
+
x_ = x.reshape([len(x), -1])
|
|
1336
|
+
x_ = x_[:, x_.isfinite().all(0)]
|
|
1337
|
+
|
|
1338
|
+
if 'channels' not in self.shared:
|
|
1339
|
+
pmin = torch.min(x_, dim=-1).values
|
|
1340
|
+
pmax = torch.max(x_, dim=-1).values
|
|
1341
|
+
else:
|
|
1342
|
+
pmin = torch.min(x_)
|
|
1343
|
+
pmax = torch.max(x_)
|
|
1344
|
+
pmin = pmin[(Ellipsis,) + (None,) * ndim]
|
|
1345
|
+
pmax = pmax[(Ellipsis,) + (None,) * ndim]
|
|
1346
|
+
|
|
1347
|
+
slope = (self.vmax - self.vmin) / (pmax - pmin)
|
|
1348
|
+
offset = self.vmin - pmin * slope
|
|
1349
|
+
|
|
1350
|
+
if self.clip:
|
|
1351
|
+
return SequentialTransform([
|
|
1352
|
+
AddMulTransform(slope, offset, **self.get_prm()),
|
|
1353
|
+
ClipTransform(self.vmin, self.vmax, **self.get_prm())
|
|
1354
|
+
]).unroll(x, max_depth-1)
|
|
1355
|
+
else:
|
|
1356
|
+
return AddMulTransform(
|
|
1357
|
+
slope, offset, **self.get_prm()
|
|
1358
|
+
).unroll(x, max_depth-1)
|