fastMONAI 0.8.0__tar.gz → 0.8.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {fastmonai-0.8.0/fastMONAI.egg-info → fastmonai-0.8.2}/PKG-INFO +1 -1
  2. fastmonai-0.8.2/fastMONAI/__init__.py +1 -0
  3. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/_modidx.py +32 -0
  4. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/dataset_info.py +142 -2
  5. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/utils.py +1 -0
  6. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/vision_augmentation.py +501 -22
  7. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/vision_patch.py +71 -31
  8. {fastmonai-0.8.0 → fastmonai-0.8.2/fastMONAI.egg-info}/PKG-INFO +1 -1
  9. {fastmonai-0.8.0 → fastmonai-0.8.2}/settings.ini +1 -1
  10. fastmonai-0.8.0/fastMONAI/__init__.py +0 -1
  11. {fastmonai-0.8.0 → fastmonai-0.8.2}/CONTRIBUTING.md +0 -0
  12. {fastmonai-0.8.0 → fastmonai-0.8.2}/LICENSE +0 -0
  13. {fastmonai-0.8.0 → fastmonai-0.8.2}/MANIFEST.in +0 -0
  14. {fastmonai-0.8.0 → fastmonai-0.8.2}/README.md +0 -0
  15. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/external_data.py +0 -0
  16. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/research_utils.py +0 -0
  17. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/vision_all.py +0 -0
  18. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/vision_core.py +0 -0
  19. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/vision_data.py +0 -0
  20. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/vision_inference.py +0 -0
  21. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/vision_loss.py +0 -0
  22. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/vision_metrics.py +0 -0
  23. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI/vision_plot.py +0 -0
  24. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI.egg-info/SOURCES.txt +0 -0
  25. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI.egg-info/dependency_links.txt +0 -0
  26. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI.egg-info/entry_points.txt +0 -0
  27. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI.egg-info/not-zip-safe +0 -0
  28. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI.egg-info/requires.txt +0 -0
  29. {fastmonai-0.8.0 → fastmonai-0.8.2}/fastMONAI.egg-info/top_level.txt +0 -0
  30. {fastmonai-0.8.0 → fastmonai-0.8.2}/pyproject.toml +0 -0
  31. {fastmonai-0.8.0 → fastmonai-0.8.2}/setup.cfg +0 -0
  32. {fastmonai-0.8.0 → fastmonai-0.8.2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastMONAI
3
- Version: 0.8.0
3
+ Version: 0.8.2
4
4
  Summary: fastMONAI library
5
5
  Home-page: https://github.com/MMIV-ML/fastMONAI
6
6
  Author: Satheshkumar Kaliyugarasan
@@ -0,0 +1 @@
1
+ __version__ = "0.8.2"
@@ -38,6 +38,8 @@ d = { 'settings': { 'branch': 'main',
38
38
  'fastMONAI/dataset_info.py'),
39
39
  'fastMONAI.dataset_info.get_class_weights': ( 'dataset_info.html#get_class_weights',
40
40
  'fastMONAI/dataset_info.py'),
41
+ 'fastMONAI.dataset_info.preprocess_dataset': ( 'dataset_info.html#preprocess_dataset',
42
+ 'fastMONAI/dataset_info.py'),
41
43
  'fastMONAI.dataset_info.suggest_patch_size': ( 'dataset_info.html#suggest_patch_size',
42
44
  'fastMONAI/dataset_info.py')},
43
45
  'fastMONAI.external_data': { 'fastMONAI.external_data.MURLs': ('external_data.html#murls', 'fastMONAI/external_data.py'),
@@ -139,6 +141,28 @@ d = { 'settings': { 'branch': 'main',
139
141
  'fastMONAI/vision_augmentation.py'),
140
142
  'fastMONAI.vision_augmentation.CustomDictTransform.tio_transform': ( 'vision_augment.html#customdicttransform.tio_transform',
141
143
  'fastMONAI/vision_augmentation.py'),
144
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation': ( 'vision_augment.html#gpupatchaugmentation',
145
+ 'fastMONAI/vision_augmentation.py'),
146
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation.__call__': ( 'vision_augment.html#gpupatchaugmentation.__call__',
147
+ 'fastMONAI/vision_augmentation.py'),
148
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation.__init__': ( 'vision_augment.html#gpupatchaugmentation.__init__',
149
+ 'fastMONAI/vision_augmentation.py'),
150
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation.__repr__': ( 'vision_augment.html#gpupatchaugmentation.__repr__',
151
+ 'fastMONAI/vision_augmentation.py'),
152
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_affine': ( 'vision_augment.html#gpupatchaugmentation._apply_affine',
153
+ 'fastMONAI/vision_augmentation.py'),
154
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_anisotropy': ( 'vision_augment.html#gpupatchaugmentation._apply_anisotropy',
155
+ 'fastMONAI/vision_augmentation.py'),
156
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_blur': ( 'vision_augment.html#gpupatchaugmentation._apply_blur',
157
+ 'fastMONAI/vision_augmentation.py'),
158
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_flip': ( 'vision_augment.html#gpupatchaugmentation._apply_flip',
159
+ 'fastMONAI/vision_augmentation.py'),
160
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_gamma': ( 'vision_augment.html#gpupatchaugmentation._apply_gamma',
161
+ 'fastMONAI/vision_augmentation.py'),
162
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_intensity_scale': ( 'vision_augment.html#gpupatchaugmentation._apply_intensity_scale',
163
+ 'fastMONAI/vision_augmentation.py'),
164
+ 'fastMONAI.vision_augmentation.GpuPatchAugmentation._apply_noise': ( 'vision_augment.html#gpupatchaugmentation._apply_noise',
165
+ 'fastMONAI/vision_augmentation.py'),
142
166
  'fastMONAI.vision_augmentation.NormalizeIntensity': ( 'vision_augment.html#normalizeintensity',
143
167
  'fastMONAI/vision_augmentation.py'),
