fastMONAI 0.6.0__tar.gz → 0.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {fastmonai-0.6.0/fastMONAI.egg-info → fastmonai-0.6.1}/PKG-INFO +1 -1
  2. fastmonai-0.6.1/fastMONAI/__init__.py +1 -0
  3. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/_modidx.py +8 -0
  4. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/vision_augmentation.py +223 -61
  5. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/vision_patch.py +18 -0
  6. {fastmonai-0.6.0 → fastmonai-0.6.1/fastMONAI.egg-info}/PKG-INFO +1 -1
  7. {fastmonai-0.6.0 → fastmonai-0.6.1}/settings.ini +1 -1
  8. fastmonai-0.6.0/fastMONAI/__init__.py +0 -1
  9. {fastmonai-0.6.0 → fastmonai-0.6.1}/CONTRIBUTING.md +0 -0
  10. {fastmonai-0.6.0 → fastmonai-0.6.1}/LICENSE +0 -0
  11. {fastmonai-0.6.0 → fastmonai-0.6.1}/MANIFEST.in +0 -0
  12. {fastmonai-0.6.0 → fastmonai-0.6.1}/README.md +0 -0
  13. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/dataset_info.py +0 -0
  14. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/external_data.py +0 -0
  15. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/research_utils.py +0 -0
  16. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/utils.py +0 -0
  17. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/vision_all.py +0 -0
  18. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/vision_core.py +0 -0
  19. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/vision_data.py +0 -0
  20. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/vision_inference.py +0 -0
  21. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/vision_loss.py +0 -0
  22. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/vision_metrics.py +0 -0
  23. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI/vision_plot.py +0 -0
  24. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI.egg-info/SOURCES.txt +0 -0
  25. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI.egg-info/dependency_links.txt +0 -0
  26. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI.egg-info/entry_points.txt +0 -0
  27. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI.egg-info/not-zip-safe +0 -0
  28. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI.egg-info/requires.txt +0 -0
  29. {fastmonai-0.6.0 → fastmonai-0.6.1}/fastMONAI.egg-info/top_level.txt +0 -0
  30. {fastmonai-0.6.0 → fastmonai-0.6.1}/setup.cfg +0 -0
  31. {fastmonai-0.6.0 → fastmonai-0.6.1}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastMONAI
3
- Version: 0.6.0
3
+ Version: 0.6.1
4
4
  Summary: fastMONAI library
5
5
  Home-page: https://github.com/MMIV-ML/fastMONAI
6
6
  Author: Satheshkumar Kaliyugarasan
@@ -0,0 +1 @@
1
+ __version__ = "0.6.1"
@@ -231,6 +231,14 @@ d = { 'settings': { 'branch': 'main',
231
231
  'fastMONAI/vision_augmentation.py'),
232
232
  'fastMONAI.vision_augmentation.RescaleIntensity.tio_transform': ( 'vision_augment.html#rescaleintensity.tio_transform',
233
233
  'fastMONAI/vision_augmentation.py'),
234
+ 'fastMONAI.vision_augmentation.SpatialPad': ( 'vision_augment.html#spatialpad',
235
+ 'fastMONAI/vision_augmentation.py'),
236
+ 'fastMONAI.vision_augmentation.SpatialPad.__init__': ( 'vision_augment.html#spatialpad.__init__',
237
+ 'fastMONAI/vision_augmentation.py'),
238
+ 'fastMONAI.vision_augmentation.SpatialPad.encodes': ( 'vision_augment.html#spatialpad.encodes',
239
+ 'fastMONAI/vision_augmentation.py'),
240
+ 'fastMONAI.vision_augmentation.SpatialPad.tio_transform': ( 'vision_augment.html#spatialpad.tio_transform',
241
+ 'fastMONAI/vision_augmentation.py'),
234
242
  'fastMONAI.vision_augmentation.ZNormalization': ( 'vision_augment.html#znormalization',
235
243
  'fastMONAI/vision_augmentation.py'),
236
244
  'fastMONAI.vision_augmentation.ZNormalization.__init__': ( 'vision_augment.html#znormalization.__init__',
@@ -1,16 +1,17 @@
1
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_vision_augment.ipynb.
2
2
 
3
3
  # %% auto 0
4
- __all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity',
5
- 'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField',
6
- 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomCutout',
7
- 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf']
4
+ __all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'SpatialPad', 'ZNormalization', 'RescaleIntensity',
5
+ 'NormalizeIntensity', 'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike',
6
+ 'RandomNoise', 'RandomBiasField', 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion',
7
+ 'RandomCutout', 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf']
8
8
 
9
9
  # %% ../nbs/03_vision_augment.ipynb 2
10
10
  from fastai.data.all import *
11
11
  from .vision_core import *
12
12
  import torchio as tio
13
13
  from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity
14
+ from monai.transforms import SpatialPad as MonaiSpatialPad
14
15
 
15
16
  # %% ../nbs/03_vision_augment.ipynb 5
16
17
  class CustomDictTransform(ItemTransform):
@@ -92,6 +93,74 @@ class PadOrCrop(DisplayedTransform):
92
93
  return type(o)(self.pad_or_crop(o))
93
94
 
94
95
  # %% ../nbs/03_vision_augment.ipynb 9
96
+ class SpatialPad(DisplayedTransform):
97
+ """Pad image to minimum size without cropping using MONAI's `SpatialPad`.
98
+
99
+ This transform pads each dimension to AT LEAST the specified size. Dimensions
100
+ already larger than `spatial_size` are left unchanged (no cropping), making it
101
+ ideal for patch-based training where images must be at least as large as the
102
+ patch size.
103
+
104
+ Uses zero padding (constant value 0) by default.
105
+
106
+ Args:
107
+ spatial_size: Minimum size [x, y, z] for each dimension. Can be int (same
108
+ for all dims) or list. Usually set to patch_size for patch-based training.
109
+ mode: Padding mode. Default 'constant' means zero padding. Other options:
110
+ 'edge', 'reflect', 'symmetric'.
111
+
112
+ Example:
113
+ >>> # Image 512×512×48, need at least 256×256×96 for patches
114
+ >>> transform = SpatialPad(spatial_size=[256, 256, 96])
115
+ >>> # Result: 512×512×96 (x,y unchanged, z padded from 48→96)
116
+
117
+ Note:
118
+ This differs from `PadOrCrop` which will crop dimensions larger than
119
+ the target size. Use `SpatialPad` when you want to preserve large
120
+ dimensions and only pad small ones.
121
+ """
122
+
123
+ order = 0
124
+
125
+ def __init__(self, spatial_size, mode='constant'):
126
+ if not is_listy(spatial_size):
127
+ spatial_size = [spatial_size, spatial_size, spatial_size]
128
+ self.spatial_size = spatial_size
129
+ self.mode = mode
130
+ # Create MONAI transform (zero padding by default)
131
+ self.spatial_pad = MonaiSpatialPad(
132
+ spatial_size=spatial_size,
133
+ mode=mode
134
+ )
135
+
136
+ @property
137
+ def tio_transform(self):
138
+ """Return TorchIO-compatible transform for patch-based workflows.
139
+
140
+ This wraps MONAI's SpatialPad in a TorchIO SpatialTransform for
141
+ compatibility with TorchIO's Queue and Compose pipelines.
142
+ """
143
+ import torchio as tio
144
+
145
+ class _TioSpatialPad(tio.SpatialTransform):
146
+ """TorchIO wrapper for MONAI's SpatialPad."""
147
+ def __init__(self, spatial_size, mode, **kwargs):
148
+ super().__init__(**kwargs)
149
+ self._monai_pad = MonaiSpatialPad(spatial_size=spatial_size, mode=mode)
150
+
151
+ def apply_transform(self, subject):
152
+ for image in self.get_images(subject):
153
+ data = image.data
154
+ padded = self._monai_pad(data)
155
+ image.set_data(padded)
156
+ return subject
157
+
158
+ return _TioSpatialPad(spatial_size=self.spatial_size, mode=self.mode)
159
+
160
+ def encodes(self, o: (MedImage, MedMask)):
161
+ return type(o)(self.spatial_pad(o))
162
+
163
+ # %% ../nbs/03_vision_augment.ipynb 10
95
164
  class ZNormalization(DisplayedTransform):
96
165
  """Apply TorchIO `ZNormalization`."""
97
166
 
@@ -136,7 +205,7 @@ class ZNormalization(DisplayedTransform):
136
205
  def encodes(self, o: MedMask):
137
206
  return o
138
207
 
139
- # %% ../nbs/03_vision_augment.ipynb 10
208
+ # %% ../nbs/03_vision_augment.ipynb 11
140
209
  class RescaleIntensity(DisplayedTransform):
141
210
  """Apply TorchIO RescaleIntensity for robust intensity scaling.
142
211
 
@@ -165,7 +234,7 @@ class RescaleIntensity(DisplayedTransform):
165
234
  def encodes(self, o: MedMask):
166
235
  return o
167
236
 
168
- # %% ../nbs/03_vision_augment.ipynb 11
237
+ # %% ../nbs/03_vision_augment.ipynb 12
169
238
  class NormalizeIntensity(DisplayedTransform):
170
239
  """Apply MONAI NormalizeIntensity.
171
240
 
@@ -203,7 +272,7 @@ class NormalizeIntensity(DisplayedTransform):
203
272
  def encodes(self, o: MedMask):
204
273
  return o
205
274
 
206
- # %% ../nbs/03_vision_augment.ipynb 12
275
+ # %% ../nbs/03_vision_augment.ipynb 13
207
276
  class BraTSMaskConverter(DisplayedTransform):
208
277
  '''Convert BraTS masks.'''
209
278
 
@@ -215,7 +284,7 @@ class BraTSMaskConverter(DisplayedTransform):
215
284
  o = torch.where(o==4, 3., o)
216
285
  return MedMask.create(o)
217
286
 
218
- # %% ../nbs/03_vision_augment.ipynb 13
287
+ # %% ../nbs/03_vision_augment.ipynb 14
219
288
  class BinaryConverter(DisplayedTransform):
220
289
  '''Convert to binary mask.'''
221
290
 
@@ -228,7 +297,7 @@ class BinaryConverter(DisplayedTransform):
228
297
  o = torch.where(o>0, 1., 0)
229
298
  return MedMask.create(o)
230
299
 
231
- # %% ../nbs/03_vision_augment.ipynb 14
300
+ # %% ../nbs/03_vision_augment.ipynb 15
232
301
  class RandomGhosting(DisplayedTransform):
233
302
  """Apply TorchIO `RandomGhosting`."""
234
303
 
@@ -243,16 +312,12 @@ class RandomGhosting(DisplayedTransform):
243
312
  return self.add_ghosts
244
313
 
245
314
  def encodes(self, o: MedImage):
246
- result = self.add_ghosts(o)
247
- # Handle potential complex values from k-space operations
248
- if result.is_complex():
249
- result = torch.real(result)
250
- return MedImage.create(result)
315
+ return MedImage.create(self.add_ghosts(o))
251
316
 
252
317
  def encodes(self, o: MedMask):
253
318
  return o
254
319
 
255
- # %% ../nbs/03_vision_augment.ipynb 15
320
+ # %% ../nbs/03_vision_augment.ipynb 16
256
321
  class RandomSpike(DisplayedTransform):
257
322
  '''Apply TorchIO `RandomSpike`.'''
258
323
 
@@ -267,16 +332,12 @@ class RandomSpike(DisplayedTransform):
267
332
  return self.add_spikes
268
333
 
269
334
  def encodes(self, o: MedImage):
270
- result = self.add_spikes(o)
271
- # Handle potential complex values from k-space operations
272
- if result.is_complex():
273
- result = torch.real(result)
274
- return MedImage.create(result)
335
+ return MedImage.create(self.add_spikes(o))
275
336
 
276
337
  def encodes(self, o: MedMask):
277
338
  return o
278
339
 
279
- # %% ../nbs/03_vision_augment.ipynb 16
340
+ # %% ../nbs/03_vision_augment.ipynb 17
280
341
  class RandomNoise(DisplayedTransform):
281
342
  '''Apply TorchIO `RandomNoise`.'''
282
343
 
@@ -296,7 +357,7 @@ class RandomNoise(DisplayedTransform):
296
357
  def encodes(self, o: MedMask):
297
358
  return o
298
359
 
299
- # %% ../nbs/03_vision_augment.ipynb 17
360
+ # %% ../nbs/03_vision_augment.ipynb 18
300
361
  class RandomBiasField(DisplayedTransform):
301
362
  '''Apply TorchIO `RandomBiasField`.'''
302
363
 
@@ -316,7 +377,7 @@ class RandomBiasField(DisplayedTransform):
316
377
  def encodes(self, o: MedMask):
317
378
  return o
318
379
 
319
- # %% ../nbs/03_vision_augment.ipynb 18
380
+ # %% ../nbs/03_vision_augment.ipynb 19
320
381
  class RandomBlur(DisplayedTransform):
321
382
  '''Apply TorchIO `RandomBlur`.'''
322
383
 
@@ -336,7 +397,7 @@ class RandomBlur(DisplayedTransform):
336
397
  def encodes(self, o: MedMask):
337
398
  return o
338
399
 
339
- # %% ../nbs/03_vision_augment.ipynb 19
400
+ # %% ../nbs/03_vision_augment.ipynb 20
340
401
  class RandomGamma(DisplayedTransform):
341
402
  '''Apply TorchIO `RandomGamma`.'''
342
403
 
@@ -356,7 +417,7 @@ class RandomGamma(DisplayedTransform):
356
417
  def encodes(self, o: MedMask):
357
418
  return o
358
419
 
359
- # %% ../nbs/03_vision_augment.ipynb 20
420
+ # %% ../nbs/03_vision_augment.ipynb 21
360
421
  class RandomIntensityScale(DisplayedTransform):
361
422
  """Randomly scale image intensities by a multiplicative factor.
362
423
 
@@ -388,7 +449,7 @@ class RandomIntensityScale(DisplayedTransform):
388
449
  def encodes(self, o: MedMask):
389
450
  return o
390
451
 
391
- # %% ../nbs/03_vision_augment.ipynb 21
452
+ # %% ../nbs/03_vision_augment.ipynb 22
392
453
  class RandomMotion(DisplayedTransform):
393
454
  """Apply TorchIO `RandomMotion`."""
394
455
 
@@ -416,16 +477,12 @@ class RandomMotion(DisplayedTransform):
416
477
  return self.add_motion
417
478
 
418
479
  def encodes(self, o: MedImage):
419
- result = self.add_motion(o)
420
- # Handle potential complex values from k-space operations
421
- if result.is_complex():
422
- result = torch.real(result)
423
- return MedImage.create(result)
480
+ return MedImage.create(self.add_motion(o))
424
481
 
425
482
  def encodes(self, o: MedMask):
426
483
  return o
427
484
 
428
- # %% ../nbs/03_vision_augment.ipynb 22
485
+ # %% ../nbs/03_vision_augment.ipynb 23
429
486
  def _create_ellipsoid_mask(shape, center, radii):
430
487
  """Create a 3D ellipsoid mask.
431
488
 
@@ -448,21 +505,35 @@ def _create_ellipsoid_mask(shape, center, radii):
448
505
  ((x - center[2]) / radii[2]) ** 2
449
506
  return dist <= 1.0
450
507
 
451
- # %% ../nbs/03_vision_augment.ipynb 23
508
+ # %% ../nbs/03_vision_augment.ipynb 24
452
509
  class _TioRandomCutout(tio.IntensityTransform):
453
- """TorchIO-compatible RandomCutout for patch-based workflows."""
510
+ """TorchIO-compatible RandomCutout for patch-based workflows.
511
+
512
+ When mask_only=True, cutouts only affect voxels where the mask is positive.
513
+ The mask should be available in the Subject as 'mask' key.
514
+ """
454
515
 
455
516
  def __init__(self, holes=1, spatial_size=8, fill_value=None,
456
- max_holes=None, max_spatial_size=None, p=0.2, **kwargs):
517
+ max_holes=None, max_spatial_size=None, mask_only=True, p=0.2, **kwargs):
457
518
  super().__init__(p=p, **kwargs)
458
519
  self.holes = holes
459
520
  self.spatial_size = spatial_size
460
521
  self.fill_value = fill_value
461
522
  self.max_holes = max_holes
462
523
  self.max_spatial_size = max_spatial_size
524
+ self.mask_only = mask_only
525
+
526
+ def _apply_cutout(self, data, fill_val, mask_tensor=None):
527
+ """Apply spherical cutout(s) to a tensor.
463
528
 
464
- def _apply_cutout(self, data, fill_val):
465
- """Apply spherical cutout(s) to a tensor."""
529
+ Args:
530
+ data: Input tensor of shape (C, D, H, W)
531
+ fill_val: Value to fill cutout regions
532
+ mask_tensor: Optional mask tensor for mask-only cutouts
533
+
534
+ Returns:
535
+ Tensor with cutout applied
536
+ """
466
537
  result = data.clone()
467
538
  n_holes = torch.randint(self.holes, (self.max_holes or self.holes) + 1, (1,)).item()
468
539
 
@@ -484,26 +555,43 @@ class _TioRandomCutout(tio.IntensityTransform):
484
555
  for i in range(3)
485
556
  ]
486
557
 
487
- mask = _create_ellipsoid_mask(spatial_shape, center, radii)
488
- result[:, mask] = fill_val
558
+ ellipsoid = _create_ellipsoid_mask(spatial_shape, center, radii)
559
+
560
+ if self.mask_only and mask_tensor is not None:
561
+ # INTERSECT with tumor mask - only affect tumor voxels
562
+ tumor_mask = mask_tensor[0] > 0
563
+ cutout_region = ellipsoid & tumor_mask
564
+ else:
565
+ cutout_region = ellipsoid
566
+
567
+ result[:, cutout_region] = fill_val
489
568
 
490
569
  return result
491
570
 
492
571
  def apply_transform(self, subject):
572
+ # Get mask if available for mask-only cutouts
573
+ mask_tensor = None
574
+ if self.mask_only and 'mask' in subject:
575
+ mask_tensor = subject['mask'].data
576
+
577
+ # Skip if mask is empty
578
+ if mask_tensor is not None and not (mask_tensor > 0).any():
579
+ return subject
580
+
493
581
  for image in self.get_images(subject):
494
582
  data = image.data
495
583
  fill_val = self.fill_value if self.fill_value is not None else float(data.min())
496
- result = self._apply_cutout(data, fill_val)
584
+ result = self._apply_cutout(data, fill_val, mask_tensor)
497
585
  image.set_data(result)
498
586
  return subject
499
587
 
500
- # %% ../nbs/03_vision_augment.ipynb 24
501
- class RandomCutout(DisplayedTransform):
502
- """Randomly erase spherical regions in 3D medical images.
588
+ # %% ../nbs/03_vision_augment.ipynb 25
589
+ class RandomCutout(ItemTransform):
590
+ """Randomly erase spherical regions in 3D medical images with mask-aware placement.
503
591
 
504
592
  Simulates post-operative surgical cavities by filling random ellipsoid
505
- volumes with specified values. Useful for training on pre-op images
506
- to generalize to post-op scans.
593
+ volumes with specified values. When mask_only=True (default), cutouts only
594
+ affect voxels inside the segmentation mask, ensuring no healthy tissue is modified.
507
595
 
508
596
  Args:
509
597
  holes: Minimum number of cutout regions. Default: 1.
@@ -511,34 +599,45 @@ class RandomCutout(DisplayedTransform):
511
599
  spatial_size: Minimum cutout diameter in voxels. Default: 8.
512
600
  max_spatial_size: Maximum cutout diameter. Default: 16.
513
601
  fill: Fill value - 'min', 'mean', 'random', or float. Default: 'min'.
602
+ mask_only: If True, cutouts only affect mask-positive voxels (tumor tissue).
603
+ If False, cutouts can affect any voxel (original behavior). Default: True.
514
604
  p: Probability of applying transform. Default: 0.2.
515
605
 
516
606
  Example:
517
- >>> # Simulate post-op cavities with dark spherical voids
607
+ >>> # Simulate post-op cavities only within tumor regions
518
608
  >>> tfm = RandomCutout(holes=1, max_holes=2, spatial_size=10,
519
- ... max_spatial_size=25, fill='min', p=0.2)
609
+ ... max_spatial_size=25, fill='min', mask_only=True, p=0.2)
610
+
611
+ >>> # Original behavior - cutouts anywhere in the volume
612
+ >>> tfm = RandomCutout(mask_only=False, p=0.2)
520
613
  """
521
614
 
522
615
  split_idx, order = 0, 1
523
616
 
524
617
  def __init__(self, holes=1, max_holes=3, spatial_size=8,
525
- max_spatial_size=16, fill='min', p=0.2):
618
+ max_spatial_size=16, fill='min', mask_only=True, p=0.2):
526
619
  self.holes = holes
527
620
  self.max_holes = max_holes
528
621
  self.spatial_size = spatial_size
529
622
  self.max_spatial_size = max_spatial_size
530
623
  self.fill = fill
624
+ self.mask_only = mask_only
531
625
  self.p = p
532
626
 
533
627
  self._tio_cutout = _TioRandomCutout(
534
628
  holes=holes, spatial_size=spatial_size,
535
629
  fill_value=None if isinstance(fill, str) else fill,
536
- max_holes=max_holes, max_spatial_size=max_spatial_size, p=p
630
+ max_holes=max_holes, max_spatial_size=max_spatial_size,
631
+ mask_only=mask_only, p=p
537
632
  )
538
633
 
539
634
  @property
540
635
  def tio_transform(self):
541
- """Return TorchIO-compatible transform for patch-based workflows."""
636
+ """Return TorchIO-compatible transform for patch-based workflows.
637
+
638
+ Note: For mask-aware cutouts in patch workflows, the mask must be
639
+ available in the TorchIO Subject as 'mask' key.
640
+ """
542
641
  return self._tio_cutout
543
642
 
544
643
  def _get_fill_value(self, tensor):
@@ -548,16 +647,79 @@ class RandomCutout(DisplayedTransform):
548
647
  return torch.empty(1).uniform_(float(tensor.min()), float(tensor.max())).item()
549
648
  else: return self.fill
550
649
 
551
- def encodes(self, o: MedImage):
552
- if torch.rand(1).item() > self.p: return o
553
- fill_val = self._get_fill_value(o)
554
- result = self._tio_cutout._apply_cutout(o.clone(), fill_val)
555
- return MedImage.create(result)
650
+ def encodes(self, x):
651
+ """Apply mask-aware cutout to image.
556
652
 
557
- def encodes(self, o: MedMask):
558
- return o
653
+ Args:
654
+ x: Tuple of (MedImage, target) where target is MedMask or TensorCategory
559
655
 
560
- # %% ../nbs/03_vision_augment.ipynb 26
656
+ Returns:
657
+ Tuple of (transformed MedImage, unchanged target)
658
+ """
659
+ img, y_true = x
660
+
661
+ # Probability check
662
+ if torch.rand(1).item() > self.p:
663
+ return img, y_true
664
+
665
+ # Get mask data if available (as numpy for safe boolean operations)
666
+ mask_np = None
667
+ tumor_coords = None
668
+ if isinstance(y_true, MedMask):
669
+ mask_np = y_true.numpy()
670
+ # Get coordinates of tumor voxels for mask-aware center placement
671
+ if self.mask_only and (mask_np > 0).any():
672
+ tumor_coords = np.argwhere(mask_np[0] > 0) # Shape: (N, 3) for z, y, x
673
+
674
+ # Skip cutout if mask_only=True but no mask or empty mask
675
+ if self.mask_only:
676
+ if mask_np is None or tumor_coords is None or len(tumor_coords) == 0:
677
+ return img, y_true
678
+
679
+ # Work with numpy array to avoid tensor subclass issues
680
+ result_np = img.numpy().copy()
681
+ spatial_shape = img.shape[1:] # (D, H, W)
682
+ fill_val = self._get_fill_value(img)
683
+
684
+ n_holes = torch.randint(self.holes, self.max_holes + 1, (1,)).item()
685
+
686
+ min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]
687
+ max_size = self.max_spatial_size or self.spatial_size
688
+ max_size = max_size if isinstance(max_size, int) else max_size[0]
689
+
690
+ for _ in range(n_holes):
691
+ # Random size for this hole
692
+ size = torch.randint(min_size, max_size + 1, (3,))
693
+ radii = size.float() / 2
694
+
695
+ if self.mask_only and tumor_coords is not None:
696
+ # Pick center from within tumor region to ensure intersection
697
+ idx = torch.randint(0, len(tumor_coords), (1,)).item()
698
+ center = tumor_coords[idx].tolist() # [z, y, x]
699
+ else:
700
+ # Random center anywhere (ensure hole fits in volume)
701
+ center = [
702
+ torch.randint(int(radii[i].item()),
703
+ max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),
704
+ (1,)).item()
705
+ for i in range(3)
706
+ ]
707
+
708
+ # Create ellipsoid mask (numpy)
709
+ ellipsoid = _create_ellipsoid_mask(spatial_shape, center, radii).numpy()
710
+
711
+ if self.mask_only and mask_np is not None:
712
+ # INTERSECT with tumor mask - only affect tumor voxels
713
+ tumor_mask = mask_np[0] > 0
714
+ cutout_region = ellipsoid & tumor_mask
715
+ else:
716
+ cutout_region = ellipsoid
717
+
718
+ result_np[:, cutout_region] = fill_val
719
+
720
+ return MedImage.create(torch.from_numpy(result_np)), y_true
721
+
722
+ # %% ../nbs/03_vision_augment.ipynb 27
561
723
  class RandomElasticDeformation(CustomDictTransform):
562
724
  """Apply TorchIO `RandomElasticDeformation`."""
563
725
 
@@ -570,7 +732,7 @@ class RandomElasticDeformation(CustomDictTransform):
570
732
  image_interpolation=image_interpolation,
571
733
  p=p))
572
734
 
573
- # %% ../nbs/03_vision_augment.ipynb 27
735
+ # %% ../nbs/03_vision_augment.ipynb 28
574
736
  class RandomAffine(CustomDictTransform):
575
737
  """Apply TorchIO `RandomAffine`."""
576
738
 
@@ -586,14 +748,14 @@ class RandomAffine(CustomDictTransform):
586
748
  default_pad_value=default_pad_value,
587
749
  p=p))
588
750
 
589
- # %% ../nbs/03_vision_augment.ipynb 28
751
+ # %% ../nbs/03_vision_augment.ipynb 29
590
752
  class RandomFlip(CustomDictTransform):
591
753
  """Apply TorchIO `RandomFlip`."""
592
754
 
593
755
  def __init__(self, axes='LR', p=0.5):
594
756
  super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))
595
757
 
596
- # %% ../nbs/03_vision_augment.ipynb 29
758
+ # %% ../nbs/03_vision_augment.ipynb 30
597
759
  class OneOf(CustomDictTransform):
598
760
  """Apply only one of the given transforms using TorchIO `OneOf`."""
599
761
 
@@ -22,6 +22,7 @@ from fastai.learner import Learner
22
22
  from .vision_core import MedImage, MedMask, MedBase, med_img_reader
23
23
  from .vision_inference import _to_original_orientation, _do_resize
24
24
  from .dataset_info import MedDataset, suggest_patch_size
25
+ from .vision_augmentation import SpatialPad
25
26
 
26
27
  # %% ../nbs/10_vision_patch.ipynb 3
27
28
  def _get_default_device() -> torch.device:
@@ -551,6 +552,12 @@ class MedPatchDataLoaders:
551
552
  Memory-efficient: Volumes are loaded on-demand by Queue workers,
552
553
  keeping memory usage constant (~150 MB) regardless of dataset size.
553
554
 
555
+ **Automatic padding**: Images smaller than patch_size are **automatically padded**
556
+ using SpatialPad (zero padding, nnU-Net standard). Dimensions larger than patch_size
557
+ are preserved. A message is printed at DataLoader creation to inform you that
558
+ automatic padding is enabled. This ensures training matches inference behavior
559
+ where both pad small dimensions to minimum patch_size.
560
+
554
561
  Note: Validation uses the same sampling as training (pseudo Dice).
555
562
  For true validation metrics, use PatchInferenceEngine with GridSampler
556
563
  for full-volume sliding window inference.
@@ -700,6 +707,16 @@ class MedPatchDataLoaders:
700
707
  if pre_patch_tfms:
701
708
  all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms))
702
709
 
710
+ # Add SpatialPad to ensure minimum patch_size
711
+ # Pads small dimensions to patch_size while preserving large dimensions.
712
+ # Uses zero padding (nnU-Net standard) to match inference behavior.
713
+ # Placed AFTER normalization ensures consistent intensity preprocessing.
714
+ spatial_pad = SpatialPad(spatial_size=patch_config.patch_size)
715
+ all_pre_tfms.append(spatial_pad.tio_transform)
716
+
717
+ # Inform user about automatic padding (transparency)
718
+ print(f"ℹ️ Automatic padding enabled: dimensions smaller than patch_size {patch_config.patch_size} will be padded (larger dimensions preserved)")
719
+
703
720
  # Create subjects datasets with lazy loading (paths only, ~0 MB)
704
721
  train_subjects = create_subjects_dataset(
705
722
  train_df, img_col, mask_col,
@@ -873,6 +890,7 @@ class MedPatchDataLoaders:
873
890
  except Exception:
874
891
  pass
875
892
 
893
+
876
894
  # %% ../nbs/10_vision_patch.ipynb 20
877
895
  import numbers
878
896
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastMONAI
3
- Version: 0.6.0
3
+ Version: 0.6.1
4
4
  Summary: fastMONAI library
5
5
  Home-page: https://github.com/MMIV-ML/fastMONAI
6
6
  Author: Satheshkumar Kaliyugarasan
@@ -5,7 +5,7 @@
5
5
  ### Python Library ###
6
6
  lib_name = fastMONAI
7
7
  min_python = 3.10
8
- version = 0.6.0
8
+ version = 0.6.1
9
9
  ### OPTIONAL ###
10
10
 
11
11
  requirements = fastai==2.8.3 monai==1.5.0 torchio==0.20.19 xlrd>=1.2.0 scikit-image==0.25.2 imagedata==3.8.14 mlflow==3.3.1 huggingface-hub gdown gradio opencv-python plum-dispatch
@@ -1 +0,0 @@
1
- __version__ = "0.6.0"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes