fastMONAI 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +224 -28
- fastMONAI/dataset_info.py +329 -47
- fastMONAI/external_data.py +1 -1
- fastMONAI/utils.py +394 -22
- fastMONAI/vision_all.py +3 -2
- fastMONAI/vision_augmentation.py +264 -28
- fastMONAI/vision_core.py +29 -132
- fastMONAI/vision_data.py +6 -6
- fastMONAI/vision_inference.py +35 -9
- fastMONAI/vision_metrics.py +420 -19
- fastMONAI/vision_patch.py +1259 -0
- fastMONAI/vision_plot.py +90 -1
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/METADATA +5 -5
- fastmonai-0.6.0.dist-info/RECORD +21 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/WHEEL +1 -1
- fastmonai-0.5.3.dist-info/RECORD +0 -20
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/entry_points.txt +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/top_level.txt +0 -0
fastMONAI/vision_augmentation.py
CHANGED
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
# %% auto 0
|
|
4
4
|
__all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity',
|
|
5
5
|
'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField',
|
|
6
|
-
'RandomBlur', 'RandomGamma', '
|
|
7
|
-
'OneOf']
|
|
6
|
+
'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomCutout',
|
|
7
|
+
'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf']
|
|
8
8
|
|
|
9
9
|
# %% ../nbs/03_vision_augment.ipynb 2
|
|
10
10
|
from fastai.data.all import *
|
|
@@ -14,10 +14,10 @@ from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity
|
|
|
14
14
|
|
|
15
15
|
# %% ../nbs/03_vision_augment.ipynb 5
|
|
16
16
|
class CustomDictTransform(ItemTransform):
|
|
17
|
-
"""A class that serves as a wrapper to perform an identical transformation on both
|
|
17
|
+
"""A class that serves as a wrapper to perform an identical transformation on both
|
|
18
18
|
the image and the target (if it's a mask).
|
|
19
19
|
"""
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
split_idx = 0 # Only perform transformations on training data. Use TTA() for transformations on validation data.
|
|
22
22
|
|
|
23
23
|
def __init__(self, aug):
|
|
@@ -28,31 +28,42 @@ class CustomDictTransform(ItemTransform):
|
|
|
28
28
|
"""
|
|
29
29
|
self.aug = aug
|
|
30
30
|
|
|
31
|
+
@property
|
|
32
|
+
def tio_transform(self):
|
|
33
|
+
"""Return the underlying TorchIO transform.
|
|
34
|
+
|
|
35
|
+
This property enables using fastMONAI wrappers in patch-based workflows
|
|
36
|
+
where raw TorchIO transforms are needed for tio.Compose().
|
|
37
|
+
"""
|
|
38
|
+
return self.aug
|
|
39
|
+
|
|
31
40
|
def encodes(self, x):
|
|
32
41
|
"""
|
|
33
|
-
Applies the stored transformation to an image, and the same random transformation
|
|
42
|
+
Applies the stored transformation to an image, and the same random transformation
|
|
34
43
|
to the target if it is a mask. If the target is not a mask, it returns the target as is.
|
|
35
44
|
|
|
36
45
|
Args:
|
|
37
|
-
x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the
|
|
46
|
+
x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the
|
|
38
47
|
image and the target.
|
|
39
48
|
|
|
40
49
|
Returns:
|
|
41
|
-
Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.
|
|
42
|
-
If the target is a mask, it's transformed identically to the image. If the target
|
|
50
|
+
Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.
|
|
51
|
+
If the target is a mask, it's transformed identically to the image. If the target
|
|
43
52
|
is not a mask, the original target is returned.
|
|
44
53
|
"""
|
|
45
54
|
img, y_true = x
|
|
55
|
+
|
|
56
|
+
# Use identity affine if MedImage.affine_matrix is not set
|
|
57
|
+
affine = MedImage.affine_matrix if MedImage.affine_matrix is not None else np.eye(4)
|
|
46
58
|
|
|
47
59
|
if isinstance(y_true, (MedMask)):
|
|
48
|
-
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=
|
|
49
|
-
mask=tio.LabelMap(tensor=y_true, affine=
|
|
60
|
+
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=affine),
|
|
61
|
+
mask=tio.LabelMap(tensor=y_true, affine=affine)))
|
|
50
62
|
return MedImage.create(aug['img'].data), MedMask.create(aug['mask'].data)
|
|
51
63
|
|
|
52
64
|
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img)))
|
|
53
65
|
return MedImage.create(aug['img'].data), y_true
|
|
54
66
|
|
|
55
|
-
|
|
56
67
|
# %% ../nbs/03_vision_augment.ipynb 7
|
|
57
68
|
def do_pad_or_crop(o, target_shape, padding_mode, mask_name, dtype=torch.Tensor):
|
|
58
69
|
#TODO:refactorize
|
|
@@ -72,6 +83,11 @@ class PadOrCrop(DisplayedTransform):
|
|
|
72
83
|
padding_mode=padding_mode,
|
|
73
84
|
mask_name=mask_name)
|
|
74
85
|
|
|
86
|
+
@property
|
|
87
|
+
def tio_transform(self):
|
|
88
|
+
"""Return the underlying TorchIO transform."""
|
|
89
|
+
return self.pad_or_crop
|
|
90
|
+
|
|
75
91
|
def encodes(self, o: (MedImage, MedMask)):
|
|
76
92
|
return type(o)(self.pad_or_crop(o))
|
|
77
93
|
|
|
@@ -85,6 +101,11 @@ class ZNormalization(DisplayedTransform):
|
|
|
85
101
|
self.z_normalization = tio.ZNormalization(masking_method=masking_method)
|
|
86
102
|
self.channel_wise = channel_wise
|
|
87
103
|
|
|
104
|
+
@property
|
|
105
|
+
def tio_transform(self):
|
|
106
|
+
"""Return the underlying TorchIO transform."""
|
|
107
|
+
return self.z_normalization
|
|
108
|
+
|
|
88
109
|
def encodes(self, o: MedImage):
|
|
89
110
|
try:
|
|
90
111
|
if self.channel_wise:
|
|
@@ -132,7 +153,12 @@ class RescaleIntensity(DisplayedTransform):
|
|
|
132
153
|
|
|
133
154
|
def __init__(self, out_min_max: tuple[float, float], in_min_max: tuple[float, float]):
|
|
134
155
|
self.rescale = tio.RescaleIntensity(out_min_max=out_min_max, in_min_max=in_min_max)
|
|
135
|
-
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def tio_transform(self):
|
|
159
|
+
"""Return the underlying TorchIO transform."""
|
|
160
|
+
return self.rescale
|
|
161
|
+
|
|
136
162
|
def encodes(self, o: MedImage):
|
|
137
163
|
return MedImage.create(self.rescale(o))
|
|
138
164
|
|
|
@@ -211,8 +237,17 @@ class RandomGhosting(DisplayedTransform):
|
|
|
211
237
|
def __init__(self, intensity=(0.5, 1), p=0.5):
|
|
212
238
|
self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)
|
|
213
239
|
|
|
240
|
+
@property
|
|
241
|
+
def tio_transform(self):
|
|
242
|
+
"""Return the underlying TorchIO transform."""
|
|
243
|
+
return self.add_ghosts
|
|
244
|
+
|
|
214
245
|
def encodes(self, o: MedImage):
|
|
215
|
-
|
|
246
|
+
result = self.add_ghosts(o)
|
|
247
|
+
# Handle potential complex values from k-space operations
|
|
248
|
+
if result.is_complex():
|
|
249
|
+
result = torch.real(result)
|
|
250
|
+
return MedImage.create(result)
|
|
216
251
|
|
|
217
252
|
def encodes(self, o: MedMask):
|
|
218
253
|
return o
|
|
@@ -221,26 +256,40 @@ class RandomGhosting(DisplayedTransform):
|
|
|
221
256
|
class RandomSpike(DisplayedTransform):
|
|
222
257
|
'''Apply TorchIO `RandomSpike`.'''
|
|
223
258
|
|
|
224
|
-
split_idx,order=0,1
|
|
259
|
+
split_idx, order = 0, 1
|
|
225
260
|
|
|
226
261
|
def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5):
|
|
227
262
|
self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p)
|
|
228
263
|
|
|
229
|
-
|
|
230
|
-
|
|
264
|
+
@property
|
|
265
|
+
def tio_transform(self):
|
|
266
|
+
"""Return the underlying TorchIO transform."""
|
|
267
|
+
return self.add_spikes
|
|
268
|
+
|
|
269
|
+
def encodes(self, o: MedImage):
|
|
270
|
+
result = self.add_spikes(o)
|
|
271
|
+
# Handle potential complex values from k-space operations
|
|
272
|
+
if result.is_complex():
|
|
273
|
+
result = torch.real(result)
|
|
274
|
+
return MedImage.create(result)
|
|
231
275
|
|
|
232
|
-
def encodes(self, o:MedMask):
|
|
276
|
+
def encodes(self, o: MedMask):
|
|
233
277
|
return o
|
|
234
278
|
|
|
235
279
|
# %% ../nbs/03_vision_augment.ipynb 16
|
|
236
280
|
class RandomNoise(DisplayedTransform):
|
|
237
281
|
'''Apply TorchIO `RandomNoise`.'''
|
|
238
282
|
|
|
239
|
-
split_idx,order=0,1
|
|
283
|
+
split_idx, order = 0, 1
|
|
240
284
|
|
|
241
285
|
def __init__(self, mean=0, std=(0, 0.25), p=0.5):
|
|
242
286
|
self.add_noise = tio.RandomNoise(mean=mean, std=std, p=p)
|
|
243
287
|
|
|
288
|
+
@property
|
|
289
|
+
def tio_transform(self):
|
|
290
|
+
"""Return the underlying TorchIO transform."""
|
|
291
|
+
return self.add_noise
|
|
292
|
+
|
|
244
293
|
def encodes(self, o: MedImage):
|
|
245
294
|
return MedImage.create(self.add_noise(o))
|
|
246
295
|
|
|
@@ -251,11 +300,16 @@ class RandomNoise(DisplayedTransform):
|
|
|
251
300
|
class RandomBiasField(DisplayedTransform):
|
|
252
301
|
'''Apply TorchIO `RandomBiasField`.'''
|
|
253
302
|
|
|
254
|
-
split_idx,order=0,1
|
|
303
|
+
split_idx, order = 0, 1
|
|
255
304
|
|
|
256
305
|
def __init__(self, coefficients=0.5, order=3, p=0.5):
|
|
257
306
|
self.add_biasfield = tio.RandomBiasField(coefficients=coefficients, order=order, p=p)
|
|
258
307
|
|
|
308
|
+
@property
|
|
309
|
+
def tio_transform(self):
|
|
310
|
+
"""Return the underlying TorchIO transform."""
|
|
311
|
+
return self.add_biasfield
|
|
312
|
+
|
|
259
313
|
def encodes(self, o: MedImage):
|
|
260
314
|
return MedImage.create(self.add_biasfield(o))
|
|
261
315
|
|
|
@@ -264,13 +318,18 @@ class RandomBiasField(DisplayedTransform):
|
|
|
264
318
|
|
|
265
319
|
# %% ../nbs/03_vision_augment.ipynb 18
|
|
266
320
|
class RandomBlur(DisplayedTransform):
|
|
267
|
-
'''Apply TorchIO `
|
|
321
|
+
'''Apply TorchIO `RandomBlur`.'''
|
|
268
322
|
|
|
269
|
-
split_idx,order=0,1
|
|
323
|
+
split_idx, order = 0, 1
|
|
270
324
|
|
|
271
325
|
def __init__(self, std=(0, 2), p=0.5):
|
|
272
326
|
self.add_blur = tio.RandomBlur(std=std, p=p)
|
|
273
|
-
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def tio_transform(self):
|
|
330
|
+
"""Return the underlying TorchIO transform."""
|
|
331
|
+
return self.add_blur
|
|
332
|
+
|
|
274
333
|
def encodes(self, o: MedImage):
|
|
275
334
|
return MedImage.create(self.add_blur(o))
|
|
276
335
|
|
|
@@ -281,12 +340,16 @@ class RandomBlur(DisplayedTransform):
|
|
|
281
340
|
class RandomGamma(DisplayedTransform):
|
|
282
341
|
'''Apply TorchIO `RandomGamma`.'''
|
|
283
342
|
|
|
284
|
-
|
|
285
|
-
split_idx,order=0,1
|
|
343
|
+
split_idx, order = 0, 1
|
|
286
344
|
|
|
287
345
|
def __init__(self, log_gamma=(-0.3, 0.3), p=0.5):
|
|
288
346
|
self.add_gamma = tio.RandomGamma(log_gamma=log_gamma, p=p)
|
|
289
347
|
|
|
348
|
+
@property
|
|
349
|
+
def tio_transform(self):
|
|
350
|
+
"""Return the underlying TorchIO transform."""
|
|
351
|
+
return self.add_gamma
|
|
352
|
+
|
|
290
353
|
def encodes(self, o: MedImage):
|
|
291
354
|
return MedImage.create(self.add_gamma(o))
|
|
292
355
|
|
|
@@ -294,6 +357,38 @@ class RandomGamma(DisplayedTransform):
|
|
|
294
357
|
return o
|
|
295
358
|
|
|
296
359
|
# %% ../nbs/03_vision_augment.ipynb 20
|
|
360
|
+
class RandomIntensityScale(DisplayedTransform):
|
|
361
|
+
"""Randomly scale image intensities by a multiplicative factor.
|
|
362
|
+
|
|
363
|
+
Useful for domain generalization across different acquisition protocols
|
|
364
|
+
with varying intensity ranges.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
scale_range (tuple[float, float]): Range of scale factors (min, max).
|
|
368
|
+
Values > 1 increase intensity, < 1 decrease intensity.
|
|
369
|
+
p (float): Probability of applying the transform (default: 0.5)
|
|
370
|
+
|
|
371
|
+
Example:
|
|
372
|
+
# Scale intensities randomly between 0.5x and 2.0x
|
|
373
|
+
transform = RandomIntensityScale(scale_range=(0.5, 2.0), p=0.3)
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
split_idx, order = 0, 1
|
|
377
|
+
|
|
378
|
+
def __init__(self, scale_range: tuple[float, float] = (0.5, 2.0), p: float = 0.5):
|
|
379
|
+
self.scale_range = scale_range
|
|
380
|
+
self.p = p
|
|
381
|
+
|
|
382
|
+
def encodes(self, o: MedImage):
|
|
383
|
+
if torch.rand(1).item() > self.p:
|
|
384
|
+
return o
|
|
385
|
+
scale = torch.empty(1).uniform_(self.scale_range[0], self.scale_range[1]).item()
|
|
386
|
+
return MedImage.create(o * scale)
|
|
387
|
+
|
|
388
|
+
def encodes(self, o: MedMask):
|
|
389
|
+
return o
|
|
390
|
+
|
|
391
|
+
# %% ../nbs/03_vision_augment.ipynb 21
|
|
297
392
|
class RandomMotion(DisplayedTransform):
|
|
298
393
|
"""Apply TorchIO `RandomMotion`."""
|
|
299
394
|
|
|
@@ -315,13 +410,154 @@ class RandomMotion(DisplayedTransform):
|
|
|
315
410
|
p=p
|
|
316
411
|
)
|
|
317
412
|
|
|
413
|
+
@property
|
|
414
|
+
def tio_transform(self):
|
|
415
|
+
"""Return the underlying TorchIO transform."""
|
|
416
|
+
return self.add_motion
|
|
417
|
+
|
|
318
418
|
def encodes(self, o: MedImage):
|
|
319
|
-
|
|
419
|
+
result = self.add_motion(o)
|
|
420
|
+
# Handle potential complex values from k-space operations
|
|
421
|
+
if result.is_complex():
|
|
422
|
+
result = torch.real(result)
|
|
423
|
+
return MedImage.create(result)
|
|
320
424
|
|
|
321
425
|
def encodes(self, o: MedMask):
|
|
322
426
|
return o
|
|
323
427
|
|
|
324
428
|
# %% ../nbs/03_vision_augment.ipynb 22
|
|
429
|
+
def _create_ellipsoid_mask(shape, center, radii):
|
|
430
|
+
"""Create a 3D ellipsoid mask.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
shape: (D, H, W) shape of the volume
|
|
434
|
+
center: (z, y, x) center of ellipsoid
|
|
435
|
+
radii: (rz, ry, rx) radii along each axis
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
Boolean mask where True = inside ellipsoid
|
|
439
|
+
"""
|
|
440
|
+
z, y, x = torch.meshgrid(
|
|
441
|
+
torch.arange(shape[0]),
|
|
442
|
+
torch.arange(shape[1]),
|
|
443
|
+
torch.arange(shape[2]),
|
|
444
|
+
indexing='ij'
|
|
445
|
+
)
|
|
446
|
+
dist = ((z - center[0]) / radii[0]) ** 2 + \
|
|
447
|
+
((y - center[1]) / radii[1]) ** 2 + \
|
|
448
|
+
((x - center[2]) / radii[2]) ** 2
|
|
449
|
+
return dist <= 1.0
|
|
450
|
+
|
|
451
|
+
# %% ../nbs/03_vision_augment.ipynb 23
|
|
452
|
+
class _TioRandomCutout(tio.IntensityTransform):
|
|
453
|
+
"""TorchIO-compatible RandomCutout for patch-based workflows."""
|
|
454
|
+
|
|
455
|
+
def __init__(self, holes=1, spatial_size=8, fill_value=None,
|
|
456
|
+
max_holes=None, max_spatial_size=None, p=0.2, **kwargs):
|
|
457
|
+
super().__init__(p=p, **kwargs)
|
|
458
|
+
self.holes = holes
|
|
459
|
+
self.spatial_size = spatial_size
|
|
460
|
+
self.fill_value = fill_value
|
|
461
|
+
self.max_holes = max_holes
|
|
462
|
+
self.max_spatial_size = max_spatial_size
|
|
463
|
+
|
|
464
|
+
def _apply_cutout(self, data, fill_val):
|
|
465
|
+
"""Apply spherical cutout(s) to a tensor."""
|
|
466
|
+
result = data.clone()
|
|
467
|
+
n_holes = torch.randint(self.holes, (self.max_holes or self.holes) + 1, (1,)).item()
|
|
468
|
+
|
|
469
|
+
spatial_shape = data.shape[1:] # (D, H, W)
|
|
470
|
+
min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]
|
|
471
|
+
max_size = self.max_spatial_size or self.spatial_size
|
|
472
|
+
max_size = max_size if isinstance(max_size, int) else max_size[0]
|
|
473
|
+
|
|
474
|
+
for _ in range(n_holes):
|
|
475
|
+
# Random size for this hole
|
|
476
|
+
size = torch.randint(min_size, max_size + 1, (3,))
|
|
477
|
+
radii = size.float() / 2
|
|
478
|
+
|
|
479
|
+
# Random center (ensure hole fits in volume)
|
|
480
|
+
center = [
|
|
481
|
+
torch.randint(int(radii[i].item()),
|
|
482
|
+
max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),
|
|
483
|
+
(1,)).item()
|
|
484
|
+
for i in range(3)
|
|
485
|
+
]
|
|
486
|
+
|
|
487
|
+
mask = _create_ellipsoid_mask(spatial_shape, center, radii)
|
|
488
|
+
result[:, mask] = fill_val
|
|
489
|
+
|
|
490
|
+
return result
|
|
491
|
+
|
|
492
|
+
def apply_transform(self, subject):
|
|
493
|
+
for image in self.get_images(subject):
|
|
494
|
+
data = image.data
|
|
495
|
+
fill_val = self.fill_value if self.fill_value is not None else float(data.min())
|
|
496
|
+
result = self._apply_cutout(data, fill_val)
|
|
497
|
+
image.set_data(result)
|
|
498
|
+
return subject
|
|
499
|
+
|
|
500
|
+
# %% ../nbs/03_vision_augment.ipynb 24
|
|
501
|
+
class RandomCutout(DisplayedTransform):
|
|
502
|
+
"""Randomly erase spherical regions in 3D medical images.
|
|
503
|
+
|
|
504
|
+
Simulates post-operative surgical cavities by filling random ellipsoid
|
|
505
|
+
volumes with specified values. Useful for training on pre-op images
|
|
506
|
+
to generalize to post-op scans.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
holes: Minimum number of cutout regions. Default: 1.
|
|
510
|
+
max_holes: Maximum number of regions. Default: 3.
|
|
511
|
+
spatial_size: Minimum cutout diameter in voxels. Default: 8.
|
|
512
|
+
max_spatial_size: Maximum cutout diameter. Default: 16.
|
|
513
|
+
fill: Fill value - 'min', 'mean', 'random', or float. Default: 'min'.
|
|
514
|
+
p: Probability of applying transform. Default: 0.2.
|
|
515
|
+
|
|
516
|
+
Example:
|
|
517
|
+
>>> # Simulate post-op cavities with dark spherical voids
|
|
518
|
+
>>> tfm = RandomCutout(holes=1, max_holes=2, spatial_size=10,
|
|
519
|
+
... max_spatial_size=25, fill='min', p=0.2)
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
split_idx, order = 0, 1
|
|
523
|
+
|
|
524
|
+
def __init__(self, holes=1, max_holes=3, spatial_size=8,
|
|
525
|
+
max_spatial_size=16, fill='min', p=0.2):
|
|
526
|
+
self.holes = holes
|
|
527
|
+
self.max_holes = max_holes
|
|
528
|
+
self.spatial_size = spatial_size
|
|
529
|
+
self.max_spatial_size = max_spatial_size
|
|
530
|
+
self.fill = fill
|
|
531
|
+
self.p = p
|
|
532
|
+
|
|
533
|
+
self._tio_cutout = _TioRandomCutout(
|
|
534
|
+
holes=holes, spatial_size=spatial_size,
|
|
535
|
+
fill_value=None if isinstance(fill, str) else fill,
|
|
536
|
+
max_holes=max_holes, max_spatial_size=max_spatial_size, p=p
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def tio_transform(self):
|
|
541
|
+
"""Return TorchIO-compatible transform for patch-based workflows."""
|
|
542
|
+
return self._tio_cutout
|
|
543
|
+
|
|
544
|
+
def _get_fill_value(self, tensor):
|
|
545
|
+
if self.fill == 'min': return float(tensor.min())
|
|
546
|
+
elif self.fill == 'mean': return float(tensor.mean())
|
|
547
|
+
elif self.fill == 'random':
|
|
548
|
+
return torch.empty(1).uniform_(float(tensor.min()), float(tensor.max())).item()
|
|
549
|
+
else: return self.fill
|
|
550
|
+
|
|
551
|
+
def encodes(self, o: MedImage):
|
|
552
|
+
if torch.rand(1).item() > self.p: return o
|
|
553
|
+
fill_val = self._get_fill_value(o)
|
|
554
|
+
result = self._tio_cutout._apply_cutout(o.clone(), fill_val)
|
|
555
|
+
return MedImage.create(result)
|
|
556
|
+
|
|
557
|
+
def encodes(self, o: MedMask):
|
|
558
|
+
return o
|
|
559
|
+
|
|
560
|
+
# %% ../nbs/03_vision_augment.ipynb 26
|
|
325
561
|
class RandomElasticDeformation(CustomDictTransform):
|
|
326
562
|
"""Apply TorchIO `RandomElasticDeformation`."""
|
|
327
563
|
|
|
@@ -334,7 +570,7 @@ class RandomElasticDeformation(CustomDictTransform):
|
|
|
334
570
|
image_interpolation=image_interpolation,
|
|
335
571
|
p=p))
|
|
336
572
|
|
|
337
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
573
|
+
# %% ../nbs/03_vision_augment.ipynb 27
|
|
338
574
|
class RandomAffine(CustomDictTransform):
|
|
339
575
|
"""Apply TorchIO `RandomAffine`."""
|
|
340
576
|
|
|
@@ -350,14 +586,14 @@ class RandomAffine(CustomDictTransform):
|
|
|
350
586
|
default_pad_value=default_pad_value,
|
|
351
587
|
p=p))
|
|
352
588
|
|
|
353
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
589
|
+
# %% ../nbs/03_vision_augment.ipynb 28
|
|
354
590
|
class RandomFlip(CustomDictTransform):
|
|
355
591
|
"""Apply TorchIO `RandomFlip`."""
|
|
356
592
|
|
|
357
593
|
def __init__(self, axes='LR', p=0.5):
|
|
358
594
|
super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))
|
|
359
595
|
|
|
360
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
596
|
+
# %% ../nbs/03_vision_augment.ipynb 29
|
|
361
597
|
class OneOf(CustomDictTransform):
|
|
362
598
|
"""Apply only one of the given transforms using TorchIO `OneOf`."""
|
|
363
599
|
|