fastMONAI 0.5.4__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,17 @@
1
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_vision_augment.ipynb.
2
2
 
3
3
  # %% auto 0
4
- __all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity',
5
- 'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField',
6
- 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomElasticDeformation',
7
- 'RandomAffine', 'RandomFlip', 'OneOf']
4
+ __all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'SpatialPad', 'ZNormalization', 'RescaleIntensity',
5
+ 'NormalizeIntensity', 'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike',
6
+ 'RandomNoise', 'RandomBiasField', 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion',
7
+ 'RandomCutout', 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf']
8
8
 
9
9
  # %% ../nbs/03_vision_augment.ipynb 2
10
10
  from fastai.data.all import *
11
11
  from .vision_core import *
12
12
  import torchio as tio
13
13
  from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity
14
+ from monai.transforms import SpatialPad as MonaiSpatialPad
14
15
 
15
16
  # %% ../nbs/03_vision_augment.ipynb 5
16
17
  class CustomDictTransform(ItemTransform):
@@ -92,6 +93,74 @@ class PadOrCrop(DisplayedTransform):
92
93
  return type(o)(self.pad_or_crop(o))
93
94
 
94
95
  # %% ../nbs/03_vision_augment.ipynb 9
96
+ class SpatialPad(DisplayedTransform):
97
+ """Pad image to minimum size without cropping using MONAI's `SpatialPad`.
98
+
99
+ This transform pads each dimension to AT LEAST the specified size. Dimensions
100
+ already larger than `spatial_size` are left unchanged (no cropping), making it
101
+ ideal for patch-based training where images must be at least as large as the
102
+ patch size.
103
+
104
+ Uses zero padding (constant value 0) by default.
105
+
106
+ Args:
107
+ spatial_size: Minimum size [x, y, z] for each dimension. Can be int (same
108
+ for all dims) or list. Usually set to patch_size for patch-based training.
109
+ mode: Padding mode. Default 'constant' means zero padding. Other options:
110
+ 'edge', 'reflect', 'symmetric'.
111
+
112
+ Example:
113
+ >>> # Image 512×512×48, need at least 256×256×96 for patches
114
+ >>> transform = SpatialPad(spatial_size=[256, 256, 96])
115
+ >>> # Result: 512×512×96 (x,y unchanged, z padded from 48→96)
116
+
117
+ Note:
118
+ This differs from `PadOrCrop` which will crop dimensions larger than
119
+ the target size. Use `SpatialPad` when you want to preserve large
120
+ dimensions and only pad small ones.
121
+ """
122
+
123
+ order = 0
124
+
125
+ def __init__(self, spatial_size, mode='constant'):
126
+ if not is_listy(spatial_size):
127
+ spatial_size = [spatial_size, spatial_size, spatial_size]
128
+ self.spatial_size = spatial_size
129
+ self.mode = mode
130
+ # Create MONAI transform (zero padding by default)
131
+ self.spatial_pad = MonaiSpatialPad(
132
+ spatial_size=spatial_size,
133
+ mode=mode
134
+ )
135
+
136
+ @property
137
+ def tio_transform(self):
138
+ """Return TorchIO-compatible transform for patch-based workflows.
139
+
140
+ This wraps MONAI's SpatialPad in a TorchIO SpatialTransform for
141
+ compatibility with TorchIO's Queue and Compose pipelines.
142
+ """
143
+ import torchio as tio
144
+
145
+ class _TioSpatialPad(tio.SpatialTransform):
146
+ """TorchIO wrapper for MONAI's SpatialPad."""
147
+ def __init__(self, spatial_size, mode, **kwargs):
148
+ super().__init__(**kwargs)
149
+ self._monai_pad = MonaiSpatialPad(spatial_size=spatial_size, mode=mode)
150
+
151
+ def apply_transform(self, subject):
152
+ for image in self.get_images(subject):
153
+ data = image.data
154
+ padded = self._monai_pad(data)
155
+ image.set_data(padded)
156
+ return subject
157
+
158
+ return _TioSpatialPad(spatial_size=self.spatial_size, mode=self.mode)
159
+
160
+ def encodes(self, o: (MedImage, MedMask)):
161
+ return type(o)(self.spatial_pad(o))
162
+
163
+ # %% ../nbs/03_vision_augment.ipynb 10
95
164
  class ZNormalization(DisplayedTransform):
96
165
  """Apply TorchIO `ZNormalization`."""
97
166
 
@@ -136,7 +205,7 @@ class ZNormalization(DisplayedTransform):
136
205
  def encodes(self, o: MedMask):
137
206
  return o
138
207
 
139
- # %% ../nbs/03_vision_augment.ipynb 10
208
+ # %% ../nbs/03_vision_augment.ipynb 11
140
209
  class RescaleIntensity(DisplayedTransform):
141
210
  """Apply TorchIO RescaleIntensity for robust intensity scaling.
142
211
 
@@ -165,7 +234,7 @@ class RescaleIntensity(DisplayedTransform):
165
234
  def encodes(self, o: MedMask):
166
235
  return o
167
236
 
168
- # %% ../nbs/03_vision_augment.ipynb 11
237
+ # %% ../nbs/03_vision_augment.ipynb 12
169
238
  class NormalizeIntensity(DisplayedTransform):
170
239
  """Apply MONAI NormalizeIntensity.
171
240
 
@@ -203,7 +272,7 @@ class NormalizeIntensity(DisplayedTransform):
203
272
  def encodes(self, o: MedMask):
204
273
  return o
205
274
 
206
- # %% ../nbs/03_vision_augment.ipynb 12
275
+ # %% ../nbs/03_vision_augment.ipynb 13
207
276
  class BraTSMaskConverter(DisplayedTransform):
208
277
  '''Convert BraTS masks.'''
209
278
 
@@ -215,7 +284,7 @@ class BraTSMaskConverter(DisplayedTransform):
215
284
  o = torch.where(o==4, 3., o)
216
285
  return MedMask.create(o)
217
286
 
218
- # %% ../nbs/03_vision_augment.ipynb 13
287
+ # %% ../nbs/03_vision_augment.ipynb 14
219
288
  class BinaryConverter(DisplayedTransform):
220
289
  '''Convert to binary mask.'''
221
290
 
@@ -228,7 +297,7 @@ class BinaryConverter(DisplayedTransform):
228
297
  o = torch.where(o>0, 1., 0)
229
298
  return MedMask.create(o)
230
299
 
231
- # %% ../nbs/03_vision_augment.ipynb 14
300
+ # %% ../nbs/03_vision_augment.ipynb 15
232
301
  class RandomGhosting(DisplayedTransform):
233
302
  """Apply TorchIO `RandomGhosting`."""
234
303
 
@@ -243,16 +312,12 @@ class RandomGhosting(DisplayedTransform):
243
312
  return self.add_ghosts
244
313
 
245
314
  def encodes(self, o: MedImage):
246
- result = self.add_ghosts(o)
247
- # Handle potential complex values from k-space operations
248
- if result.is_complex():
249
- result = torch.real(result)
250
- return MedImage.create(result)
315
+ return MedImage.create(self.add_ghosts(o))
251
316
 
252
317
  def encodes(self, o: MedMask):
253
318
  return o
254
319
 
255
- # %% ../nbs/03_vision_augment.ipynb 15
320
+ # %% ../nbs/03_vision_augment.ipynb 16
256
321
  class RandomSpike(DisplayedTransform):
257
322
  '''Apply TorchIO `RandomSpike`.'''
258
323
 
@@ -267,16 +332,12 @@ class RandomSpike(DisplayedTransform):
267
332
  return self.add_spikes
268
333
 
269
334
  def encodes(self, o: MedImage):
270
- result = self.add_spikes(o)
271
- # Handle potential complex values from k-space operations
272
- if result.is_complex():
273
- result = torch.real(result)
274
- return MedImage.create(result)
335
+ return MedImage.create(self.add_spikes(o))
275
336
 
276
337
  def encodes(self, o: MedMask):
277
338
  return o
278
339
 
279
- # %% ../nbs/03_vision_augment.ipynb 16
340
+ # %% ../nbs/03_vision_augment.ipynb 17
280
341
  class RandomNoise(DisplayedTransform):
281
342
  '''Apply TorchIO `RandomNoise`.'''
282
343
 
@@ -296,7 +357,7 @@ class RandomNoise(DisplayedTransform):
296
357
  def encodes(self, o: MedMask):
297
358
  return o
298
359
 
299
- # %% ../nbs/03_vision_augment.ipynb 17
360
+ # %% ../nbs/03_vision_augment.ipynb 18
300
361
  class RandomBiasField(DisplayedTransform):
301
362
  '''Apply TorchIO `RandomBiasField`.'''
302
363
 
@@ -316,7 +377,7 @@ class RandomBiasField(DisplayedTransform):
316
377
  def encodes(self, o: MedMask):
317
378
  return o
318
379
 
319
- # %% ../nbs/03_vision_augment.ipynb 18
380
+ # %% ../nbs/03_vision_augment.ipynb 19
320
381
  class RandomBlur(DisplayedTransform):
321
382
  '''Apply TorchIO `RandomBlur`.'''
322
383
 
@@ -336,7 +397,7 @@ class RandomBlur(DisplayedTransform):
336
397
  def encodes(self, o: MedMask):
337
398
  return o
338
399
 
339
- # %% ../nbs/03_vision_augment.ipynb 19
400
+ # %% ../nbs/03_vision_augment.ipynb 20
340
401
  class RandomGamma(DisplayedTransform):
341
402
  '''Apply TorchIO `RandomGamma`.'''
342
403
 
@@ -356,7 +417,7 @@ class RandomGamma(DisplayedTransform):
356
417
  def encodes(self, o: MedMask):
357
418
  return o
358
419
 
359
- # %% ../nbs/03_vision_augment.ipynb 20
420
+ # %% ../nbs/03_vision_augment.ipynb 21
360
421
  class RandomIntensityScale(DisplayedTransform):
361
422
  """Randomly scale image intensities by a multiplicative factor.
362
423
 
@@ -388,7 +449,7 @@ class RandomIntensityScale(DisplayedTransform):
388
449
  def encodes(self, o: MedMask):
389
450
  return o
390
451
 
391
- # %% ../nbs/03_vision_augment.ipynb 21
452
+ # %% ../nbs/03_vision_augment.ipynb 22
392
453
  class RandomMotion(DisplayedTransform):
393
454
  """Apply TorchIO `RandomMotion`."""
394
455
 
@@ -416,16 +477,249 @@ class RandomMotion(DisplayedTransform):
416
477
  return self.add_motion
417
478
 
418
479
  def encodes(self, o: MedImage):
419
- result = self.add_motion(o)
420
- # Handle potential complex values from k-space operations
421
- if result.is_complex():
422
- result = torch.real(result)
423
- return MedImage.create(result)
480
+ return MedImage.create(self.add_motion(o))
424
481
 
425
482
  def encodes(self, o: MedMask):
426
483
  return o
427
484
 
428
485
  # %% ../nbs/03_vision_augment.ipynb 23
486
+ def _create_ellipsoid_mask(shape, center, radii):
487
+ """Create a 3D ellipsoid mask.
488
+
489
+ Args:
490
+ shape: (D, H, W) shape of the volume
491
+ center: (z, y, x) center of ellipsoid
492
+ radii: (rz, ry, rx) radii along each axis
493
+
494
+ Returns:
495
+ Boolean mask where True = inside ellipsoid
496
+ """
497
+ z, y, x = torch.meshgrid(
498
+ torch.arange(shape[0]),
499
+ torch.arange(shape[1]),
500
+ torch.arange(shape[2]),
501
+ indexing='ij'
502
+ )
503
+ dist = ((z - center[0]) / radii[0]) ** 2 + \
504
+ ((y - center[1]) / radii[1]) ** 2 + \
505
+ ((x - center[2]) / radii[2]) ** 2
506
+ return dist <= 1.0
507
+
508
+ # %% ../nbs/03_vision_augment.ipynb 24
509
+ class _TioRandomCutout(tio.IntensityTransform):
510
+ """TorchIO-compatible RandomCutout for patch-based workflows.
511
+
512
+ When mask_only=True, cutouts only affect voxels where the mask is positive.
513
+ The mask should be available in the Subject as 'mask' key.
514
+ """
515
+
516
+ def __init__(self, holes=1, spatial_size=8, fill_value=None,
517
+ max_holes=None, max_spatial_size=None, mask_only=True, p=0.2, **kwargs):
518
+ super().__init__(p=p, **kwargs)
519
+ self.holes = holes
520
+ self.spatial_size = spatial_size
521
+ self.fill_value = fill_value
522
+ self.max_holes = max_holes
523
+ self.max_spatial_size = max_spatial_size
524
+ self.mask_only = mask_only
525
+
526
+ def _apply_cutout(self, data, fill_val, mask_tensor=None):
527
+ """Apply spherical cutout(s) to a tensor.
528
+
529
+ Args:
530
+ data: Input tensor of shape (C, D, H, W)
531
+ fill_val: Value to fill cutout regions
532
+ mask_tensor: Optional mask tensor for mask-only cutouts
533
+
534
+ Returns:
535
+ Tensor with cutout applied
536
+ """
537
+ result = data.clone()
538
+ n_holes = torch.randint(self.holes, (self.max_holes or self.holes) + 1, (1,)).item()
539
+
540
+ spatial_shape = data.shape[1:] # (D, H, W)
541
+ min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]
542
+ max_size = self.max_spatial_size or self.spatial_size
543
+ max_size = max_size if isinstance(max_size, int) else max_size[0]
544
+
545
+ for _ in range(n_holes):
546
+ # Random size for this hole
547
+ size = torch.randint(min_size, max_size + 1, (3,))
548
+ radii = size.float() / 2
549
+
550
+ # Random center (ensure hole fits in volume)
551
+ center = [
552
+ torch.randint(int(radii[i].item()),
553
+ max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),
554
+ (1,)).item()
555
+ for i in range(3)
556
+ ]
557
+
558
+ ellipsoid = _create_ellipsoid_mask(spatial_shape, center, radii)
559
+
560
+ if self.mask_only and mask_tensor is not None:
561
+ # INTERSECT with tumor mask - only affect tumor voxels
562
+ tumor_mask = mask_tensor[0] > 0
563
+ cutout_region = ellipsoid & tumor_mask
564
+ else:
565
+ cutout_region = ellipsoid
566
+
567
+ result[:, cutout_region] = fill_val
568
+
569
+ return result
570
+
571
+ def apply_transform(self, subject):
572
+ # Get mask if available for mask-only cutouts
573
+ mask_tensor = None
574
+ if self.mask_only and 'mask' in subject:
575
+ mask_tensor = subject['mask'].data
576
+
577
+ # Skip if mask is empty
578
+ if mask_tensor is not None and not (mask_tensor > 0).any():
579
+ return subject
580
+
581
+ for image in self.get_images(subject):
582
+ data = image.data
583
+ fill_val = self.fill_value if self.fill_value is not None else float(data.min())
584
+ result = self._apply_cutout(data, fill_val, mask_tensor)
585
+ image.set_data(result)
586
+ return subject
587
+
588
+ # %% ../nbs/03_vision_augment.ipynb 25
589
+ class RandomCutout(ItemTransform):
590
+ """Randomly erase spherical regions in 3D medical images with mask-aware placement.
591
+
592
+ Simulates post-operative surgical cavities by filling random ellipsoid
593
+ volumes with specified values. When mask_only=True (default), cutouts only
594
+ affect voxels inside the segmentation mask, ensuring no healthy tissue is modified.
595
+
596
+ Args:
597
+ holes: Minimum number of cutout regions. Default: 1.
598
+ max_holes: Maximum number of regions. Default: 3.
599
+ spatial_size: Minimum cutout diameter in voxels. Default: 8.
600
+ max_spatial_size: Maximum cutout diameter. Default: 16.
601
+ fill: Fill value - 'min', 'mean', 'random', or float. Default: 'min'.
602
+ mask_only: If True, cutouts only affect mask-positive voxels (tumor tissue).
603
+ If False, cutouts can affect any voxel (original behavior). Default: True.
604
+ p: Probability of applying transform. Default: 0.2.
605
+
606
+ Example:
607
+ >>> # Simulate post-op cavities only within tumor regions
608
+ >>> tfm = RandomCutout(holes=1, max_holes=2, spatial_size=10,
609
+ ... max_spatial_size=25, fill='min', mask_only=True, p=0.2)
610
+
611
+ >>> # Original behavior - cutouts anywhere in the volume
612
+ >>> tfm = RandomCutout(mask_only=False, p=0.2)
613
+ """
614
+
615
+ split_idx, order = 0, 1
616
+
617
+ def __init__(self, holes=1, max_holes=3, spatial_size=8,
618
+ max_spatial_size=16, fill='min', mask_only=True, p=0.2):
619
+ self.holes = holes
620
+ self.max_holes = max_holes
621
+ self.spatial_size = spatial_size
622
+ self.max_spatial_size = max_spatial_size
623
+ self.fill = fill
624
+ self.mask_only = mask_only
625
+ self.p = p
626
+
627
+ self._tio_cutout = _TioRandomCutout(
628
+ holes=holes, spatial_size=spatial_size,
629
+ fill_value=None if isinstance(fill, str) else fill,
630
+ max_holes=max_holes, max_spatial_size=max_spatial_size,
631
+ mask_only=mask_only, p=p
632
+ )
633
+
634
+ @property
635
+ def tio_transform(self):
636
+ """Return TorchIO-compatible transform for patch-based workflows.
637
+
638
+ Note: For mask-aware cutouts in patch workflows, the mask must be
639
+ available in the TorchIO Subject as 'mask' key.
640
+ """
641
+ return self._tio_cutout
642
+
643
+ def _get_fill_value(self, tensor):
644
+ if self.fill == 'min': return float(tensor.min())
645
+ elif self.fill == 'mean': return float(tensor.mean())
646
+ elif self.fill == 'random':
647
+ return torch.empty(1).uniform_(float(tensor.min()), float(tensor.max())).item()
648
+ else: return self.fill
649
+
650
+ def encodes(self, x):
651
+ """Apply mask-aware cutout to image.
652
+
653
+ Args:
654
+ x: Tuple of (MedImage, target) where target is MedMask or TensorCategory
655
+
656
+ Returns:
657
+ Tuple of (transformed MedImage, unchanged target)
658
+ """
659
+ img, y_true = x
660
+
661
+ # Probability check
662
+ if torch.rand(1).item() > self.p:
663
+ return img, y_true
664
+
665
+ # Get mask data if available (as numpy for safe boolean operations)
666
+ mask_np = None
667
+ tumor_coords = None
668
+ if isinstance(y_true, MedMask):
669
+ mask_np = y_true.numpy()
670
+ # Get coordinates of tumor voxels for mask-aware center placement
671
+ if self.mask_only and (mask_np > 0).any():
672
+ tumor_coords = np.argwhere(mask_np[0] > 0) # Shape: (N, 3) for z, y, x
673
+
674
+ # Skip cutout if mask_only=True but no mask or empty mask
675
+ if self.mask_only:
676
+ if mask_np is None or tumor_coords is None or len(tumor_coords) == 0:
677
+ return img, y_true
678
+
679
+ # Work with numpy array to avoid tensor subclass issues
680
+ result_np = img.numpy().copy()
681
+ spatial_shape = img.shape[1:] # (D, H, W)
682
+ fill_val = self._get_fill_value(img)
683
+
684
+ n_holes = torch.randint(self.holes, self.max_holes + 1, (1,)).item()
685
+
686
+ min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]
687
+ max_size = self.max_spatial_size or self.spatial_size
688
+ max_size = max_size if isinstance(max_size, int) else max_size[0]
689
+
690
+ for _ in range(n_holes):
691
+ # Random size for this hole
692
+ size = torch.randint(min_size, max_size + 1, (3,))
693
+ radii = size.float() / 2
694
+
695
+ if self.mask_only and tumor_coords is not None:
696
+ # Pick center from within tumor region to ensure intersection
697
+ idx = torch.randint(0, len(tumor_coords), (1,)).item()
698
+ center = tumor_coords[idx].tolist() # [z, y, x]
699
+ else:
700
+ # Random center anywhere (ensure hole fits in volume)
701
+ center = [
702
+ torch.randint(int(radii[i].item()),
703
+ max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),
704
+ (1,)).item()
705
+ for i in range(3)
706
+ ]
707
+
708
+ # Create ellipsoid mask (numpy)
709
+ ellipsoid = _create_ellipsoid_mask(spatial_shape, center, radii).numpy()
710
+
711
+ if self.mask_only and mask_np is not None:
712
+ # INTERSECT with tumor mask - only affect tumor voxels
713
+ tumor_mask = mask_np[0] > 0
714
+ cutout_region = ellipsoid & tumor_mask
715
+ else:
716
+ cutout_region = ellipsoid
717
+
718
+ result_np[:, cutout_region] = fill_val
719
+
720
+ return MedImage.create(torch.from_numpy(result_np)), y_true
721
+
722
+ # %% ../nbs/03_vision_augment.ipynb 27
429
723
  class RandomElasticDeformation(CustomDictTransform):
430
724
  """Apply TorchIO `RandomElasticDeformation`."""
431
725
 
@@ -438,7 +732,7 @@ class RandomElasticDeformation(CustomDictTransform):
438
732
  image_interpolation=image_interpolation,
439
733
  p=p))
440
734
 
441
- # %% ../nbs/03_vision_augment.ipynb 24
735
+ # %% ../nbs/03_vision_augment.ipynb 28
442
736
  class RandomAffine(CustomDictTransform):
443
737
  """Apply TorchIO `RandomAffine`."""
444
738
 
@@ -454,14 +748,14 @@ class RandomAffine(CustomDictTransform):
454
748
  default_pad_value=default_pad_value,
455
749
  p=p))
456
750
 
457
- # %% ../nbs/03_vision_augment.ipynb 25
751
+ # %% ../nbs/03_vision_augment.ipynb 29
458
752
  class RandomFlip(CustomDictTransform):
459
753
  """Apply TorchIO `RandomFlip`."""
460
754
 
461
755
  def __init__(self, axes='LR', p=0.5):
462
756
  super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))
463
757
 
464
- # %% ../nbs/03_vision_augment.ipynb 26
758
+ # %% ../nbs/03_vision_augment.ipynb 30
465
759
  class OneOf(CustomDictTransform):
466
760
  """Apply only one of the given transforms using TorchIO `OneOf`."""
467
761