144
168
  'fastMONAI.vision_augmentation.NormalizeIntensity.__init__': ( 'vision_augment.html#normalizeintensity.__init__',
@@ -289,10 +313,18 @@ d = { 'settings': { 'branch': 'main',
289
313
  'fastMONAI/vision_augmentation.py'),
290
314
  'fastMONAI.vision_augmentation._TioRandomIntensityScale.apply_transform': ( 'vision_augment.html#_tiorandomintensityscale.apply_transform',
291
315
  'fastMONAI/vision_augmentation.py'),
316
+ 'fastMONAI.vision_augmentation._build_rotation_matrix_3d': ( 'vision_augment.html#_build_rotation_matrix_3d',
317
+ 'fastMONAI/vision_augmentation.py'),
318
+ 'fastMONAI.vision_augmentation._compute_patch_aug_params': ( 'vision_augment.html#_compute_patch_aug_params',
319
+ 'fastMONAI/vision_augmentation.py'),
292
320
  'fastMONAI.vision_augmentation._create_ellipsoid_mask': ( 'vision_augment.html#_create_ellipsoid_mask',
293
321
  'fastMONAI/vision_augmentation.py'),
322
+ 'fastMONAI.vision_augmentation._foreground_masking': ( 'vision_augment.html#_foreground_masking',
323
+ 'fastMONAI/vision_augmentation.py'),
294
324
  'fastMONAI.vision_augmentation.do_pad_or_crop': ( 'vision_augment.html#do_pad_or_crop',
295
325
  'fastMONAI/vision_augmentation.py'),
326
+ 'fastMONAI.vision_augmentation.gpu_patch_augmentations': ( 'vision_augment.html#gpu_patch_augmentations',
327
+ 'fastMONAI/vision_augmentation.py'),
296
328
  'fastMONAI.vision_augmentation.suggest_patch_augmentations': ( 'vision_augment.html#suggest_patch_augmentations',
297
329
  'fastMONAI/vision_augmentation.py')},
298
330
  'fastMONAI.vision_core': { 'fastMONAI.vision_core.MedBase': ('vision_core.html#medbase', 'fastMONAI/vision_core.py'),
@@ -1,15 +1,17 @@
1
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/08_dataset_info.ipynb.
2
2
 
3
3
  # %% auto #0
4
- __all__ = ['MedDataset', 'suggest_patch_size', 'get_class_weights']
4
+ __all__ = ['MedDataset', 'suggest_patch_size', 'preprocess_dataset', 'get_class_weights']
5
5
 
6
6
  # %% ../nbs/08_dataset_info.ipynb #027f016a-a80c-4842-b9dc-0bddb358a00c
7
7
  from .vision_core import *
8
8
  from .vision_plot import find_max_slice
9
9
 
10
10
  from sklearn.utils.class_weight import compute_class_weight
11
- from concurrent.futures import ThreadPoolExecutor
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
+ from tqdm.auto import tqdm
12
13
  from pathlib import Path
14
+ import torchio as tio
13
15
  import pandas as pd
14
16
  import numpy as np
15
17
  import torch
@@ -548,6 +550,144 @@ def suggest_patch_size(
548
550
 
549
551
  return patch_size
550
552
 
553
+ # %% ../nbs/08_dataset_info.ipynb #mbn5svtmzkh
554
+ def preprocess_dataset(df, img_col, mask_col=None, output_dir='preprocessed',
555
+ target_spacing=None, apply_reorder=True, transforms=None,
556
+ max_workers=4, skip_existing=True):
557
+ """Preprocess dataset to disk, creating new columns for preprocessed paths.
558
+
559
+ Processes images (and optionally masks) through a transform pipeline,
560
+ saves to output_dir, then creates new '{col}_preprocessed' columns in
561
+ the DataFrame. Original columns are preserved unchanged.
562
+
563
+ Transform pipeline order:
564
+ CopyAffine (if masks) -> ToCanonical (if apply_reorder)
565
+ -> Resample (if target_spacing) -> user transforms
566
+
567
+ Args:
568
+ df: DataFrame with file paths.
569
+ img_col: Column name for image paths.
570
+ mask_col: Optional column name for mask paths.
571
+ output_dir: Output directory. Creates images/ and masks/ subdirectories.
572
+ target_spacing: Target voxel spacing for resampling (e.g., [1.0, 1.0, 1.0]).
573
+ apply_reorder: Whether to reorder to RAS+ canonical orientation.
574
+ transforms: Additional TorchIO or fastMONAI transforms to apply after
575
+ reordering and resampling.
576
+ max_workers: Number of parallel workers. Each worker loads a full 3D
577
+ volume into memory, so reduce for large volumes.
578
+ skip_existing: Skip files that already exist on disk (with size > 0).
579
+ """
580
+ # Input validation
581
+ if len(df) == 0:
582
+ raise ValueError("DataFrame is empty")
583
+ if img_col not in df.columns:
584
+ raise ValueError(f"Column '{img_col}' not found in DataFrame")
585
+ if mask_col is not None and mask_col not in df.columns:
586
+ raise ValueError(f"Column '{mask_col}' not found in DataFrame")
587
+
588
+ img_names = [Path(p).name for p in df[img_col]]
589
+ if len(set(img_names)) != len(img_names):
590
+ dupes = set(n for n in img_names if img_names.count(n) > 1)
591
+ raise ValueError(f"Duplicate image file names: {dupes}")
592
+
593
+ if mask_col is not None:
594
+ mask_names = [Path(p).name for p in df[mask_col]]
595
+ if len(set(mask_names)) != len(mask_names):
596
+ dupes = set(n for n in mask_names if mask_names.count(n) > 1)
597
+ raise ValueError(f"Duplicate mask file names: {dupes}")
598
+
599
+ # Build transform pipeline (canonical order)
600
+ all_tfms = []
601
+ if mask_col is not None:
602
+ all_tfms.append(tio.CopyAffine(target='image'))
603
+ if apply_reorder:
604
+ all_tfms.append(tio.ToCanonical())
605
+ if target_spacing is not None:
606
+ all_tfms.append(tio.Resample(target_spacing))
607
+ if transforms:
608
+ all_tfms.extend([getattr(t, 'tio_transform', t) for t in transforms])
609
+ pipeline = tio.Compose(all_tfms) if all_tfms else None
610
+
611
+ # Create output directories
612
+ output_dir = Path(output_dir)
613
+ img_dir = output_dir / 'images'
614
+ img_dir.mkdir(parents=True, exist_ok=True)
615
+ if mask_col is not None:
616
+ mask_dir = output_dir / 'masks'
617
+ mask_dir.mkdir(parents=True, exist_ok=True)
618
+
619
+ # Build work items, filtering skip_existing
620
+ work_items = []
621
+ skipped = 0
622
+ for idx in range(len(df)):
623
+ img_path = df[img_col].iloc[idx]
624
+ out_img = img_dir / Path(img_path).name
625
+
626
+ mask_path = df[mask_col].iloc[idx] if mask_col is not None else None
627
+ out_mask = (mask_dir / Path(mask_path).name) if mask_col is not None else None
628
+
629
+ if skip_existing:
630
+ img_ok = out_img.exists() and out_img.stat().st_size > 0
631
+ mask_ok = out_mask is None or (out_mask.exists() and out_mask.stat().st_size > 0)
632
+ if img_ok and mask_ok:
633
+ skipped += 1
634
+ continue
635
+
636
+ work_items.append({
637
+ 'idx': idx, 'img_path': img_path, 'mask_path': mask_path,
638
+ 'out_img': out_img, 'out_mask': out_mask,
639
+ })
640
+
641
+ # Process cases
642
+ processed = 0
643
+ failed = 0
644
+ failed_cases = []
645
+
646
+ def _process_case(item):
647
+ subject_dict = {'image': tio.ScalarImage(item['img_path'])}
648
+ if item['mask_path'] is not None:
649
+ subject_dict['mask'] = tio.LabelMap(item['mask_path'])
650
+
651
+ subject = tio.Subject(**subject_dict)
652
+ if pipeline is not None:
653
+ subject = pipeline(subject)
654
+
655
+ # Atomic write: save to temp file (with valid NIfTI extension), then rename
656
+ out_img = item['out_img']
657
+ tmp_img = out_img.parent / f'.tmp_{out_img.name}'
658
+ subject['image'].save(str(tmp_img))
659
+ os.rename(str(tmp_img), str(out_img))
660
+
661
+ if item['out_mask'] is not None:
662
+ out_mask = item['out_mask']
663
+ tmp_mask = out_mask.parent / f'.tmp_{out_mask.name}'
664
+ subject['mask'].save(str(tmp_mask))
665
+ os.rename(str(tmp_mask), str(out_mask))
666
+
667
+ if work_items:
668
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
669
+ futures = {executor.submit(_process_case, item): item for item in work_items}
670
+ for future in tqdm(as_completed(futures), total=len(futures),
671
+ desc='Preprocessing'):
672
+ item = futures[future]
673
+ try:
674
+ future.result()
675
+ processed += 1
676
+ except Exception as e:
677
+ failed += 1
678
+ failed_cases.append(Path(item['img_path']).name)
679
+ warnings.warn(f"Failed to process {item['img_path']}: {e}")
680
+
681
+ # Create new columns for preprocessed paths (preserve originals)
682
+ df[f'{img_col}_preprocessed'] = [str(img_dir / Path(p).name) for p in df[img_col]]
683
+
684
+ if mask_col is not None:
685
+ df[f'{mask_col}_preprocessed'] = [str(mask_dir / Path(p).name) for p in df[mask_col]]
686
+
687
+ print(f"Preprocessing complete: {processed} processed, {skipped} skipped, {failed} failed")
688
+ if failed_cases:
689
+ print(f"Failed cases: {failed_cases}")
690
+
551
691
  # %% ../nbs/08_dataset_info.ipynb #9b81f6e8-abd7-4bf6-be4c-4118986c308a
552
692
  def get_class_weights(labels: (np.array, list), class_weight: str = 'balanced') -> torch.Tensor:
553
693
  """Calculates and returns the class weights.
@@ -237,6 +237,7 @@ def _extract_patch_config(learn) -> dict:
237
237
  'aggregation_mode': patch_config.aggregation_mode,
238
238
  'padding_mode': patch_config.padding_mode,
239
239
  'keep_largest_component': patch_config.keep_largest_component,
240
+ 'preprocessed': patch_config.preprocessed,
240
241
  }
241
242
  else:
242
243
  config['patch_config'] = None
@@ -4,11 +4,14 @@
4
4
  __all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity',
5
5
  'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField',
6
6
  'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomAnisotropy', 'RandomCutout',
7
- 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf', 'suggest_patch_augmentations']
7
+ 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf', 'GpuPatchAugmentation',
8
+ 'gpu_patch_augmentations', 'suggest_patch_augmentations']
8
9
 
9
10
  # %% ../nbs/03_vision_augment.ipynb #2d6694aa
10
11
  from fastai.data.all import *
11
12
  from .vision_core import *
13
+ import torch.nn.functional as F
14
+ import math
12
15
  import torchio as tio
13
16
  from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity
14
17
 
@@ -91,6 +94,11 @@ class PadOrCrop(DisplayedTransform):
91
94
  def encodes(self, o: (MedImage, MedMask)):
92
95
  return type(o)(self.pad_or_crop(o))
93
96
 
97
+ # %% ../nbs/03_vision_augment.ipynb #534509q2nn
98
+ def _foreground_masking(tensor):
99
+ """Mask for non-zero voxels (nnU-Net-style foreground normalization)."""
100
+ return tensor > 0
101
+
94
102
  # %% ../nbs/03_vision_augment.ipynb #ca95a690
95
103
  class ZNormalization(DisplayedTransform):
96
104
  """Apply TorchIO `ZNormalization`."""
@@ -98,6 +106,8 @@ class ZNormalization(DisplayedTransform):
98
106
  order = 0
99
107
 
100
108
  def __init__(self, masking_method=None, channel_wise=True):
109
+ if masking_method == 'foreground':
110
+ masking_method = _foreground_masking
101
111
  self.z_normalization = tio.ZNormalization(masking_method=masking_method)
102
112
  self.channel_wise = channel_wise
103
113
 
@@ -784,19 +794,19 @@ class OneOf(CustomDictTransform):
784
794
  def __init__(self, transform_dict, p=1):
785
795
  super().__init__(tio.OneOf(transform_dict, p=p))
786
796
 
787
- # %% ../nbs/03_vision_augment.ipynb #t6hak044rc
788
- def suggest_patch_augmentations(patch_size, target_spacing,
789
- anisotropy_threshold=3.0,
790
- translation_fraction=0.15):
791
- """Suggest patch-based augmentations with nnU-Net-inspired defaults.
797
+ # %% ../nbs/03_vision_augment.ipynb #lqet5pabzy
798
+ def _compute_patch_aug_params(patch_size, target_spacing,
799
+ anisotropy_threshold=3.0,
800
+ translation_fraction=0.15):
801
+ """Compute geometry-aware augmentation parameters from patch/spacing metadata.
792
802
 
793
- Derives rotation degrees, translation, and RandomAnisotropy axes from
794
- patch geometry and voxel spacing. Returns a list of fastMONAI transform
795
- instances ready for the ``patch_tfms`` parameter in MedPatchDataLoaders.
803
+ Shared logic used by both suggest_patch_augmentations (CPU/TorchIO) and
804
+ gpu_patch_augmentations (GPU-batched). Extracts rotation degrees, translation
805
+ offsets, and RandomAnisotropy axes from the spatial configuration.
796
806
 
797
807
  Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation
798
808
  is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg
799
- symmetric. Translation is patch_size * fraction per axis.
809
+ symmetric.
800
810
 
801
811
  Args:
802
812
  patch_size: List/tuple of 3 ints -- patch dimensions.
@@ -805,38 +815,507 @@ def suggest_patch_augmentations(patch_size, target_spacing,
805
815
  translation_fraction: Fraction of patch_size for translation (default 0.15).
806
816
 
807
817
  Returns:
808
- list: fastMONAI transform instances (7 normally, 6 if RandomAnisotropy omitted).
809
-
810
- Example::
811
-
812
- >>> patch_tfms = suggest_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])
813
- >>> dls = MedPatchDataLoaders.from_config(..., patch_tfms=patch_tfms)
818
+ dict with keys:
819
+ 'degrees': tuple of 3 ints (per-axis max rotation in degrees)
820
+ 'translation': tuple of 3 ints (per-axis translation in voxels)
821
+ 'aniso_axes': tuple of ints (axes where patch_size > 1)
822
+ 'is_aniso': bool (whether spacing is anisotropic)
814
823
  """
815
824
  if len(patch_size) != 3:
816
825
  raise ValueError(f"patch_size must have 3 elements, got {len(patch_size)}")
817
826
  if len(target_spacing) != 3:
818
827
  raise ValueError(f"target_spacing must have 3 elements, got {len(target_spacing)}")
819
828
 
820
- # Determine anisotropy
821
829
  spacing = list(target_spacing)
822
830
  ratio = max(spacing) / min(spacing)
823
831
  is_aniso = ratio >= anisotropy_threshold
824
832
  aniso_axis = spacing.index(max(spacing)) if is_aniso else None
825
833
 
826
- # Rotation degrees
827
834
  if is_aniso:
828
835
  degrees = [5, 5, 5]
829
836
  degrees[aniso_axis] = 30
830
837
  degrees = tuple(degrees)
831
838
  else:
832
- degrees = 30
839
+ degrees = (30, 30, 30)
833
840
 
834
- # Translation
835
841
  translation = tuple(round(p * translation_fraction) for p in patch_size)
836
-
837
- # RandomAnisotropy axes: all axes where patch_size > 1
838
842
  aniso_axes = tuple(i for i in range(3) if patch_size[i] > 1)
839
843
 
844
+ return {
845
+ 'degrees': degrees,
846
+ 'translation': translation,
847
+ 'aniso_axes': aniso_axes,
848
+ 'is_aniso': is_aniso,
849
+ }
850
+
851
+ # %% ../nbs/03_vision_augment.ipynb #oef9rtzvbw
852
+ def _build_rotation_matrix_3d(angles_rad):
853
+ """Build [N, 3, 3] rotation matrices from [N, 3] Euler angles (XYZ extrinsic).
854
+
855
+ Computes R = Rz @ Ry @ Rx for each sample in the batch.
856
+
857
+ Args:
858
+ angles_rad: Tensor of shape [N, 3] with rotation angles in radians
859
+ for each axis (x, y, z).
860
+
861
+ Returns:
862
+ Tensor of shape [N, 3, 3] -- rotation matrices.
863
+ """
864
+ cos = torch.cos(angles_rad)
865
+ sin = torch.sin(angles_rad)
866
+ cx, cy, cz = cos[:, 0], cos[:, 1], cos[:, 2]
867
+ sx, sy, sz = sin[:, 0], sin[:, 1], sin[:, 2]
868
+
869
+ # R = Rz @ Ry @ Rx (combined formula)
870
+ r00 = cy * cz; r01 = sx * sy * cz - cx * sz; r02 = cx * sy * cz + sx * sz
871
+ r10 = cy * sz; r11 = sx * sy * sz + cx * cz; r12 = cx * sy * sz - sx * cz
872
+ r20 = -sy; r21 = sx * cy; r22 = cx * cy
873
+
874
+ R = torch.stack([
875
+ torch.stack([r00, r01, r02], dim=-1),
876
+ torch.stack([r10, r11, r12], dim=-1),
877
+ torch.stack([r20, r21, r22], dim=-1),
878
+ ], dim=-2)
879
+ return R
880
+
881
+ # %% ../nbs/03_vision_augment.ipynb #ts2kolv83z
882
+ class GpuPatchAugmentation:
883
+ """GPU-batched augmentation for patch-based training.
884
+
885
+ Operates on [B, C, D, H, W] tensors already on GPU. All operations run
886
+ under torch.no_grad() since augmentation does not need gradient tracking.
887
+
888
+ Transform order: spatial (affine, anisotropy, flip) then intensity
889
+ (gamma, intensity_scale, noise, blur). Spatial transforms apply the
890
+ same parameters to both image and mask. Intensity transforms skip the mask.
891
+
892
+ Each transform is controlled by a parameter dict with at minimum a 'p' key
893
+ for per-sample probability. Pass None to disable a transform.
894
+
895
+ Args:
896
+ affine: dict with keys 'scales', 'degrees', 'translation',
897
+ 'default_pad_value', 'p'. None to disable.
898
+ anisotropy: dict with keys 'axes', 'downsampling', 'p'. None to disable.
899
+ flip: dict with keys 'axes', 'p'. None to disable.
900
+ gamma: dict with keys 'log_gamma', 'p'. None to disable.
901
+ intensity_scale: dict with keys 'scale_range', 'p'. None to disable.
902
+ noise: dict with keys 'std', 'p'. None to disable.
903
+ blur: dict with keys 'std', 'p'. None to disable.
904
+
905
+ Example::
906
+
907
+ >>> gpu_aug = GpuPatchAugmentation(
908
+ ... affine={'scales': (0.7, 1.4), 'degrees': (30, 30, 30),
909
+ ... 'translation': (25, 25, 10), 'default_pad_value': 0., 'p': 0.2},
910
+ ... gamma={'log_gamma': (-0.3, 0.3), 'p': 0.3},
911
+ ... flip={'axes': (0, 1, 2), 'p': 0.5},
912
+ ... )
913
+ >>> img_aug, mask_aug = gpu_aug(img_gpu, mask_gpu)
914
+ """
915
+
916
+ def __init__(self, affine=None, anisotropy=None, flip=None,
917
+ gamma=None, intensity_scale=None, noise=None, blur=None):
918
+ self.affine = affine
919
+ self.anisotropy = anisotropy
920
+ self.flip = flip
921
+ self.gamma = gamma
922
+ self.intensity_scale = intensity_scale
923
+ self.noise = noise
924
+ self.blur = blur
925
+
926
+ def __call__(self, img, mask=None):
927
+ """Apply GPU augmentations to a batch.
928
+
929
+ Args:
930
+ img: Tensor [B, C, D, H, W] (float).
931
+ mask: Tensor [B, C, D, H, W] (float), or None.
932
+
933
+ Returns:
934
+ Tuple (img, mask). mask is None if input was None.
935
+ """
936
+ with torch.no_grad():
937
+ # Spatial transforms (same params for img and mask)
938
+ if self.affine is not None:
939
+ img, mask = self._apply_affine(img, mask)
940
+ if self.anisotropy is not None:
941
+ img, mask = self._apply_anisotropy(img, mask)
942
+ if self.flip is not None:
943
+ img, mask = self._apply_flip(img, mask)
944
+ # Intensity transforms (img only)
945
+ if self.gamma is not None:
946
+ img = self._apply_gamma(img)
947
+ if self.intensity_scale is not None:
948
+ img = self._apply_intensity_scale(img)
949
+ if self.noise is not None:
950
+ img = self._apply_noise(img)
951
+ if self.blur is not None:
952
+ img = self._apply_blur(img)
953
+ return img, mask
954
+
955
+ def _apply_affine(self, img, mask):
956
+ """Batched random affine via F.affine_grid + F.grid_sample."""
957
+ cfg = self.affine
958
+ B = img.shape[0]
959
+ device = img.device
960
+ dtype = img.dtype
961
+
962
+ # Per-sample probability
963
+ do_tfm = torch.rand(B, device=device) < cfg['p']
964
+ if not do_tfm.any():
965
+ return img, mask
966
+
967
+ # Start with identity theta [B, 3, 4]
968
+ theta = torch.zeros(B, 3, 4, device=device, dtype=dtype)
969
+ theta[:, 0, 0] = 1.0
970
+ theta[:, 1, 1] = 1.0
971
+ theta[:, 2, 2] = 1.0
972
+
973
+ idx = do_tfm.nonzero(as_tuple=True)[0]
974
+ n = idx.shape[0]
975
+
976
+ # Random scales per axis
977
+ s_lo, s_hi = cfg['scales']
978
+ scales = torch.empty(n, 3, device=device, dtype=dtype).uniform_(s_lo, s_hi)
979
+
980
+ # Random rotation angles (degrees -> radians)
981
+ degrees = cfg['degrees']
982
+ if not isinstance(degrees, (list, tuple)):
983
+ degrees = (degrees, degrees, degrees)
984
+ angles_deg = torch.stack([
985
+ torch.empty(n, device=device, dtype=dtype).uniform_(-degrees[0], degrees[0]),
986
+ torch.empty(n, device=device, dtype=dtype).uniform_(-degrees[1], degrees[1]),
987
+ torch.empty(n, device=device, dtype=dtype).uniform_(-degrees[2], degrees[2]),
988
+ ], dim=1) # [n, 3]
989
+ angles_rad = angles_deg * (math.pi / 180.0)
990
+
991
+ # Build rotation matrices [n, 3, 3]
992
+ R = _build_rotation_matrix_3d(angles_rad)
993
+
994
+ # Scale matrix: S @ R -> [n, 3, 3]
995
+ S = torch.diag_embed(scales) # [n, 3, 3]
996
+ SR = S @ R # [n, 3, 3]
997
+
998
+ # Random translation (voxels -> normalized [-1, 1] coords)
999
+ translation = cfg['translation']
1000
+ if not isinstance(translation, (list, tuple)):
1001
+ translation = (translation, translation, translation)
1002
+ spatial_size = img.shape[2:] # (D, H, W)
1003
+ t_norm = torch.stack([
1004
+ torch.empty(n, device=device, dtype=dtype).uniform_(
1005
+ -translation[i], translation[i]
1006
+ ) * 2.0 / spatial_size[i]
1007
+ for i in range(3)
1008
+ ], dim=1) # [n, 3]
1009
+
1010
+ # Assemble theta for active samples
1011
+ theta[idx, :3, :3] = SR
1012
+ theta[idx, :3, 3] = t_norm
1013
+
1014
+ # Apply grid_sample
1015
+ grid = F.affine_grid(theta, img.shape, align_corners=False)
1016
+ pad_val = cfg.get('default_pad_value', 0.)
1017
+ # For non-zero padding, shift values, sample, shift back
1018
+ if pad_val != 0.:
1019
+ img = img - pad_val
1020
+ img = F.grid_sample(img, grid, mode='bilinear',
1021
+ padding_mode='zeros', align_corners=False)
1022
+ if pad_val != 0.:
1023
+ img = img + pad_val
1024
+
1025
+ if mask is not None:
1026
+ mask = F.grid_sample(mask.float(), grid, mode='nearest',
1027
+ padding_mode='zeros', align_corners=False)
1028
+ return img, mask
1029
+
1030
+ def _apply_anisotropy(self, img, mask):
1031
+ """Per-sample anisotropy simulation via F.interpolate.
1032
+
1033
+ Downsample along a random axis with nearest interpolation,
1034
+ then upsample back with trilinear (matches TorchIO behavior).
1035
+ Only affects img, not mask (anisotropy is intensity degradation).
1036
+ """
1037
+ cfg = self.anisotropy
1038
+ B = img.shape[0]
1039
+ device = img.device
1040
+ ds_lo, ds_hi = cfg['downsampling']
1041
+ axes = cfg['axes']
1042
+
1043
+ for i in range(B):
1044
+ if torch.rand(1, device=device).item() >= cfg['p']:
1045
+ continue
1046
+ # Pick random axis and downsampling factor
1047
+ axis_idx = torch.randint(len(axes), (1,), device=device).item()
1048
+ axis = axes[axis_idx]
1049
+ factor = torch.empty(1, device=device).uniform_(ds_lo, ds_hi).item()
1050
+
1051
+ sample = img[i:i+1] # [1, C, D, H, W]
1052
+ orig_size = list(sample.shape[2:])
1053
+ down_size = list(orig_size)
1054
+ down_size[axis] = max(1, round(orig_size[axis] / factor))
1055
+
1056
+ # Downsample with nearest, upsample with trilinear
1057
+ down = F.interpolate(sample, size=down_size, mode='nearest')
1058
+ up = F.interpolate(down, size=orig_size, mode='trilinear',
1059
+ align_corners=False)
1060
+ img[i:i+1] = up
1061
+
1062
+ return img, mask
1063
+
1064
+ def _apply_flip(self, img, mask):
1065
+ """Per-sample random flip along configured axes."""
1066
+ cfg = self.flip
1067
+ B = img.shape[0]
1068
+ device = img.device
1069
+ axes = cfg['axes']
1070
+ p = cfg['p']
1071
+
1072
+ for i in range(B):
1073
+ # Each axis is independently flipped with probability p
1074
+ flip_dims = []
1075
+ for axis in axes:
1076
+ if torch.rand(1, device=device).item() < p:
1077
+ # axis 0 -> dim 2 (D), axis 1 -> dim 3 (H), axis 2 -> dim 4 (W)
1078
+ # but img[i] is [C, D, H, W], so axis 0 -> dim 1, etc.
1079
+ flip_dims.append(axis + 2) # +2 for batch and channel dims
1080
+ if flip_dims:
1081
+ img[i] = torch.flip(img[i], dims=[d - 1 for d in flip_dims]) # -1 since no batch dim
1082
+ if mask is not None:
1083
+ mask[i] = torch.flip(mask[i], dims=[d - 1 for d in flip_dims])
1084
+
1085
+ return img, mask
1086
+
1087
+ def _apply_gamma(self, img):
1088
+ """Batched gamma correction with per-sample random gamma."""
1089
+ cfg = self.gamma
1090
+ B = img.shape[0]
1091
+ device = img.device
1092
+ dtype = img.dtype
1093
+ log_lo, log_hi = cfg['log_gamma']
1094
+
1095
+ active = torch.rand(B, device=device) < cfg['p']
1096
+ if not active.any():
1097
+ return img
1098
+
1099
+ # Only apply clamp + pow to active samples (clamp destroys negatives)
1100
+ active_idx = active.nonzero(as_tuple=True)[0]
1101
+ log_gamma = torch.empty(active_idx.shape[0], device=device, dtype=dtype).uniform_(log_lo, log_hi)
1102
+ gamma = torch.exp(log_gamma).view(-1, 1, 1, 1, 1)
1103
+ img[active_idx] = img[active_idx].clamp(min=0).pow(gamma)
1104
+ return img
1105
+
1106
+ def _apply_intensity_scale(self, img):
1107
+ """Batched intensity scaling with per-sample random factors."""
1108
+ cfg = self.intensity_scale
1109
+ B = img.shape[0]
1110
+ device = img.device
1111
+ dtype = img.dtype
1112
+ s_lo, s_hi = cfg['scale_range']
1113
+
1114
+ # Per-sample scale (inactive get scale=1)
1115
+ scale = torch.empty(B, device=device, dtype=dtype).uniform_(s_lo, s_hi)
1116
+ active = torch.rand(B, device=device) < cfg['p']
1117
+ scale = torch.where(active, scale, torch.ones_like(scale))
1118
+
1119
+ img = img * scale.view(B, 1, 1, 1, 1)
1120
+ return img
1121
+
1122
+ def _apply_noise(self, img):
1123
+ """Batched additive Gaussian noise with per-sample random std."""
1124
+ cfg = self.noise
1125
+ B = img.shape[0]
1126
+ device = img.device
1127
+ dtype = img.dtype
1128
+
1129
+ std_val = cfg['std']
1130
+ if isinstance(std_val, (list, tuple)):
1131
+ std_lo, std_hi = std_val
1132
+ per_std = torch.empty(B, device=device, dtype=dtype).uniform_(std_lo, std_hi)
1133
+ else:
1134
+ per_std = torch.full((B,), std_val, device=device, dtype=dtype)
1135
+
1136
+ # Zero std for inactive samples
1137
+ active = torch.rand(B, device=device) < cfg['p']
1138
+ per_std = torch.where(active, per_std, torch.zeros_like(per_std))
1139
+
1140
+ noise = torch.randn_like(img) * per_std.view(B, 1, 1, 1, 1)
1141
+ img = img + noise
1142
+ return img
1143
+
1144
+ def _apply_blur(self, img):
1145
+ """Batched separable 3D Gaussian blur via F.conv3d with groups trick."""
1146
+ cfg = self.blur
1147
+ B, C, D, H, W = img.shape
1148
+ device = img.device
1149
+ dtype = img.dtype
1150
+
1151
+ std_val = cfg['std']
1152
+ if isinstance(std_val, (list, tuple)):
1153
+ std_lo, std_hi = std_val
1154
+ else:
1155
+ std_lo, std_hi = 0.0, std_val
1156
+
1157
+ # Per-sample sigma
1158
+ sigma = torch.empty(B, device=device, dtype=dtype).uniform_(std_lo, std_hi)
1159
+ active = torch.rand(B, device=device) < cfg['p']
1160
+ if not active.any():
1161
+ return img
1162
+
1163
+ # Fixed kernel size from max sigma
1164
+ max_sigma = max(std_hi, 0.01)
1165
+ kernel_radius = int(math.ceil(3 * max_sigma))
1166
+ kernel_size = 2 * kernel_radius + 1
1167
+
1168
+ # Build per-sample 1D Gaussian kernels [B, kernel_size]
1169
+ x = torch.arange(-kernel_radius, kernel_radius + 1,
1170
+ device=device, dtype=dtype)
1171
+ # Avoid division by zero for sigma=0
1172
+ safe_sigma = torch.where(active, sigma, torch.ones_like(sigma))
1173
+ kernels = torch.exp(-x.unsqueeze(0)**2 / (2 * safe_sigma.unsqueeze(1)**2))
1174
+ kernels = kernels / kernels.sum(dim=1, keepdim=True)
1175
+
1176
+ # For inactive samples, use delta kernel
1177
+ delta = torch.zeros(B, kernel_size, device=device, dtype=dtype)
1178
+ delta[:, kernel_radius] = 1.0
1179
+ kernels = torch.where(active.unsqueeze(1), kernels, delta)
1180
+
1181
+ # Expand kernels for all channels: [B*C, kernel_size]
1182
+ kernels_bc = kernels.unsqueeze(1).expand(B, C, kernel_size).reshape(B * C, kernel_size)
1183
+
1184
+ # Reshape img for grouped convolution: [1, B*C, D, H, W]
1185
+ img_grouped = img.reshape(1, B * C, D, H, W)
1186
+
1187
+ # Separable 3D convolution: D-axis, H-axis, W-axis
1188
+ pad = kernel_radius
1189
+
1190
+ # D-axis: kernel shape [B*C, 1, K, 1, 1]
1191
+ k_d = kernels_bc.reshape(B * C, 1, kernel_size, 1, 1)
1192
+ img_grouped = F.pad(img_grouped, (0, 0, 0, 0, pad, pad), mode='replicate')
1193
+ img_grouped = F.conv3d(img_grouped, k_d, groups=B * C)
1194
+
1195
+ # H-axis: kernel shape [B*C, 1, 1, K, 1]
1196
+ k_h = kernels_bc.reshape(B * C, 1, 1, kernel_size, 1)
1197
+ img_grouped = F.pad(img_grouped, (0, 0, pad, pad, 0, 0), mode='replicate')
1198
+ img_grouped = F.conv3d(img_grouped, k_h, groups=B * C)
1199
+
1200
+ # W-axis: kernel shape [B*C, 1, 1, 1, K]
1201
+ k_w = kernels_bc.reshape(B * C, 1, 1, 1, kernel_size)
1202
+ img_grouped = F.pad(img_grouped, (pad, pad, 0, 0, 0, 0), mode='replicate')
1203
+ img_grouped = F.conv3d(img_grouped, k_w, groups=B * C)
1204
+
1205
+ return img_grouped.reshape(B, C, D, H, W)
1206
+
1207
+ def __repr__(self):
1208
+ parts = []
1209
+ for name in ['affine', 'anisotropy', 'flip', 'gamma',
1210
+ 'intensity_scale', 'noise', 'blur']:
1211
+ cfg = getattr(self, name)
1212
+ if cfg is not None:
1213
+ parts.append(f"{name}(p={cfg['p']})")
1214
+ return f"GpuPatchAugmentation({', '.join(parts)})"
1215
+
1216
+ # %% ../nbs/03_vision_augment.ipynb #pdbh1nqo0j7
1217
+ def gpu_patch_augmentations(patch_size, target_spacing,
1218
+ anisotropy_threshold=3.0,
1219
+ translation_fraction=0.15,
1220
+ affine_p=0.2, anisotropy_p=0.25,
1221
+ gamma_p=0.3, intensity_scale_p=0.1,
1222
+ noise_p=0.1, blur_p=0.2, flip_p=0.5):
1223
+ """Create GpuPatchAugmentation with nnU-Net-inspired defaults.
1224
+
1225
+ Factory function that mirrors suggest_patch_augmentations but returns
1226
+ a GpuPatchAugmentation for GPU-batched operation. Uses the same shared
1227
+ parameter logic via _compute_patch_aug_params.
1228
+
1229
+ Args:
1230
+ patch_size: List/tuple of 3 ints -- patch dimensions.
1231
+ target_spacing: List/tuple of 3 floats -- voxel spacing.
1232
+ anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0).
1233
+ translation_fraction: Fraction of patch_size for translation (default 0.15).
1234
+ affine_p: Probability for RandomAffine (default 0.2).
1235
+ anisotropy_p: Probability for RandomAnisotropy (default 0.25).
1236
+ gamma_p: Probability for RandomGamma (default 0.3).
1237
+ intensity_scale_p: Probability for RandomIntensityScale (default 0.1).
1238
+ noise_p: Probability for RandomNoise (default 0.1).
1239
+ blur_p: Probability for RandomBlur (default 0.2).
1240
+ flip_p: Probability for RandomFlip per axis (default 0.5).
1241
+
1242
+ Returns:
1243
+ GpuPatchAugmentation instance.
1244
+
1245
+ Example::
1246
+
1247
+ >>> gpu_aug = gpu_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])
1248
+ >>> dls = MedPatchDataLoaders.from_df(..., gpu_augmentation=gpu_aug)
1249
+ """
1250
+ params = _compute_patch_aug_params(
1251
+ patch_size, target_spacing, anisotropy_threshold, translation_fraction
1252
+ )
1253
+
1254
+ affine_cfg = {
1255
+ 'scales': (0.7, 1.4),
1256
+ 'degrees': params['degrees'],
1257
+ 'translation': params['translation'],
1258
+ 'default_pad_value': 0.,
1259
+ 'p': affine_p,
1260
+ }
1261
+
1262
+ aniso_cfg = None
1263
+ if len(params['aniso_axes']) > 0:
1264
+ aniso_cfg = {
1265
+ 'axes': params['aniso_axes'],
1266
+ 'downsampling': (1.5, 4),
1267
+ 'p': anisotropy_p,
1268
+ }
1269
+
1270
+ return GpuPatchAugmentation(
1271
+ affine=affine_cfg,
1272
+ anisotropy=aniso_cfg,
1273
+ flip={'axes': (0, 1, 2), 'p': flip_p},
1274
+ gamma={'log_gamma': (-0.3, 0.3), 'p': gamma_p},
1275
+ intensity_scale={'scale_range': (0.75, 1.25), 'p': intensity_scale_p},
1276
+ noise={'std': 0.1, 'p': noise_p},
1277
+ blur={'std': (0.5, 1.0), 'p': blur_p},
1278
+ )
1279
+
1280
+ # %% ../nbs/03_vision_augment.ipynb #t6hak044rc
1281
+ def suggest_patch_augmentations(patch_size, target_spacing,
1282
+ anisotropy_threshold=3.0,
1283
+ translation_fraction=0.15):
1284
+ """Suggest patch-based augmentations with nnU-Net-inspired defaults.
1285
+
1286
+ Derives rotation degrees, translation, and RandomAnisotropy axes from
1287
+ patch geometry and voxel spacing. Returns a list of fastMONAI transform
1288
+ instances ready for the ``patch_tfms`` parameter in MedPatchDataLoaders.
1289
+
1290
+ Anisotropy detection: if max(spacing)/min(spacing) >= threshold, rotation
1291
+ is restricted to 5 deg out-of-plane and 30 deg in-plane. Otherwise 30 deg
1292
+ symmetric. Translation is patch_size * fraction per axis.
1293
+
1294
+ Args:
1295
+ patch_size: List/tuple of 3 ints -- patch dimensions.
1296
+ target_spacing: List/tuple of 3 floats -- voxel spacing.
1297
+ anisotropy_threshold: Ratio threshold for anisotropy detection (default 3.0).
1298
+ translation_fraction: Fraction of patch_size for translation (default 0.15).
1299
+
1300
+ Returns:
1301
+ list: fastMONAI transform instances (7 normally, 6 if RandomAnisotropy omitted).
1302
+
1303
+ Example::
1304
+
1305
+ >>> patch_tfms = suggest_patch_augmentations([128, 128, 32], [0.5, 0.5, 1.5])
1306
+ >>> dls = MedPatchDataLoaders.from_df(..., patch_tfms=patch_tfms)
1307
+ """
1308
+ params = _compute_patch_aug_params(
1309
+ patch_size, target_spacing, anisotropy_threshold, translation_fraction
1310
+ )
1311
+ degrees = params['degrees']
1312
+ translation = params['translation']
1313
+ aniso_axes = params['aniso_axes']
1314
+
1315
+ # For TorchIO: pass scalar 30 when isotropic (TorchIO expands to symmetric)
1316
+ if not params['is_aniso']:
1317
+ degrees = 30
1318
+
840
1319
  transforms = [
841
1320
  RandomAffine(scales=(0.7, 1.4), degrees=degrees, translation=translation,
842
1321
  default_pad_value=0., p=0.2),
@@ -116,6 +116,11 @@ class PatchConfig:
116
116
  training and inference. Defaults to True (the common case).
117
117
  target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between
118
118
  training and inference.
119
+ preprocessed: If True, data has been preprocessed externally (e.g., via
120
+ preprocess_dataset()). Training will skip reorder, resample, AND
121
+ pre_patch_tfms (e.g., normalization) since they were already applied.
122
+ Inference is unaffected and always applies pre_inference_tfms to raw
123
+ images. Defaults to False.
119
124
  padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding)
120
125
  to align with nnU-Net's approach. Can be int, float, or string (e.g., 'minimum', 'mean').
121
126
  keep_largest_component: If True, keep only the largest connected component
@@ -142,6 +147,7 @@ class PatchConfig:
142
147
  # Preprocessing parameters - must match between training and inference
143
148
  apply_reorder: bool = True # Defaults to True (the common case)
144
149
  target_spacing: list = None
150
+ preprocessed: bool = False # True = data already preprocessed, skip all preprocessing during training
145
151
  padding_mode: int | float | str = 0 # Zero padding (nnU-Net standard)
146
152
  # Post-processing (binary segmentation only)
147
153
  keep_largest_component: bool = False
@@ -403,7 +409,10 @@ class MedPatchDataLoader:
403
409
  patch_tfms: Transforms to apply to extracted patches (training only).
404
410
  Accepts both fastMONAI wrappers (e.g., RandomAffine, RandomGamma) and
405
411
  raw TorchIO transforms. fastMONAI wrappers are automatically normalized
406
- to raw TorchIO for internal use.
412
+ to raw TorchIO for internal use. Mutually exclusive with gpu_augmentation.
413
+ gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation.
414
+ Operates on [B,C,D,H,W] tensors already on GPU, avoiding per-sample CPU
415
+ overhead. Mutually exclusive with patch_tfms. Training only.
407
416
  shuffle: Whether to shuffle subjects and patches.
408
417
  drop_last: Whether to drop last incomplete batch.
409
418
  """
@@ -414,18 +423,20 @@ class MedPatchDataLoader:
414
423
  config: PatchConfig,
415
424
  batch_size: int = 4,
416
425
  patch_tfms: list = None,
426
+ gpu_augmentation=None,
417
427
  shuffle: bool = True,
418
428
  drop_last: bool = False
419
429
  ):
420
430
  if batch_size <= 0:
421
431
  raise ValueError(f"batch_size must be positive, got {batch_size}")
422
-
432
+
423
433
  self.subjects_dataset = subjects_dataset
424
434
  self.config = config
425
435
  self.bs = batch_size
426
436
  self.shuffle = shuffle
427
437
  self.drop_last = drop_last
428
438
  self._device = _get_default_device()
439
+ self.gpu_augmentation = gpu_augmentation
429
440
 
430
441
  # Create sampler
431
442
  self.sampler = create_patch_sampler(config)
@@ -464,14 +475,12 @@ class MedPatchDataLoader:
464
475
  img = batch['image'][tio.DATA] # [B, C, H, W, D]
465
476
  has_mask = 'mask' in batch
466
477
 
467
- # Apply patch transforms if provided
478
+ # Apply CPU patch transforms if provided (per-sample TorchIO loop)
468
479
  if self.patch_tfms is not None:
469
- # Apply transforms to each sample in batch
470
480
  transformed_imgs = []
471
481
  transformed_masks = [] if has_mask else None
472
482
 
473
483
  for i in range(img.shape[0]):
474
- # Build subject dict with image, and mask if available
475
484
  subject_dict = {'image': tio.ScalarImage(tensor=batch['image'][tio.DATA][i])}
476
485
  if has_mask:
477
486
  subject_dict['mask'] = tio.LabelMap(tensor=batch['mask'][tio.DATA][i])
@@ -487,10 +496,19 @@ class MedPatchDataLoader:
487
496
  else:
488
497
  mask = batch['mask'][tio.DATA] if has_mask else None
489
498
 
490
- # Convert to MedImage/MedMask and move to device
491
- img = MedImage(img).to(self._device)
499
+ # Move to device
500
+ img = img.to(self._device)
492
501
  if mask is not None:
493
- mask = MedMask(mask).to(self._device)
502
+ mask = mask.to(self._device)
503
+
504
+ # Apply GPU augmentation if provided (batched, on-device)
505
+ if self.gpu_augmentation is not None:
506
+ img, mask = self.gpu_augmentation(img, mask)
507
+
508
+ # Wrap as MedImage/MedMask
509
+ img = MedImage(img)
510
+ if mask is not None:
511
+ mask = MedMask(mask)
494
512
 
495
513
  yield img, mask
496
514
 
@@ -613,6 +631,7 @@ class MedPatchDataLoaders:
613
631
  patch_config: PatchConfig = None,
614
632
  pre_patch_tfms: list = None,
615
633
  patch_tfms: list = None,
634
+ gpu_augmentation=None,
616
635
  apply_reorder: bool = None,
617
636
  target_spacing: list = None,
618
637
  bs: int = 4,
@@ -640,7 +659,12 @@ class MedPatchDataLoaders:
640
659
  pre_patch_tfms: TorchIO transforms applied before patch extraction
641
660
  (after reorder/resample). Example: [tio.ZNormalization()].
642
661
  Accepts both fastMONAI wrappers and raw TorchIO transforms.
662
+ Skipped when preprocessed=True (include in preprocess_dataset()
663
+ transforms instead). Still needed for inference via pre_inference_tfms.
643
664
  patch_tfms: TorchIO transforms applied to extracted patches (training only).
665
+ Mutually exclusive with gpu_augmentation.
666
+ gpu_augmentation: GpuPatchAugmentation instance for GPU-batched augmentation
667
+ (training only). Mutually exclusive with patch_tfms.
644
668
  apply_reorder: If True, reorder to RAS+ orientation. If None, uses
645
669
  patch_config.apply_reorder. Explicit value overrides config.
646
670
  target_spacing: Target voxel spacing [x, y, z]. If None, uses
@@ -656,22 +680,32 @@ class MedPatchDataLoaders:
656
680
  MedPatchDataLoaders instance.
657
681
 
658
682
  Example:
659
- >>> # New pattern: config contains preprocessing params
660
- >>> config = PatchConfig(
661
- ... patch_size=[96, 96, 96],
662
- ... apply_reorder=True,
663
- ... target_spacing=[0.5, 0.5, 0.5],
664
- ... label_probabilities={0: 0.1, 1: 0.9}
665
- ... )
683
+ >>> # CPU augmentation path (existing)
666
684
  >>> dls = MedPatchDataLoaders.from_df(
667
685
  ... df, img_col='image', mask_col='label',
668
686
  ... patch_config=config,
669
- ... pre_patch_tfms=[tio.ZNormalization()],
670
687
  ... patch_tfms=[tio.RandomAffine(degrees=10), tio.RandomFlip()],
671
688
  ... bs=4
672
689
  ... )
673
- >>> # Memory: ~150 MB (queue buffer only)
690
+ >>>
691
+ >>> # GPU augmentation path (new, faster for long training runs)
692
+ >>> from fastMONAI.vision_augmentation import gpu_patch_augmentations
693
+ >>> gpu_aug = gpu_patch_augmentations(config.patch_size, config.target_spacing)
694
+ >>> dls = MedPatchDataLoaders.from_df(
695
+ ... df, img_col='image', mask_col='label',
696
+ ... patch_config=config,
697
+ ... gpu_augmentation=gpu_aug,
698
+ ... bs=4
699
+ ... )
674
700
  """
701
+ # Validate mutual exclusivity
702
+ if gpu_augmentation is not None and patch_tfms is not None:
703
+ raise ValueError(
704
+ "Cannot use both gpu_augmentation and patch_tfms. "
705
+ "gpu_augmentation operates on GPU tensors batch-wise, while "
706
+ "patch_tfms uses per-sample CPU TorchIO transforms. Choose one."
707
+ )
708
+
675
709
  if patch_config is None:
676
710
  patch_config = PatchConfig()
677
711
 
@@ -699,17 +733,19 @@ class MedPatchDataLoaders:
699
733
  # Build preprocessing transforms
700
734
  all_pre_tfms = []
701
735
 
702
- # Add reorder transform (reorder to RAS+ orientation)
703
- if _apply_reorder:
704
- all_pre_tfms.append(tio.ToCanonical())
736
+ # Skip all preprocessing if data was already preprocessed externally
737
+ if not patch_config.preprocessed:
738
+ # Add reorder transform (reorder to RAS+ orientation)
739
+ if _apply_reorder:
740
+ all_pre_tfms.append(tio.ToCanonical())
705
741
 
706
- # Add resample transform
707
- if _target_spacing is not None:
708
- all_pre_tfms.append(tio.Resample(_target_spacing))
742
+ # Add resample transform
743
+ if _target_spacing is not None:
744
+ all_pre_tfms.append(tio.Resample(_target_spacing))
709
745
 
710
- # Add user-provided transforms (normalize to raw TorchIO transforms)
711
- if pre_patch_tfms:
712
- all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms))
746
+ # Add user-provided transforms (normalize to raw TorchIO transforms)
747
+ if pre_patch_tfms:
748
+ all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms))
713
749
 
