fastMONAI 0.5.4__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +61 -1
- fastMONAI/dataset_info.py +144 -7
- fastMONAI/utils.py +296 -7
- fastMONAI/vision_augmentation.py +328 -34
- fastMONAI/vision_patch.py +175 -23
- fastMONAI/vision_plot.py +89 -1
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.1.dist-info}/METADATA +1 -1
- fastmonai-0.6.1.dist-info/RECORD +21 -0
- fastmonai-0.5.4.dist-info/RECORD +0 -21
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.1.dist-info}/WHEEL +0 -0
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.1.dist-info}/entry_points.txt +0 -0
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.1.dist-info}/licenses/LICENSE +0 -0
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.1.dist-info}/top_level.txt +0 -0
fastMONAI/vision_augmentation.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_vision_augment.ipynb.
|
|
2
2
|
|
|
3
3
|
# %% auto 0
|
|
4
|
-
__all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', '
|
|
5
|
-
'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike',
|
|
6
|
-
'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion',
|
|
7
|
-
'RandomAffine', 'RandomFlip', 'OneOf']
|
|
4
|
+
__all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'SpatialPad', 'ZNormalization', 'RescaleIntensity',
|
|
5
|
+
'NormalizeIntensity', 'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike',
|
|
6
|
+
'RandomNoise', 'RandomBiasField', 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion',
|
|
7
|
+
'RandomCutout', 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf']
|
|
8
8
|
|
|
9
9
|
# %% ../nbs/03_vision_augment.ipynb 2
|
|
10
10
|
from fastai.data.all import *
|
|
11
11
|
from .vision_core import *
|
|
12
12
|
import torchio as tio
|
|
13
13
|
from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity
|
|
14
|
+
from monai.transforms import SpatialPad as MonaiSpatialPad
|
|
14
15
|
|
|
15
16
|
# %% ../nbs/03_vision_augment.ipynb 5
|
|
16
17
|
class CustomDictTransform(ItemTransform):
|
|
@@ -92,6 +93,74 @@ class PadOrCrop(DisplayedTransform):
|
|
|
92
93
|
return type(o)(self.pad_or_crop(o))
|
|
93
94
|
|
|
94
95
|
# %% ../nbs/03_vision_augment.ipynb 9
|
|
96
|
+
class SpatialPad(DisplayedTransform):
|
|
97
|
+
"""Pad image to minimum size without cropping using MONAI's `SpatialPad`.
|
|
98
|
+
|
|
99
|
+
This transform pads each dimension to AT LEAST the specified size. Dimensions
|
|
100
|
+
already larger than `spatial_size` are left unchanged (no cropping), making it
|
|
101
|
+
ideal for patch-based training where images must be at least as large as the
|
|
102
|
+
patch size.
|
|
103
|
+
|
|
104
|
+
Uses zero padding (constant value 0) by default.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
spatial_size: Minimum size [x, y, z] for each dimension. Can be int (same
|
|
108
|
+
for all dims) or list. Usually set to patch_size for patch-based training.
|
|
109
|
+
mode: Padding mode. Default 'constant' means zero padding. Other options:
|
|
110
|
+
'edge', 'reflect', 'symmetric'.
|
|
111
|
+
|
|
112
|
+
Example:
|
|
113
|
+
>>> # Image 512×512×48, need at least 256×256×96 for patches
|
|
114
|
+
>>> transform = SpatialPad(spatial_size=[256, 256, 96])
|
|
115
|
+
>>> # Result: 512×512×96 (x,y unchanged, z padded from 48→96)
|
|
116
|
+
|
|
117
|
+
Note:
|
|
118
|
+
This differs from `PadOrCrop` which will crop dimensions larger than
|
|
119
|
+
the target size. Use `SpatialPad` when you want to preserve large
|
|
120
|
+
dimensions and only pad small ones.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
order = 0
|
|
124
|
+
|
|
125
|
+
def __init__(self, spatial_size, mode='constant'):
|
|
126
|
+
if not is_listy(spatial_size):
|
|
127
|
+
spatial_size = [spatial_size, spatial_size, spatial_size]
|
|
128
|
+
self.spatial_size = spatial_size
|
|
129
|
+
self.mode = mode
|
|
130
|
+
# Create MONAI transform (zero padding by default)
|
|
131
|
+
self.spatial_pad = MonaiSpatialPad(
|
|
132
|
+
spatial_size=spatial_size,
|
|
133
|
+
mode=mode
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def tio_transform(self):
|
|
138
|
+
"""Return TorchIO-compatible transform for patch-based workflows.
|
|
139
|
+
|
|
140
|
+
This wraps MONAI's SpatialPad in a TorchIO SpatialTransform for
|
|
141
|
+
compatibility with TorchIO's Queue and Compose pipelines.
|
|
142
|
+
"""
|
|
143
|
+
import torchio as tio
|
|
144
|
+
|
|
145
|
+
class _TioSpatialPad(tio.SpatialTransform):
|
|
146
|
+
"""TorchIO wrapper for MONAI's SpatialPad."""
|
|
147
|
+
def __init__(self, spatial_size, mode, **kwargs):
|
|
148
|
+
super().__init__(**kwargs)
|
|
149
|
+
self._monai_pad = MonaiSpatialPad(spatial_size=spatial_size, mode=mode)
|
|
150
|
+
|
|
151
|
+
def apply_transform(self, subject):
|
|
152
|
+
for image in self.get_images(subject):
|
|
153
|
+
data = image.data
|
|
154
|
+
padded = self._monai_pad(data)
|
|
155
|
+
image.set_data(padded)
|
|
156
|
+
return subject
|
|
157
|
+
|
|
158
|
+
return _TioSpatialPad(spatial_size=self.spatial_size, mode=self.mode)
|
|
159
|
+
|
|
160
|
+
def encodes(self, o: (MedImage, MedMask)):
|
|
161
|
+
return type(o)(self.spatial_pad(o))
|
|
162
|
+
|
|
163
|
+
# %% ../nbs/03_vision_augment.ipynb 10
|
|
95
164
|
class ZNormalization(DisplayedTransform):
|
|
96
165
|
"""Apply TorchIO `ZNormalization`."""
|
|
97
166
|
|
|
@@ -136,7 +205,7 @@ class ZNormalization(DisplayedTransform):
|
|
|
136
205
|
def encodes(self, o: MedMask):
|
|
137
206
|
return o
|
|
138
207
|
|
|
139
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
208
|
+
# %% ../nbs/03_vision_augment.ipynb 11
|
|
140
209
|
class RescaleIntensity(DisplayedTransform):
|
|
141
210
|
"""Apply TorchIO RescaleIntensity for robust intensity scaling.
|
|
142
211
|
|
|
@@ -165,7 +234,7 @@ class RescaleIntensity(DisplayedTransform):
|
|
|
165
234
|
def encodes(self, o: MedMask):
|
|
166
235
|
return o
|
|
167
236
|
|
|
168
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
237
|
+
# %% ../nbs/03_vision_augment.ipynb 12
|
|
169
238
|
class NormalizeIntensity(DisplayedTransform):
|
|
170
239
|
"""Apply MONAI NormalizeIntensity.
|
|
171
240
|
|
|
@@ -203,7 +272,7 @@ class NormalizeIntensity(DisplayedTransform):
|
|
|
203
272
|
def encodes(self, o: MedMask):
|
|
204
273
|
return o
|
|
205
274
|
|
|
206
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
275
|
+
# %% ../nbs/03_vision_augment.ipynb 13
|
|
207
276
|
class BraTSMaskConverter(DisplayedTransform):
|
|
208
277
|
'''Convert BraTS masks.'''
|
|
209
278
|
|
|
@@ -215,7 +284,7 @@ class BraTSMaskConverter(DisplayedTransform):
|
|
|
215
284
|
o = torch.where(o==4, 3., o)
|
|
216
285
|
return MedMask.create(o)
|
|
217
286
|
|
|
218
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
287
|
+
# %% ../nbs/03_vision_augment.ipynb 14
|
|
219
288
|
class BinaryConverter(DisplayedTransform):
|
|
220
289
|
'''Convert to binary mask.'''
|
|
221
290
|
|
|
@@ -228,7 +297,7 @@ class BinaryConverter(DisplayedTransform):
|
|
|
228
297
|
o = torch.where(o>0, 1., 0)
|
|
229
298
|
return MedMask.create(o)
|
|
230
299
|
|
|
231
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
300
|
+
# %% ../nbs/03_vision_augment.ipynb 15
|
|
232
301
|
class RandomGhosting(DisplayedTransform):
|
|
233
302
|
"""Apply TorchIO `RandomGhosting`."""
|
|
234
303
|
|
|
@@ -243,16 +312,12 @@ class RandomGhosting(DisplayedTransform):
|
|
|
243
312
|
return self.add_ghosts
|
|
244
313
|
|
|
245
314
|
def encodes(self, o: MedImage):
|
|
246
|
-
|
|
247
|
-
# Handle potential complex values from k-space operations
|
|
248
|
-
if result.is_complex():
|
|
249
|
-
result = torch.real(result)
|
|
250
|
-
return MedImage.create(result)
|
|
315
|
+
return MedImage.create(self.add_ghosts(o))
|
|
251
316
|
|
|
252
317
|
def encodes(self, o: MedMask):
|
|
253
318
|
return o
|
|
254
319
|
|
|
255
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
320
|
+
# %% ../nbs/03_vision_augment.ipynb 16
|
|
256
321
|
class RandomSpike(DisplayedTransform):
|
|
257
322
|
'''Apply TorchIO `RandomSpike`.'''
|
|
258
323
|
|
|
@@ -267,16 +332,12 @@ class RandomSpike(DisplayedTransform):
|
|
|
267
332
|
return self.add_spikes
|
|
268
333
|
|
|
269
334
|
def encodes(self, o: MedImage):
|
|
270
|
-
|
|
271
|
-
# Handle potential complex values from k-space operations
|
|
272
|
-
if result.is_complex():
|
|
273
|
-
result = torch.real(result)
|
|
274
|
-
return MedImage.create(result)
|
|
335
|
+
return MedImage.create(self.add_spikes(o))
|
|
275
336
|
|
|
276
337
|
def encodes(self, o: MedMask):
|
|
277
338
|
return o
|
|
278
339
|
|
|
279
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
340
|
+
# %% ../nbs/03_vision_augment.ipynb 17
|
|
280
341
|
class RandomNoise(DisplayedTransform):
|
|
281
342
|
'''Apply TorchIO `RandomNoise`.'''
|
|
282
343
|
|
|
@@ -296,7 +357,7 @@ class RandomNoise(DisplayedTransform):
|
|
|
296
357
|
def encodes(self, o: MedMask):
|
|
297
358
|
return o
|
|
298
359
|
|
|
299
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
360
|
+
# %% ../nbs/03_vision_augment.ipynb 18
|
|
300
361
|
class RandomBiasField(DisplayedTransform):
|
|
301
362
|
'''Apply TorchIO `RandomBiasField`.'''
|
|
302
363
|
|
|
@@ -316,7 +377,7 @@ class RandomBiasField(DisplayedTransform):
|
|
|
316
377
|
def encodes(self, o: MedMask):
|
|
317
378
|
return o
|
|
318
379
|
|
|
319
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
380
|
+
# %% ../nbs/03_vision_augment.ipynb 19
|
|
320
381
|
class RandomBlur(DisplayedTransform):
|
|
321
382
|
'''Apply TorchIO `RandomBlur`.'''
|
|
322
383
|
|
|
@@ -336,7 +397,7 @@ class RandomBlur(DisplayedTransform):
|
|
|
336
397
|
def encodes(self, o: MedMask):
|
|
337
398
|
return o
|
|
338
399
|
|
|
339
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
400
|
+
# %% ../nbs/03_vision_augment.ipynb 20
|
|
340
401
|
class RandomGamma(DisplayedTransform):
|
|
341
402
|
'''Apply TorchIO `RandomGamma`.'''
|
|
342
403
|
|
|
@@ -356,7 +417,7 @@ class RandomGamma(DisplayedTransform):
|
|
|
356
417
|
def encodes(self, o: MedMask):
|
|
357
418
|
return o
|
|
358
419
|
|
|
359
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
420
|
+
# %% ../nbs/03_vision_augment.ipynb 21
|
|
360
421
|
class RandomIntensityScale(DisplayedTransform):
|
|
361
422
|
"""Randomly scale image intensities by a multiplicative factor.
|
|
362
423
|
|
|
@@ -388,7 +449,7 @@ class RandomIntensityScale(DisplayedTransform):
|
|
|
388
449
|
def encodes(self, o: MedMask):
|
|
389
450
|
return o
|
|
390
451
|
|
|
391
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
452
|
+
# %% ../nbs/03_vision_augment.ipynb 22
|
|
392
453
|
class RandomMotion(DisplayedTransform):
|
|
393
454
|
"""Apply TorchIO `RandomMotion`."""
|
|
394
455
|
|
|
@@ -416,16 +477,249 @@ class RandomMotion(DisplayedTransform):
|
|
|
416
477
|
return self.add_motion
|
|
417
478
|
|
|
418
479
|
def encodes(self, o: MedImage):
|
|
419
|
-
|
|
420
|
-
# Handle potential complex values from k-space operations
|
|
421
|
-
if result.is_complex():
|
|
422
|
-
result = torch.real(result)
|
|
423
|
-
return MedImage.create(result)
|
|
480
|
+
return MedImage.create(self.add_motion(o))
|
|
424
481
|
|
|
425
482
|
def encodes(self, o: MedMask):
|
|
426
483
|
return o
|
|
427
484
|
|
|
428
485
|
# %% ../nbs/03_vision_augment.ipynb 23
|
|
486
|
+
def _create_ellipsoid_mask(shape, center, radii):
|
|
487
|
+
"""Create a 3D ellipsoid mask.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
shape: (D, H, W) shape of the volume
|
|
491
|
+
center: (z, y, x) center of ellipsoid
|
|
492
|
+
radii: (rz, ry, rx) radii along each axis
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
Boolean mask where True = inside ellipsoid
|
|
496
|
+
"""
|
|
497
|
+
z, y, x = torch.meshgrid(
|
|
498
|
+
torch.arange(shape[0]),
|
|
499
|
+
torch.arange(shape[1]),
|
|
500
|
+
torch.arange(shape[2]),
|
|
501
|
+
indexing='ij'
|
|
502
|
+
)
|
|
503
|
+
dist = ((z - center[0]) / radii[0]) ** 2 + \
|
|
504
|
+
((y - center[1]) / radii[1]) ** 2 + \
|
|
505
|
+
((x - center[2]) / radii[2]) ** 2
|
|
506
|
+
return dist <= 1.0
|
|
507
|
+
|
|
508
|
+
# %% ../nbs/03_vision_augment.ipynb 24
|
|
509
|
+
class _TioRandomCutout(tio.IntensityTransform):
|
|
510
|
+
"""TorchIO-compatible RandomCutout for patch-based workflows.
|
|
511
|
+
|
|
512
|
+
When mask_only=True, cutouts only affect voxels where the mask is positive.
|
|
513
|
+
The mask should be available in the Subject as 'mask' key.
|
|
514
|
+
"""
|
|
515
|
+
|
|
516
|
+
def __init__(self, holes=1, spatial_size=8, fill_value=None,
|
|
517
|
+
max_holes=None, max_spatial_size=None, mask_only=True, p=0.2, **kwargs):
|
|
518
|
+
super().__init__(p=p, **kwargs)
|
|
519
|
+
self.holes = holes
|
|
520
|
+
self.spatial_size = spatial_size
|
|
521
|
+
self.fill_value = fill_value
|
|
522
|
+
self.max_holes = max_holes
|
|
523
|
+
self.max_spatial_size = max_spatial_size
|
|
524
|
+
self.mask_only = mask_only
|
|
525
|
+
|
|
526
|
+
def _apply_cutout(self, data, fill_val, mask_tensor=None):
|
|
527
|
+
"""Apply spherical cutout(s) to a tensor.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
data: Input tensor of shape (C, D, H, W)
|
|
531
|
+
fill_val: Value to fill cutout regions
|
|
532
|
+
mask_tensor: Optional mask tensor for mask-only cutouts
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
Tensor with cutout applied
|
|
536
|
+
"""
|
|
537
|
+
result = data.clone()
|
|
538
|
+
n_holes = torch.randint(self.holes, (self.max_holes or self.holes) + 1, (1,)).item()
|
|
539
|
+
|
|
540
|
+
spatial_shape = data.shape[1:] # (D, H, W)
|
|
541
|
+
min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]
|
|
542
|
+
max_size = self.max_spatial_size or self.spatial_size
|
|
543
|
+
max_size = max_size if isinstance(max_size, int) else max_size[0]
|
|
544
|
+
|
|
545
|
+
for _ in range(n_holes):
|
|
546
|
+
# Random size for this hole
|
|
547
|
+
size = torch.randint(min_size, max_size + 1, (3,))
|
|
548
|
+
radii = size.float() / 2
|
|
549
|
+
|
|
550
|
+
# Random center (ensure hole fits in volume)
|
|
551
|
+
center = [
|
|
552
|
+
torch.randint(int(radii[i].item()),
|
|
553
|
+
max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),
|
|
554
|
+
(1,)).item()
|
|
555
|
+
for i in range(3)
|
|
556
|
+
]
|
|
557
|
+
|
|
558
|
+
ellipsoid = _create_ellipsoid_mask(spatial_shape, center, radii)
|
|
559
|
+
|
|
560
|
+
if self.mask_only and mask_tensor is not None:
|
|
561
|
+
# INTERSECT with tumor mask - only affect tumor voxels
|
|
562
|
+
tumor_mask = mask_tensor[0] > 0
|
|
563
|
+
cutout_region = ellipsoid & tumor_mask
|
|
564
|
+
else:
|
|
565
|
+
cutout_region = ellipsoid
|
|
566
|
+
|
|
567
|
+
result[:, cutout_region] = fill_val
|
|
568
|
+
|
|
569
|
+
return result
|
|
570
|
+
|
|
571
|
+
def apply_transform(self, subject):
|
|
572
|
+
# Get mask if available for mask-only cutouts
|
|
573
|
+
mask_tensor = None
|
|
574
|
+
if self.mask_only and 'mask' in subject:
|
|
575
|
+
mask_tensor = subject['mask'].data
|
|
576
|
+
|
|
577
|
+
# Skip if mask is empty
|
|
578
|
+
if mask_tensor is not None and not (mask_tensor > 0).any():
|
|
579
|
+
return subject
|
|
580
|
+
|
|
581
|
+
for image in self.get_images(subject):
|
|
582
|
+
data = image.data
|
|
583
|
+
fill_val = self.fill_value if self.fill_value is not None else float(data.min())
|
|
584
|
+
result = self._apply_cutout(data, fill_val, mask_tensor)
|
|
585
|
+
image.set_data(result)
|
|
586
|
+
return subject
|
|
587
|
+
|
|
588
|
+
# %% ../nbs/03_vision_augment.ipynb 25
|
|
589
|
+
class RandomCutout(ItemTransform):
|
|
590
|
+
"""Randomly erase spherical regions in 3D medical images with mask-aware placement.
|
|
591
|
+
|
|
592
|
+
Simulates post-operative surgical cavities by filling random ellipsoid
|
|
593
|
+
volumes with specified values. When mask_only=True (default), cutouts only
|
|
594
|
+
affect voxels inside the segmentation mask, ensuring no healthy tissue is modified.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
holes: Minimum number of cutout regions. Default: 1.
|
|
598
|
+
max_holes: Maximum number of regions. Default: 3.
|
|
599
|
+
spatial_size: Minimum cutout diameter in voxels. Default: 8.
|
|
600
|
+
max_spatial_size: Maximum cutout diameter. Default: 16.
|
|
601
|
+
fill: Fill value - 'min', 'mean', 'random', or float. Default: 'min'.
|
|
602
|
+
mask_only: If True, cutouts only affect mask-positive voxels (tumor tissue).
|
|
603
|
+
If False, cutouts can affect any voxel (original behavior). Default: True.
|
|
604
|
+
p: Probability of applying transform. Default: 0.2.
|
|
605
|
+
|
|
606
|
+
Example:
|
|
607
|
+
>>> # Simulate post-op cavities only within tumor regions
|
|
608
|
+
>>> tfm = RandomCutout(holes=1, max_holes=2, spatial_size=10,
|
|
609
|
+
... max_spatial_size=25, fill='min', mask_only=True, p=0.2)
|
|
610
|
+
|
|
611
|
+
>>> # Original behavior - cutouts anywhere in the volume
|
|
612
|
+
>>> tfm = RandomCutout(mask_only=False, p=0.2)
|
|
613
|
+
"""
|
|
614
|
+
|
|
615
|
+
split_idx, order = 0, 1
|
|
616
|
+
|
|
617
|
+
def __init__(self, holes=1, max_holes=3, spatial_size=8,
|
|
618
|
+
max_spatial_size=16, fill='min', mask_only=True, p=0.2):
|
|
619
|
+
self.holes = holes
|
|
620
|
+
self.max_holes = max_holes
|
|
621
|
+
self.spatial_size = spatial_size
|
|
622
|
+
self.max_spatial_size = max_spatial_size
|
|
623
|
+
self.fill = fill
|
|
624
|
+
self.mask_only = mask_only
|
|
625
|
+
self.p = p
|
|
626
|
+
|
|
627
|
+
self._tio_cutout = _TioRandomCutout(
|
|
628
|
+
holes=holes, spatial_size=spatial_size,
|
|
629
|
+
fill_value=None if isinstance(fill, str) else fill,
|
|
630
|
+
max_holes=max_holes, max_spatial_size=max_spatial_size,
|
|
631
|
+
mask_only=mask_only, p=p
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
@property
|
|
635
|
+
def tio_transform(self):
|
|
636
|
+
"""Return TorchIO-compatible transform for patch-based workflows.
|
|
637
|
+
|
|
638
|
+
Note: For mask-aware cutouts in patch workflows, the mask must be
|
|
639
|
+
available in the TorchIO Subject as 'mask' key.
|
|
640
|
+
"""
|
|
641
|
+
return self._tio_cutout
|
|
642
|
+
|
|
643
|
+
def _get_fill_value(self, tensor):
|
|
644
|
+
if self.fill == 'min': return float(tensor.min())
|
|
645
|
+
elif self.fill == 'mean': return float(tensor.mean())
|
|
646
|
+
elif self.fill == 'random':
|
|
647
|
+
return torch.empty(1).uniform_(float(tensor.min()), float(tensor.max())).item()
|
|
648
|
+
else: return self.fill
|
|
649
|
+
|
|
650
|
+
def encodes(self, x):
|
|
651
|
+
"""Apply mask-aware cutout to image.
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
x: Tuple of (MedImage, target) where target is MedMask or TensorCategory
|
|
655
|
+
|
|
656
|
+
Returns:
|
|
657
|
+
Tuple of (transformed MedImage, unchanged target)
|
|
658
|
+
"""
|
|
659
|
+
img, y_true = x
|
|
660
|
+
|
|
661
|
+
# Probability check
|
|
662
|
+
if torch.rand(1).item() > self.p:
|
|
663
|
+
return img, y_true
|
|
664
|
+
|
|
665
|
+
# Get mask data if available (as numpy for safe boolean operations)
|
|
666
|
+
mask_np = None
|
|
667
|
+
tumor_coords = None
|
|
668
|
+
if isinstance(y_true, MedMask):
|
|
669
|
+
mask_np = y_true.numpy()
|
|
670
|
+
# Get coordinates of tumor voxels for mask-aware center placement
|
|
671
|
+
if self.mask_only and (mask_np > 0).any():
|
|
672
|
+
tumor_coords = np.argwhere(mask_np[0] > 0) # Shape: (N, 3) for z, y, x
|
|
673
|
+
|
|
674
|
+
# Skip cutout if mask_only=True but no mask or empty mask
|
|
675
|
+
if self.mask_only:
|
|
676
|
+
if mask_np is None or tumor_coords is None or len(tumor_coords) == 0:
|
|
677
|
+
return img, y_true
|
|
678
|
+
|
|
679
|
+
# Work with numpy array to avoid tensor subclass issues
|
|
680
|
+
result_np = img.numpy().copy()
|
|
681
|
+
spatial_shape = img.shape[1:] # (D, H, W)
|
|
682
|
+
fill_val = self._get_fill_value(img)
|
|
683
|
+
|
|
684
|
+
n_holes = torch.randint(self.holes, self.max_holes + 1, (1,)).item()
|
|
685
|
+
|
|
686
|
+
min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]
|
|
687
|
+
max_size = self.max_spatial_size or self.spatial_size
|
|
688
|
+
max_size = max_size if isinstance(max_size, int) else max_size[0]
|
|
689
|
+
|
|
690
|
+
for _ in range(n_holes):
|
|
691
|
+
# Random size for this hole
|
|
692
|
+
size = torch.randint(min_size, max_size + 1, (3,))
|
|
693
|
+
radii = size.float() / 2
|
|
694
|
+
|
|
695
|
+
if self.mask_only and tumor_coords is not None:
|
|
696
|
+
# Pick center from within tumor region to ensure intersection
|
|
697
|
+
idx = torch.randint(0, len(tumor_coords), (1,)).item()
|
|
698
|
+
center = tumor_coords[idx].tolist() # [z, y, x]
|
|
699
|
+
else:
|
|
700
|
+
# Random center anywhere (ensure hole fits in volume)
|
|
701
|
+
center = [
|
|
702
|
+
torch.randint(int(radii[i].item()),
|
|
703
|
+
max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),
|
|
704
|
+
(1,)).item()
|
|
705
|
+
for i in range(3)
|
|
706
|
+
]
|
|
707
|
+
|
|
708
|
+
# Create ellipsoid mask (numpy)
|
|
709
|
+
ellipsoid = _create_ellipsoid_mask(spatial_shape, center, radii).numpy()
|
|
710
|
+
|
|
711
|
+
if self.mask_only and mask_np is not None:
|
|
712
|
+
# INTERSECT with tumor mask - only affect tumor voxels
|
|
713
|
+
tumor_mask = mask_np[0] > 0
|
|
714
|
+
cutout_region = ellipsoid & tumor_mask
|
|
715
|
+
else:
|
|
716
|
+
cutout_region = ellipsoid
|
|
717
|
+
|
|
718
|
+
result_np[:, cutout_region] = fill_val
|
|
719
|
+
|
|
720
|
+
return MedImage.create(torch.from_numpy(result_np)), y_true
|
|
721
|
+
|
|
722
|
+
# %% ../nbs/03_vision_augment.ipynb 27
|
|
429
723
|
class RandomElasticDeformation(CustomDictTransform):
|
|
430
724
|
"""Apply TorchIO `RandomElasticDeformation`."""
|
|
431
725
|
|
|
@@ -438,7 +732,7 @@ class RandomElasticDeformation(CustomDictTransform):
|
|
|
438
732
|
image_interpolation=image_interpolation,
|
|
439
733
|
p=p))
|
|
440
734
|
|
|
441
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
735
|
+
# %% ../nbs/03_vision_augment.ipynb 28
|
|
442
736
|
class RandomAffine(CustomDictTransform):
|
|
443
737
|
"""Apply TorchIO `RandomAffine`."""
|
|
444
738
|
|
|
@@ -454,14 +748,14 @@ class RandomAffine(CustomDictTransform):
|
|
|
454
748
|
default_pad_value=default_pad_value,
|
|
455
749
|
p=p))
|
|
456
750
|
|
|
457
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
751
|
+
# %% ../nbs/03_vision_augment.ipynb 29
|
|
458
752
|
class RandomFlip(CustomDictTransform):
|
|
459
753
|
"""Apply TorchIO `RandomFlip`."""
|
|
460
754
|
|
|
461
755
|
def __init__(self, axes='LR', p=0.5):
|
|
462
756
|
super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))
|
|
463
757
|
|
|
464
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
758
|
+
# %% ../nbs/03_vision_augment.ipynb 30
|
|
465
759
|
class OneOf(CustomDictTransform):
|
|
466
760
|
"""Apply only one of the given transforms using TorchIO `OneOf`."""
|
|
467
761
|
|