714
750
  # Create subjects datasets with lazy loading (paths only, ~0 MB)
715
751
  train_subjects = create_subjects_dataset(
@@ -726,11 +762,14 @@ class MedPatchDataLoaders:
726
762
  # Create DataLoaders (both use same patch_config for consistent sampling)
727
763
  train_dl = MedPatchDataLoader(
728
764
  train_subjects, patch_config, bs,
729
- patch_tfms=patch_tfms, shuffle=True, drop_last=True
765
+ patch_tfms=patch_tfms,
766
+ gpu_augmentation=gpu_augmentation,
767
+ shuffle=True, drop_last=True
730
768
  )
731
769
  valid_dl = MedPatchDataLoader(
732
770
  valid_subjects, patch_config, bs,
733
771
  patch_tfms=None, # No augmentation for validation
772
+ gpu_augmentation=None, # No augmentation for validation
734
773
  shuffle=False, drop_last=False
735
774
  )
736
775
 
@@ -847,7 +886,8 @@ class MedPatchDataLoaders:
847
886
  return self.to(torch.device('cpu'))
848
887
 
849
888
  def show_batch(self, dl_idx=0, max_n=6, figsize=None, channel=0,
850
- slice_index=None, anatomical_plane=0, overlay=False, **kwargs):
889
+ slice_index=None, anatomical_plane=0, overlay=False,
890
+ voxel_size=None, **kwargs):
851
891
  """Show a batch of patch samples for visualization."""
852
892
 
853
893
  dl = self[dl_idx]
@@ -886,15 +926,15 @@ class MedPatchDataLoaders:
886
926
  imgs.extend(im_channels)
887
927
  slice_idxs.extend([idx] * len(im_channels))
888
928
 
889
- voxel_size = self.target_spacing
929
+ _voxel_size = voxel_size if voxel_size is not None else self.target_spacing
890
930
  ctxs = [im.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,
891
- voxel_size=voxel_size)
931
+ voxel_size=_voxel_size)
892
932
  for im, ax, idx in zip(imgs, flat_axs, slice_idxs)]
893
933
 
894
934
  if overlay and has_mask:
895
935
  for mask, ax, idx in zip(masks_for_overlay, flat_axs, slice_idxs):
896
936
  mask.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane,
897
- voxel_size=voxel_size)
937
+ voxel_size=_voxel_size)
898
938
 
899
939
  plt.tight_layout()
900
940
  plt.show()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastMONAI
3
- Version: 0.8.0
3
+ Version: 0.8.2
4
4
  Summary: fastMONAI library
5
5
  Home-page: https://github.com/MMIV-ML/fastMONAI
6
6
  Author: Satheshkumar Kaliyugarasan
@@ -5,7 +5,7 @@
5
5
  ### Python Library ###
6
6
  lib_name = fastMONAI
7
7
  min_python = 3.10
8
- version = 0.8.0
8
+ version = 0.8.2
9
9
  ### OPTIONAL ###
10
10
 
11
11
  requirements = fastai==2.8.6 monai==1.5.2 torchio==0.21.2 xlrd>=1.2.0 scikit-image==0.26.0 imagedata==3.8.14 mlflow==3.9.0 huggingface-hub gdown gradio opencv-python plum-dispatch
@@ -1 +0,0 @@
1
- __version__ = "0.8.0"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes