fastMONAI 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1125 @@
1
+ """Patch-based training and inference for 3D medical image segmentation using TorchIO's Queue mechanism."""
2
+
3
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_vision_patch.ipynb.
4
+
5
+ # %% auto 0
6
+ __all__ = ['normalize_patch_transforms', 'PatchConfig', 'med_to_subject', 'create_subjects_dataset', 'create_patch_sampler',
7
+ 'MedPatchDataLoader', 'MedPatchDataLoaders', 'PatchInferenceEngine', 'patch_inference']
8
+
9
+ # %% ../nbs/10_vision_patch.ipynb 2
10
+ import torch
11
+ import torchio as tio
12
+ import pandas as pd
13
+ import numpy as np
14
+ import warnings
15
+ from pathlib import Path
16
+ from dataclasses import dataclass, field
17
+ from typing import Callable
18
+ from torch.utils.data import DataLoader
19
+ from tqdm.auto import tqdm
20
+ from fastai.data.all import *
21
+ from .vision_core import MedImage, MedMask, MedBase, med_img_reader
22
+ from .vision_inference import _to_original_orientation, _do_resize
23
+
24
+ # %% ../nbs/10_vision_patch.ipynb 3
25
+ def _get_default_device() -> torch.device:
26
+ """Get the default device (CUDA if available, else CPU)."""
27
+ return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+
29
+
30
+ def _warn_config_override(param_name: str, config_value, explicit_value):
31
+ """Warn when explicit argument overrides config value.
32
+
33
+ Args:
34
+ param_name: Name of the parameter (e.g., 'apply_reorder', 'target_spacing')
35
+ config_value: Value from PatchConfig
36
+ explicit_value: Explicitly provided value
37
+ """
38
+ if explicit_value is not None and config_value is not None:
39
+ if explicit_value != config_value:
40
+ warnings.warn(
41
+ f"{param_name} mismatch: explicit={explicit_value}, config={config_value}. "
42
+ f"Using explicit argument."
43
+ )
44
+
45
+
46
+ def _extract_tio_transform(tfm):
47
+ """Extract TorchIO transform from fastMONAI wrapper or return as-is.
48
+
49
+ This function enables using fastMONAI wrappers (e.g., RandomAffine, RandomGamma)
50
+ in patch-based workflows where raw TorchIO transforms are needed for tio.Compose().
51
+
52
+ Uses the explicit `.tio_transform` property when available on fastMONAI wrappers.
53
+ Falls back to returning the transform unchanged for raw TorchIO transforms.
54
+
55
+ Args:
56
+ tfm: fastMONAI wrapper (e.g., RandomAffine) or raw TorchIO transform
57
+
58
+ Returns:
59
+ The underlying TorchIO transform
60
+
61
+ Example:
62
+ >>> from fastMONAI.vision_augmentation import RandomAffine
63
+ >>> wrapped = RandomAffine(degrees=10)
64
+ >>> raw = _extract_tio_transform(wrapped) # Returns tio.RandomAffine
65
+ """
66
+ return getattr(tfm, 'tio_transform', tfm)
67
+
68
+
69
+ def normalize_patch_transforms(tfms: list) -> list:
70
+ """Normalize transforms for patch-based workflow.
71
+
72
+ Extracts underlying TorchIO transforms from fastMONAI wrappers.
73
+ Also accepts raw TorchIO transforms for backward compatibility.
74
+
75
+ This enables using the same transform syntax in both standard and
76
+ patch-based workflows:
77
+
78
+ >>> from fastMONAI.vision_augmentation import RandomAffine, RandomGamma
79
+ >>>
80
+ >>> # Same syntax works in both contexts
81
+ >>> item_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)] # Standard
82
+ >>> patch_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)] # Patch-based
83
+
84
+ Args:
85
+ tfms: List of fastMONAI wrappers or raw TorchIO transforms
86
+
87
+ Returns:
88
+ List of raw TorchIO transforms suitable for tio.Compose()
89
+ """
90
+ if tfms is None:
91
+ return None
92
+ return [_extract_tio_transform(t) for t in tfms]
93
+
94
+ # %% ../nbs/10_vision_patch.ipynb 7
95
+ @dataclass
96
+ class PatchConfig:
97
+ """Configuration for patch-based training and inference.
98
+
99
+ Args:
100
+ patch_size: Size of patches [x, y, z].
101
+ patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list).
102
+ - Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap)
103
+ - Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap)
104
+ - List: per-dimension overlap in pixels
105
+ samples_per_volume: Number of patches to extract per volume during training.
106
+ sampler_type: Type of sampler ('uniform', 'label', 'weighted').
107
+ label_probabilities: For LabelSampler, dict mapping label values to probabilities.
108
+ queue_length: Maximum number of patches to store in queue.
109
+ queue_num_workers: Number of workers for parallel patch extraction.
110
+ aggregation_mode: For inference, how to combine overlapping patches ('crop', 'average', 'hann').
111
+ apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between
112
+ training and inference.
113
+ target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between
114
+ training and inference.
115
+ padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding)
116
+ to align with nnU-Net's approach. Can be int, float, or string (e.g., 'minimum', 'mean').
117
+ keep_largest_component: If True, keep only the largest connected component
118
+ in binary segmentation predictions. Only applies during inference when
119
+ return_probabilities=False. Defaults to False.
120
+
121
+ Example:
122
+ >>> config = PatchConfig(
123
+ ... patch_size=[96, 96, 96],
124
+ ... samples_per_volume=16,
125
+ ... sampler_type='label',
126
+ ... label_probabilities={0: 0.1, 1: 0.9},
127
+ ... apply_reorder=True,
128
+ ... target_spacing=[0.5, 0.5, 0.5]
129
+ ... )
130
+ """
131
+ patch_size: list = field(default_factory=lambda: [96, 96, 96])
132
+ patch_overlap: int | float | list = 0
133
+ samples_per_volume: int = 8
134
+ sampler_type: str = 'uniform'
135
+ label_probabilities: dict = None
136
+ queue_length: int = 300
137
+ queue_num_workers: int = 4
138
+ aggregation_mode: str = 'hann'
139
+ # Preprocessing parameters - must match between training and inference
140
+ apply_reorder: bool = False
141
+ target_spacing: list = None
142
+ padding_mode: int | float | str = 0 # Zero padding (nnU-Net standard)
143
+ # Post-processing (binary segmentation only)
144
+ keep_largest_component: bool = False
145
+
146
+ def __post_init__(self):
147
+ """Validate configuration."""
148
+ valid_samplers = ['uniform', 'label', 'weighted']
149
+ if self.sampler_type not in valid_samplers:
150
+ raise ValueError(f"sampler_type must be one of {valid_samplers}")
151
+
152
+ valid_aggregation = ['crop', 'average', 'hann']
153
+ if self.aggregation_mode not in valid_aggregation:
154
+ raise ValueError(f"aggregation_mode must be one of {valid_aggregation}")
155
+
156
+ # Validate patch_overlap
157
+ # Negative overlap doesn't make sense
158
+ if isinstance(self.patch_overlap, (int, float)):
159
+ if self.patch_overlap < 0:
160
+ raise ValueError("patch_overlap cannot be negative")
161
+ # Check if overlap as pixels would exceed patch_size (causes step_size=0)
162
+ if self.patch_overlap >= 1: # Pixel value, not fraction
163
+ for ps in self.patch_size:
164
+ if self.patch_overlap >= ps:
165
+ raise ValueError(
166
+ f"patch_overlap ({self.patch_overlap}) must be less than patch_size ({ps}). "
167
+ f"Overlap >= patch_size creates step_size <= 0 (infinite patches)."
168
+ )
169
+ elif isinstance(self.patch_overlap, (list, tuple)):
170
+ for i, (overlap, ps) in enumerate(zip(self.patch_overlap, self.patch_size)):
171
+ if overlap < 0:
172
+ raise ValueError(f"patch_overlap[{i}] cannot be negative")
173
+ if overlap >= ps:
174
+ raise ValueError(
175
+ f"patch_overlap[{i}] ({overlap}) must be less than patch_size[{i}] ({ps}). "
176
+ f"Overlap >= patch_size creates step_size <= 0 (infinite patches)."
177
+ )
178
+
179
+ # %% ../nbs/10_vision_patch.ipynb 10
180
+ def med_to_subject(
181
+ img: Path | str,
182
+ mask: Path | str = None,
183
+ ) -> tio.Subject:
184
+ """Create TorchIO Subject with LAZY loading (paths only, no tensor loading).
185
+
186
+ This function stores file paths in the Subject, allowing TorchIO's Queue
187
+ workers to load volumes on-demand during training. This is memory-efficient
188
+ as volumes are not loaded into RAM until needed.
189
+
190
+ Args:
191
+ img: Path to image file.
192
+ mask: Path to mask file (optional).
193
+
194
+ Returns:
195
+ TorchIO Subject with 'image' and optionally 'mask' keys (lazy loaded).
196
+
197
+ Example:
198
+ >>> subject = med_to_subject('image.nii.gz', 'mask.nii.gz')
199
+ >>> # Volume NOT loaded yet - only path stored
200
+ >>> data = subject['image'].data # NOW volume is loaded
201
+ """
202
+ subject_dict = {
203
+ 'image': tio.ScalarImage(path=str(img)) # Lazy - stores path only
204
+ }
205
+
206
+ if mask is not None:
207
+ subject_dict['mask'] = tio.LabelMap(path=str(mask)) # Lazy
208
+
209
+ return tio.Subject(**subject_dict)
210
+
211
+ # %% ../nbs/10_vision_patch.ipynb 11
212
+ def create_subjects_dataset(
213
+ df: pd.DataFrame,
214
+ img_col: str,
215
+ mask_col: str = None,
216
+ pre_tfms: list = None,
217
+ ensure_affine_consistency: bool = True
218
+ ) -> tio.SubjectsDataset:
219
+ """Build TorchIO SubjectsDataset with LAZY loading from DataFrame.
220
+
221
+ This function creates a SubjectsDataset that stores only file paths,
222
+ not loaded tensors. Volumes are loaded on-demand by Queue workers,
223
+ keeping memory usage constant regardless of dataset size.
224
+
225
+ Args:
226
+ df: DataFrame with image (and optionally mask) paths.
227
+ img_col: Column name containing image paths.
228
+ mask_col: Column name containing mask paths (optional).
229
+ pre_tfms: List of TorchIO transforms to apply before patch extraction.
230
+ Use tio.ToCanonical() for reordering and tio.Resample() for resampling.
231
+ ensure_affine_consistency: If True and mask_col is provided, automatically
232
+ prepends tio.CopyAffine(target='image') to ensure spatial metadata
233
+ consistency between image and mask. This prevents "More than one value
234
+ for direction found" errors. Defaults to True.
235
+
236
+ Returns:
237
+ TorchIO SubjectsDataset with lazy-loaded subjects.
238
+
239
+ Example:
240
+ >>> # Preprocessing via transforms (applied by workers on-demand)
241
+ >>> pre_tfms = [
242
+ ... tio.ToCanonical(), # Reorder to RAS+
243
+ ... tio.Resample([0.5, 0.5, 0.5]), # Resample
244
+ ... tio.ZNormalization(), # Intensity normalization
245
+ ... ]
246
+ >>> dataset = create_subjects_dataset(
247
+ ... df, img_col='image', mask_col='label',
248
+ ... pre_tfms=pre_tfms
249
+ ... )
250
+ >>> # Memory: ~0 MB (only paths stored, not volumes)
251
+ """
252
+ subjects = []
253
+ for idx, row in df.iterrows():
254
+ img_path = row[img_col]
255
+ mask_path = row[mask_col] if mask_col else None
256
+
257
+ # Create subject with lazy loading (paths only)
258
+ subject = med_to_subject(img=img_path, mask=mask_path)
259
+ subjects.append(subject)
260
+
261
+ # Build transform pipeline
262
+ all_transforms = []
263
+
264
+ # Add CopyAffine as FIRST transform when mask is present
265
+ # This ensures spatial metadata consistency before other transforms
266
+ if mask_col is not None and ensure_affine_consistency:
267
+ all_transforms.append(tio.CopyAffine(target='image'))
268
+
269
+ # Add user-provided transforms
270
+ if pre_tfms:
271
+ all_transforms.extend(pre_tfms)
272
+
273
+ transform = tio.Compose(all_transforms) if all_transforms else None
274
+
275
+ return tio.SubjectsDataset(subjects, transform=transform)
276
+
277
+ # %% ../nbs/10_vision_patch.ipynb 13
278
+ def create_patch_sampler(config: PatchConfig) -> tio.data.PatchSampler:
279
+ """Create appropriate TorchIO sampler based on config.
280
+
281
+ Args:
282
+ config: PatchConfig with sampler settings.
283
+
284
+ Returns:
285
+ TorchIO PatchSampler instance.
286
+
287
+ Example:
288
+ >>> config = PatchConfig(patch_size=[96, 96, 96], sampler_type='label')
289
+ >>> sampler = create_patch_sampler(config)
290
+ """
291
+ patch_size = config.patch_size
292
+
293
+ if config.sampler_type == 'uniform':
294
+ return tio.UniformSampler(patch_size)
295
+
296
+ elif config.sampler_type == 'label':
297
+ return tio.LabelSampler(
298
+ patch_size,
299
+ label_name='mask',
300
+ label_probabilities=config.label_probabilities
301
+ )
302
+
303
+ elif config.sampler_type == 'weighted':
304
+ raise NotImplementedError(
305
+ "WeightedSampler requires a pre-computed probability map which is not currently supported. "
306
+ "Use 'label' sampler with label_probabilities for weighted sampling based on segmentation labels, "
307
+ "or 'uniform' for random patch extraction."
308
+ )
309
+
310
+ raise ValueError(f"Unknown sampler type: {config.sampler_type}")
311
+
312
+ # %% ../nbs/10_vision_patch.ipynb 17
313
+ class MedPatchDataLoader:
314
+ """DataLoader wrapper for patch-based training with TorchIO Queue.
315
+
316
+ This class wraps a TorchIO Queue to provide a fastai-compatible DataLoader
317
+ interface for patch-based training.
318
+
319
+ Args:
320
+ subjects_dataset: TorchIO SubjectsDataset.
321
+ config: PatchConfig with queue and sampler settings.
322
+ batch_size: Number of patches per batch. Must be positive.
323
+ patch_tfms: Transforms to apply to extracted patches (training only).
324
+ Accepts both fastMONAI wrappers (e.g., RandomAffine, RandomGamma) and
325
+ raw TorchIO transforms. fastMONAI wrappers are automatically normalized
326
+ to raw TorchIO for internal use.
327
+ shuffle: Whether to shuffle subjects and patches.
328
+ drop_last: Whether to drop last incomplete batch.
329
+ """
330
+
331
+ def __init__(
332
+ self,
333
+ subjects_dataset: tio.SubjectsDataset,
334
+ config: PatchConfig,
335
+ batch_size: int = 4,
336
+ patch_tfms: list = None,
337
+ shuffle: bool = True,
338
+ drop_last: bool = False
339
+ ):
340
+ if batch_size <= 0:
341
+ raise ValueError(f"batch_size must be positive, got {batch_size}")
342
+
343
+ self.subjects_dataset = subjects_dataset
344
+ self.config = config
345
+ self.bs = batch_size
346
+ self.shuffle = shuffle
347
+ self.drop_last = drop_last
348
+ self._device = _get_default_device()
349
+
350
+ # Create sampler
351
+ self.sampler = create_patch_sampler(config)
352
+
353
+ # Create patch transforms
354
+ # Normalize transforms - accepts both fastMONAI wrappers and raw TorchIO
355
+ normalized_tfms = normalize_patch_transforms(patch_tfms)
356
+ self.patch_tfms = tio.Compose(normalized_tfms) if normalized_tfms else None
357
+
358
+ # Create queue
359
+ self.queue = tio.Queue(
360
+ subjects_dataset,
361
+ max_length=config.queue_length,
362
+ samples_per_volume=config.samples_per_volume,
363
+ sampler=self.sampler,
364
+ num_workers=config.queue_num_workers,
365
+ shuffle_subjects=shuffle,
366
+ shuffle_patches=shuffle
367
+ )
368
+
369
+ # Create torch DataLoader
370
+ self._dl = DataLoader(
371
+ self.queue,
372
+ batch_size=batch_size,
373
+ num_workers=0, # Queue handles workers
374
+ drop_last=drop_last
375
+ )
376
+
377
+ def __iter__(self):
378
+ """Iterate over batches, yielding (image, mask) tuples."""
379
+ for batch in self._dl:
380
+ # Extract image and mask tensors
381
+ img = batch['image'][tio.DATA] # [B, C, H, W, D]
382
+ has_mask = 'mask' in batch
383
+
384
+ # Apply patch transforms if provided
385
+ if self.patch_tfms is not None:
386
+ # Apply transforms to each sample in batch
387
+ transformed_imgs = []
388
+ transformed_masks = [] if has_mask else None
389
+
390
+ for i in range(img.shape[0]):
391
+ # Build subject dict with image, and mask if available
392
+ subject_dict = {'image': tio.ScalarImage(tensor=batch['image'][tio.DATA][i])}
393
+ if has_mask:
394
+ subject_dict['mask'] = tio.LabelMap(tensor=batch['mask'][tio.DATA][i])
395
+
396
+ subject = tio.Subject(subject_dict)
397
+ transformed = self.patch_tfms(subject)
398
+ transformed_imgs.append(transformed['image'].data)
399
+ if has_mask:
400
+ transformed_masks.append(transformed['mask'].data)
401
+
402
+ img = torch.stack(transformed_imgs)
403
+ mask = torch.stack(transformed_masks) if has_mask else None
404
+ else:
405
+ mask = batch['mask'][tio.DATA] if has_mask else None
406
+
407
+ # Convert to MedImage/MedMask and move to device
408
+ img = MedImage(img).to(self._device)
409
+ if mask is not None:
410
+ mask = MedMask(mask).to(self._device)
411
+
412
+ yield img, mask
413
+
414
+ def __len__(self):
415
+ """Return number of batches per epoch."""
416
+ n_patches = len(self.subjects_dataset) * self.config.samples_per_volume
417
+ if self.drop_last:
418
+ return n_patches // self.bs
419
+ return (n_patches + self.bs - 1) // self.bs
420
+
421
+ @property
422
+ def dataset(self):
423
+ """Return the underlying queue as dataset."""
424
+ return self.queue
425
+
426
+ @property
427
+ def device(self):
428
+ """Return current device."""
429
+ return self._device
430
+
431
+ def to(self, device):
432
+ """Move DataLoader to device."""
433
+ self._device = device
434
+ return self
435
+
436
+ def one_batch(self):
437
+ """Return one batch from the DataLoader.
438
+
439
+ Required for fastai compatibility - used for device detection
440
+ and batch shape validation during Learner initialization.
441
+
442
+ Returns:
443
+ Tuple of (image, mask) tensors on the correct device.
444
+ """
445
+ return next(iter(self))
446
+
447
+ # %% ../nbs/10_vision_patch.ipynb 18
448
+ class MedPatchDataLoaders:
449
+ """fastai-compatible DataLoaders for patch-based training with LAZY loading.
450
+
451
+ This class provides train and validation DataLoaders that work with
452
+ fastai's Learner for patch-based training on 3D medical images.
453
+
454
+ Memory-efficient: Volumes are loaded on-demand by Queue workers,
455
+ keeping memory usage constant (~150 MB) regardless of dataset size.
456
+
457
+ Note: Validation uses the same sampling as training (pseudo Dice).
458
+ For true validation metrics, use PatchInferenceEngine with GridSampler
459
+ for full-volume sliding window inference.
460
+
461
+ Example:
462
+ >>> import torchio as tio
463
+ >>>
464
+ >>> # New pattern: preprocessing params in config (DRY)
465
+ >>> config = PatchConfig(
466
+ ... patch_size=[96, 96, 96],
467
+ ... apply_reorder=True,
468
+ ... target_spacing=[0.5, 0.5, 0.5]
469
+ ... )
470
+ >>> dls = MedPatchDataLoaders.from_df(
471
+ ... df, img_col='image', mask_col='label',
472
+ ... valid_pct=0.2,
473
+ ... patch_config=config,
474
+ ... pre_patch_tfms=[tio.ZNormalization()],
475
+ ... bs=4
476
+ ... )
477
+ >>> learn = Learner(dls, model, loss_func=DiceLoss())
478
+ """
479
+
480
+ def __init__(
481
+ self,
482
+ train_dl: MedPatchDataLoader,
483
+ valid_dl: MedPatchDataLoader,
484
+ device: torch.device = None
485
+ ):
486
+ self._train_dl = train_dl
487
+ self._valid_dl = valid_dl
488
+ self._device = device or _get_default_device()
489
+
490
+ # Move to device
491
+ self._train_dl.to(self._device)
492
+ self._valid_dl.to(self._device)
493
+
494
+ @classmethod
495
+ def from_df(
496
+ cls,
497
+ df: pd.DataFrame,
498
+ img_col: str,
499
+ mask_col: str = None,
500
+ valid_pct: float = 0.2,
501
+ valid_col: str = None,
502
+ patch_config: PatchConfig = None,
503
+ pre_patch_tfms: list = None,
504
+ patch_tfms: list = None,
505
+ apply_reorder: bool = None,
506
+ target_spacing: list = None,
507
+ bs: int = 4,
508
+ seed: int = None,
509
+ device: torch.device = None,
510
+ ensure_affine_consistency: bool = True
511
+ ) -> 'MedPatchDataLoaders':
512
+ """Create train/valid DataLoaders from DataFrame with LAZY loading.
513
+
514
+ Memory-efficient: Only file paths are stored at creation time.
515
+ Volumes are loaded on-demand by Queue workers during training.
516
+
517
+ Note: Both train and valid use the same sampling strategy from patch_config.
518
+ This gives pseudo Dice during training. For true validation metrics,
519
+ use PatchInferenceEngine with full-volume sliding window inference.
520
+
521
+ Args:
522
+ df: DataFrame with image paths.
523
+ img_col: Column name for image paths.
524
+ mask_col: Column name for mask paths.
525
+ valid_pct: Fraction of data for validation.
526
+ valid_col: Column name for train/valid split (if pre-defined).
527
+ patch_config: PatchConfig instance. Preprocessing params (apply_reorder,
528
+ target_spacing) can be set here for DRY usage with PatchInferenceEngine.
529
+ pre_patch_tfms: TorchIO transforms applied before patch extraction
530
+ (after reorder/resample). Example: [tio.ZNormalization()].
531
+ patch_tfms: TorchIO transforms applied to extracted patches (training only).
532
+ apply_reorder: If True, reorder to RAS+ orientation. If None, uses
533
+ patch_config.apply_reorder. Explicit value overrides config.
534
+ target_spacing: Target voxel spacing [x, y, z]. If None, uses
535
+ patch_config.target_spacing. Explicit value overrides config.
536
+ bs: Batch size.
537
+ seed: Random seed for splitting.
538
+ device: Device to use.
539
+ ensure_affine_consistency: If True and mask_col is provided, automatically
540
+ adds tio.CopyAffine(target='image') as the first transform to prevent
541
+ spatial metadata mismatch errors. Defaults to True.
542
+
543
+ Returns:
544
+ MedPatchDataLoaders instance.
545
+
546
+ Example:
547
+ >>> # New pattern: config contains preprocessing params
548
+ >>> config = PatchConfig(
549
+ ... patch_size=[96, 96, 96],
550
+ ... apply_reorder=True,
551
+ ... target_spacing=[0.5, 0.5, 0.5],
552
+ ... label_probabilities={0: 0.1, 1: 0.9}
553
+ ... )
554
+ >>> dls = MedPatchDataLoaders.from_df(
555
+ ... df, img_col='image', mask_col='label',
556
+ ... patch_config=config,
557
+ ... pre_patch_tfms=[tio.ZNormalization()],
558
+ ... patch_tfms=[tio.RandomAffine(degrees=10), tio.RandomFlip()],
559
+ ... bs=4
560
+ ... )
561
+ >>> # Memory: ~150 MB (queue buffer only)
562
+ """
563
+ if patch_config is None:
564
+ patch_config = PatchConfig()
565
+
566
+ # Use config values, allow explicit overrides for backward compatibility
567
+ _apply_reorder = apply_reorder if apply_reorder is not None else patch_config.apply_reorder
568
+ _target_spacing = target_spacing if target_spacing is not None else patch_config.target_spacing
569
+
570
+ # Warn if both config and explicit args provided with different values
571
+ _warn_config_override('apply_reorder', patch_config.apply_reorder, apply_reorder)
572
+ _warn_config_override('target_spacing', patch_config.target_spacing, target_spacing)
573
+
574
+ # Split data
575
+ if valid_col is not None:
576
+ train_df = df[df[valid_col] == False].reset_index(drop=True)
577
+ valid_df = df[df[valid_col] == True].reset_index(drop=True)
578
+ else:
579
+ if seed is not None:
580
+ np.random.seed(seed)
581
+ n = len(df)
582
+ valid_idx = np.random.choice(n, size=int(n * valid_pct), replace=False)
583
+ train_idx = np.setdiff1d(np.arange(n), valid_idx)
584
+ train_df = df.iloc[train_idx].reset_index(drop=True)
585
+ valid_df = df.iloc[valid_idx].reset_index(drop=True)
586
+
587
+ # Build preprocessing transforms
588
+ all_pre_tfms = []
589
+
590
+ # Add reorder transform (reorder to RAS+ orientation)
591
+ if _apply_reorder:
592
+ all_pre_tfms.append(tio.ToCanonical())
593
+
594
+ # Add resample transform
595
+ if _target_spacing is not None:
596
+ all_pre_tfms.append(tio.Resample(_target_spacing))
597
+
598
+ # Add user-provided transforms
599
+ if pre_patch_tfms:
600
+ all_pre_tfms.extend(pre_patch_tfms)
601
+
602
+ # Create subjects datasets with lazy loading (paths only, ~0 MB)
603
+ train_subjects = create_subjects_dataset(
604
+ train_df, img_col, mask_col,
605
+ pre_tfms=all_pre_tfms if all_pre_tfms else None,
606
+ ensure_affine_consistency=ensure_affine_consistency
607
+ )
608
+ valid_subjects = create_subjects_dataset(
609
+ valid_df, img_col, mask_col,
610
+ pre_tfms=all_pre_tfms if all_pre_tfms else None,
611
+ ensure_affine_consistency=ensure_affine_consistency
612
+ )
613
+
614
+ # Create DataLoaders (both use same patch_config for consistent sampling)
615
+ train_dl = MedPatchDataLoader(
616
+ train_subjects, patch_config, bs,
617
+ patch_tfms=patch_tfms, shuffle=True, drop_last=True
618
+ )
619
+ valid_dl = MedPatchDataLoader(
620
+ valid_subjects, patch_config, bs,
621
+ patch_tfms=None, # No augmentation for validation
622
+ shuffle=False, drop_last=False
623
+ )
624
+
625
+ # Create instance and store metadata
626
+ instance = cls(train_dl, valid_dl, device)
627
+ instance._img_col = img_col
628
+ instance._mask_col = mask_col
629
+ instance._pre_patch_tfms = pre_patch_tfms
630
+ instance._apply_reorder = _apply_reorder
631
+ instance._target_spacing = _target_spacing
632
+ instance._ensure_affine_consistency = ensure_affine_consistency
633
+ instance._patch_config = patch_config
634
+ return instance
635
+
636
+ @property
637
+ def train(self):
638
+ """Training DataLoader."""
639
+ return self._train_dl
640
+
641
+ @property
642
+ def valid(self):
643
+ """Validation DataLoader."""
644
+ return self._valid_dl
645
+
646
+ @property
647
+ def train_ds(self):
648
+ """Training subjects dataset."""
649
+ return self._train_dl.subjects_dataset
650
+
651
+ @property
652
+ def valid_ds(self):
653
+ """Validation subjects dataset."""
654
+ return self._valid_dl.subjects_dataset
655
+
656
+ @property
657
+ def device(self):
658
+ """Current device."""
659
+ return self._device
660
+
661
+ @property
662
+ def bs(self):
663
+ """Batch size."""
664
+ return self._train_dl.bs
665
+
666
+ @property
667
+ def apply_reorder(self):
668
+ """Whether reordering to RAS+ is enabled."""
669
+ return getattr(self, '_apply_reorder', False)
670
+
671
+ @property
672
+ def target_spacing(self):
673
+ """Target voxel spacing for resampling."""
674
+ return getattr(self, '_target_spacing', None)
675
+
676
+ @property
677
+ def patch_config(self):
678
+ """The PatchConfig used for this DataLoaders."""
679
+ return getattr(self, '_patch_config', None)
680
+
681
+ def to(self, device):
682
+ """Move DataLoaders to device."""
683
+ self._device = device
684
+ self._train_dl.to(device)
685
+ self._valid_dl.to(device)
686
+ return self
687
+
688
+ def __iter__(self):
689
+ """Iterate over training DataLoader."""
690
+ return iter(self._train_dl)
691
+
692
+ def one_batch(self):
693
+ """Return one batch from the training DataLoader.
694
+
695
+ Required for fastai Learner compatibility - used for device
696
+ detection and batch shape validation.
697
+ """
698
+ return self._train_dl.one_batch()
699
+
700
+ def __len__(self):
701
+ """Return number of batches in training DataLoader."""
702
+ return len(self._train_dl)
703
+
704
+ def __getitem__(self, idx):
705
+ """Get DataLoader by index. Required for fastai Learner compatibility.
706
+
707
+ Args:
708
+ idx: 0 for training DataLoader, 1 for validation DataLoader.
709
+
710
+ Returns:
711
+ MedPatchDataLoader instance.
712
+ """
713
+ if idx == 0:
714
+ return self._train_dl
715
+ elif idx == 1:
716
+ return self._valid_dl
717
+ else:
718
+ raise IndexError(f"Index {idx} out of range. Use 0 (train) or 1 (valid).")
719
+
720
+ def cuda(self):
721
+ """Move DataLoaders to CUDA device."""
722
+ return self.to(torch.device('cuda'))
723
+
724
+ def cpu(self):
725
+ """Move DataLoaders to CPU."""
726
+ return self.to(torch.device('cpu'))
727
+
728
+ def new_empty(self):
729
+ """Create a new empty version of self for learner export.
730
+
731
+ Required for fastai Learner.export() compatibility - creates a
732
+ lightweight placeholder that can be pickled without the full dataset.
733
+
734
+ Returns:
735
+ A minimal MedPatchDataLoaders-like object with no data.
736
+ """
737
+ class EmptyMedPatchDataLoaders:
738
+ """Minimal placeholder for exported learner."""
739
+ def __init__(self, device):
740
+ self._device = device
741
+ @property
742
+ def device(self): return self._device
743
+ def to(self, device):
744
+ self._device = device
745
+ return self
746
+
747
+ return EmptyMedPatchDataLoaders(self._device)
748
+
749
+ # %% ../nbs/10_vision_patch.ipynb 20
750
+ import numbers
751
+
752
+ def _normalize_patch_overlap(patch_overlap, patch_size):
753
+ """Convert patch_overlap to integer pixel values for TorchIO compatibility.
754
+
755
+ TorchIO's GridSampler expects patch_overlap as a tuple of even integers.
756
+ This function handles:
757
+ - Fractional overlap (0-1): converted to pixel values based on patch_size
758
+ - Numpy scalar types: converted to native Python types
759
+ - Sequences: converted to tuple of integers
760
+
761
+ Note: Input validation (negative values, overlap >= patch_size) is handled
762
+ by PatchConfig.__post_init__(). This function focuses on format conversion.
763
+
764
+ Args:
765
+ patch_overlap: int, float (0-1 for fraction), or sequence
766
+ patch_size: list/tuple of patch dimensions [x, y, z]
767
+
768
+ Returns:
769
+ Tuple of even integers suitable for TorchIO GridSampler
770
+ """
771
+ # Handle scalar fractional overlap (0 < x < 1)
772
+ # Note: excludes 1.0 as 100% overlap creates step_size=0 (infinite patches)
773
+ if isinstance(patch_overlap, (int, float, numbers.Number)) and 0 < float(patch_overlap) < 1:
774
+ # Convert fraction to pixel values, ensure even
775
+ result = []
776
+ for ps in patch_size:
777
+ pixels = int(int(ps) * float(patch_overlap))
778
+ # Ensure even (required by TorchIO)
779
+ if pixels % 2 != 0:
780
+ pixels = pixels - 1 if pixels > 0 else 0
781
+ result.append(pixels)
782
+ return tuple(result)
783
+
784
+ # Handle scalar integer (including numpy scalars) - values > 1 are pixel counts
785
+ if isinstance(patch_overlap, (int, float, numbers.Number)):
786
+ val = int(patch_overlap)
787
+ # Ensure even
788
+ if val % 2 != 0:
789
+ val = val - 1 if val > 0 else 0
790
+ return tuple(val for _ in patch_size)
791
+
792
+ # Handle sequences (list, tuple, ndarray)
793
+ result = []
794
+ for val in patch_overlap:
795
+ pixels = int(val)
796
+ if pixels % 2 != 0:
797
+ pixels = pixels - 1 if pixels > 0 else 0
798
+ result.append(pixels)
799
+ return tuple(result)
800
+
801
+
802
+ class PatchInferenceEngine:
803
+ """Patch-based inference with automatic volume reconstruction.
804
+
805
+ Uses TorchIO's GridSampler to extract overlapping patches and
806
+ GridAggregator to reconstruct the full volume from predictions.
807
+
808
+ Args:
809
+ learner: PyTorch model or fastai Learner.
810
+ config: PatchConfig with inference settings. Preprocessing params (apply_reorder,
811
+ target_spacing, padding_mode) can be set here for DRY usage.
812
+ apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.
813
+ target_spacing: Target voxel spacing. If None, uses config value.
814
+ batch_size: Number of patches to predict at once. Must be positive.
815
+ pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.
816
+ IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).
817
+ This ensures preprocessing consistency between training and inference.
818
+
819
+ Example:
820
+ >>> # DRY pattern: use same config for training and inference
821
+ >>> config = PatchConfig(
822
+ ... patch_size=[96, 96, 96],
823
+ ... apply_reorder=True,
824
+ ... target_spacing=[1.0, 1.0, 1.0]
825
+ ... )
826
+ >>> # Training
827
+ >>> dls = MedPatchDataLoaders.from_df(df, 'img', 'mask', patch_config=config)
828
+ >>> # Inference - no need to repeat params!
829
+ >>> engine = PatchInferenceEngine(
830
+ ... learn, config,
831
+ ... pre_inference_tfms=[tio.ZNormalization()]
832
+ ... )
833
+ >>> pred = engine.predict('image.nii.gz')
834
+ """
835
+
836
+ def __init__(
837
+ self,
838
+ learner,
839
+ config: PatchConfig,
840
+ apply_reorder: bool = None,
841
+ target_spacing: list = None,
842
+ batch_size: int = 4,
843
+ pre_inference_tfms: list = None
844
+ ):
845
+ if batch_size <= 0:
846
+ raise ValueError(f"batch_size must be positive, got {batch_size}")
847
+
848
+ # Extract model from Learner if needed
849
+ self.model = learner.model if hasattr(learner, 'model') else learner
850
+ self.config = config
851
+ self.batch_size = batch_size
852
+ self.pre_inference_tfms = tio.Compose(pre_inference_tfms) if pre_inference_tfms else None
853
+
854
+ # Use config values, allow explicit overrides for backward compatibility
855
+ self.apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder
856
+ self.target_spacing = target_spacing if target_spacing is not None else config.target_spacing
857
+
858
+ # Warn if explicit args provided but differ from config (potential mistake)
859
+ _warn_config_override('apply_reorder', config.apply_reorder, apply_reorder)
860
+ _warn_config_override('target_spacing', config.target_spacing, target_spacing)
861
+
862
+ # Get device from model parameters, with fallback for parameter-less models
863
+ try:
864
+ self._device = next(self.model.parameters()).device
865
+ except StopIteration:
866
+ self._device = _get_default_device()
867
+
868
+ def predict(
869
+ self,
870
+ img_path: Path | str,
871
+ return_probabilities: bool = False,
872
+ return_affine: bool = False
873
+ ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
874
+ """Predict on a single volume using patch-based inference.
875
+
876
+ Args:
877
+ img_path: Path to input image.
878
+ return_probabilities: If True, return probability map instead of argmax.
879
+ return_affine: If True, return (prediction, affine) tuple instead of just prediction.
880
+
881
+ Returns:
882
+ Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True.
883
+ """
884
+ # Load image - keep org_img and org_size for post-processing
885
+ # Note: med_img_reader handles reorder/resample internally, no global state needed
886
+ org_img, input_img, org_size = med_img_reader(
887
+ img_path, apply_reorder=self.apply_reorder, target_spacing=self.target_spacing, only_tensor=False
888
+ )
889
+
890
+ # Create TorchIO Subject from preprocessed image
891
+ subject = tio.Subject(
892
+ image=tio.ScalarImage(tensor=input_img.data.float(), affine=input_img.affine)
893
+ )
894
+
895
+ # Apply pre-inference transforms (e.g., ZNormalization) to match training
896
+ if self.pre_inference_tfms is not None:
897
+ subject = self.pre_inference_tfms(subject)
898
+
899
+ # Pad dimensions smaller than patch_size, keep larger dimensions intact
900
+ # GridSampler handles large images via overlapping patches
901
+ img_shape = subject['image'].shape[1:] # Exclude channel dim
902
+ target_size = [max(s, p) for s, p in zip(img_shape, self.config.patch_size)]
903
+
904
+ # Warn if volume needed padding (may cause artifacts if training didn't cover similar sizes)
905
+ if any(s < p for s, p in zip(img_shape, self.config.patch_size)):
906
+ padded_dims = [f"dim{i}: {s}<{p}" for i, (s, p) in enumerate(zip(img_shape, self.config.patch_size)) if s < p]
907
+ warnings.warn(
908
+ f"Image size {list(img_shape)} smaller than patch_size {self.config.patch_size} "
909
+ f"in {padded_dims}. Padding with mode={self.config.padding_mode}. "
910
+ "Ensure training data covered similar sizes to avoid artifacts."
911
+ )
912
+
913
+ # Use padding_mode from config (default: 0 for zero padding, nnU-Net standard)
914
+ subject = tio.CropOrPad(target_size, padding_mode=self.config.padding_mode)(subject)
915
+
916
+ # Convert patch_overlap to integer pixel values for TorchIO compatibility
917
+ patch_overlap = _normalize_patch_overlap(self.config.patch_overlap, self.config.patch_size)
918
+
919
+ # Create GridSampler
920
+ grid_sampler = tio.GridSampler(
921
+ subject,
922
+ patch_size=self.config.patch_size,
923
+ patch_overlap=patch_overlap
924
+ )
925
+
926
+ # Create GridAggregator
927
+ aggregator = tio.GridAggregator(
928
+ grid_sampler,
929
+ overlap_mode=self.config.aggregation_mode
930
+ )
931
+
932
+ # Create patch loader
933
+ patch_loader = DataLoader(
934
+ grid_sampler,
935
+ batch_size=self.batch_size,
936
+ num_workers=0
937
+ )
938
+
939
+ # Predict patches
940
+ self.model.eval()
941
+ with torch.no_grad():
942
+ for patches_batch in patch_loader:
943
+ patch_input = patches_batch['image'][tio.DATA].to(self._device)
944
+ locations = patches_batch[tio.LOCATION]
945
+
946
+ # Forward pass - get logits
947
+ logits = self.model(patch_input)
948
+
949
+ # Convert logits to probabilities BEFORE aggregation
950
+ # This is critical: softmax is non-linear, so we must aggregate
951
+ # probabilities, not logits, to get correct boundary handling
952
+ n_classes = logits.shape[1]
953
+ if n_classes == 1:
954
+ probs = torch.sigmoid(logits)
955
+ else:
956
+ probs = torch.softmax(logits, dim=1) # dim=1 for batch [B, C, H, W, D]
957
+
958
+ # Add probabilities to aggregator
959
+ aggregator.add_batch(probs.cpu(), locations)
960
+
961
+ # Get reconstructed output (now contains probabilities, not logits)
962
+ output = aggregator.get_output_tensor()
963
+
964
+ # Convert to prediction mask (only if not returning probabilities)
965
+ if return_probabilities:
966
+ result = output # Keep as float probabilities
967
+ else:
968
+ n_classes = output.shape[0]
969
+ if n_classes == 1:
970
+ result = (output > 0.5).float()
971
+ else:
972
+ result = output.argmax(dim=0, keepdim=True).float()
973
+
974
+ # Apply keep_largest post-processing for binary segmentation
975
+ if not return_probabilities and self.config.keep_largest_component:
976
+ from fastMONAI.vision_inference import keep_largest
977
+ result = keep_largest(result.squeeze(0)).unsqueeze(0)
978
+
979
+ # Post-processing: resize back to original size and reorient
980
+ # This matches the workflow in vision_inference.py
981
+
982
+ # Wrap result in TorchIO Image for resizing
983
+ # Use ScalarImage for probabilities, LabelMap for masks
984
+ if return_probabilities:
985
+ pred_img = tio.ScalarImage(tensor=result.float(), affine=input_img.affine)
986
+ else:
987
+ pred_img = tio.LabelMap(tensor=result.float(), affine=input_img.affine)
988
+
989
+ # Resize back to original size (before resampling)
990
+ pred_img = _do_resize(pred_img, org_size, image_interpolation='nearest')
991
+
992
+ # Reorient to original orientation (if reorder was applied)
993
+ # Use explicit .cpu() for consistent device handling
994
+ if self.apply_reorder:
995
+ reoriented_array = _to_original_orientation(
996
+ pred_img.as_sitk(),
997
+ ('').join(org_img.orientation)
998
+ )
999
+ result = torch.from_numpy(reoriented_array).cpu()
1000
+ # Only convert to long for masks, not probabilities
1001
+ if not return_probabilities:
1002
+ result = result.long()
1003
+ else:
1004
+ result = pred_img.data.cpu()
1005
+ # Only convert to long for masks, not probabilities
1006
+ if not return_probabilities:
1007
+ result = result.long()
1008
+
1009
+ # Use original affine matrix for correct spatial alignment
1010
+ # org_img.affine is always available from med_img_reader
1011
+ if not (hasattr(org_img, 'affine') and org_img.affine is not None):
1012
+ raise RuntimeError(
1013
+ "org_img.affine not available. This should never happen - please report this bug."
1014
+ )
1015
+ affine = org_img.affine.copy()
1016
+
1017
+ if return_affine:
1018
+ return result, affine
1019
+ return result
1020
+
1021
+ def to(self, device):
1022
+ """Move engine to device."""
1023
+ self._device = device
1024
+ self.model.to(device)
1025
+ return self
1026
+
1027
+ # %% ../nbs/10_vision_patch.ipynb 21
1028
+ def patch_inference(
1029
+ learner,
1030
+ config: PatchConfig,
1031
+ file_paths: list,
1032
+ apply_reorder: bool = None,
1033
+ target_spacing: list = None,
1034
+ batch_size: int = 4,
1035
+ return_probabilities: bool = False,
1036
+ progress: bool = True,
1037
+ save_dir: str = None,
1038
+ pre_inference_tfms: list = None
1039
+ ) -> list:
1040
+ """Batch patch-based inference on multiple volumes.
1041
+
1042
+ Args:
1043
+ learner: PyTorch model or fastai Learner.
1044
+ config: PatchConfig with inference settings. Preprocessing params (apply_reorder,
1045
+ target_spacing) can be set here for DRY usage.
1046
+ file_paths: List of image paths.
1047
+ apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.
1048
+ target_spacing: Target voxel spacing. If None, uses config value.
1049
+ batch_size: Patches per batch.
1050
+ return_probabilities: Return probability maps.
1051
+ progress: Show progress bar.
1052
+ save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved.
1053
+ pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.
1054
+ IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).
1055
+
1056
+ Returns:
1057
+ List of predicted tensors.
1058
+
1059
+ Example:
1060
+ >>> # DRY pattern: use same config for training and inference
1061
+ >>> config = PatchConfig(
1062
+ ... patch_size=[96, 96, 96],
1063
+ ... apply_reorder=True,
1064
+ ... target_spacing=[0.4102, 0.4102, 1.5]
1065
+ ... )
1066
+ >>> predictions = patch_inference(
1067
+ ... learner=learn,
1068
+ ... config=config, # apply_reorder and target_spacing from config
1069
+ ... file_paths=val_paths,
1070
+ ... pre_inference_tfms=[tio.ZNormalization()],
1071
+ ... save_dir='predictions/patch_based'
1072
+ ... )
1073
+ """
1074
+ # Use config values if not explicitly provided
1075
+ _apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder
1076
+ _target_spacing = target_spacing if target_spacing is not None else config.target_spacing
1077
+
1078
+ engine = PatchInferenceEngine(
1079
+ learner, config, _apply_reorder, _target_spacing, batch_size, pre_inference_tfms
1080
+ )
1081
+
1082
+ # Create save directory if specified
1083
+ if save_dir is not None:
1084
+ save_path = Path(save_dir)
1085
+ save_path.mkdir(parents=True, exist_ok=True)
1086
+
1087
+ predictions = []
1088
+ iterator = tqdm(file_paths, desc='Patch inference') if progress else file_paths
1089
+
1090
+ for path in iterator:
1091
+ # Get prediction and affine when saving is needed
1092
+ if save_dir is not None:
1093
+ pred, affine = engine.predict(path, return_probabilities, return_affine=True)
1094
+ else:
1095
+ pred = engine.predict(path, return_probabilities)
1096
+ predictions.append(pred)
1097
+
1098
+ # Save prediction if save_dir specified
1099
+ if save_dir is not None:
1100
+ input_path = Path(path)
1101
+ # Create output filename based on input using suffix-based approach
1102
+ # This handles .nii.gz correctly without corrupting filenames with .nii elsewhere
1103
+ stem = input_path.stem
1104
+ if input_path.suffix == '.gz' and stem.endswith('.nii'):
1105
+ # Handle .nii.gz files: stem is "filename.nii", strip the .nii
1106
+ stem = stem[:-4]
1107
+ out_name = f"{stem}_pred.nii.gz"
1108
+ elif input_path.suffix == '.nii':
1109
+ # Handle .nii files
1110
+ out_name = f"{stem}_pred.nii"
1111
+ else:
1112
+ # Fallback for other formats
1113
+ out_name = f"{stem}_pred.nii.gz"
1114
+ out_path = save_path / out_name
1115
+
1116
+ # affine is guaranteed to be valid from engine.predict() with return_affine=True
1117
+ # Save as NIfTI using TorchIO with correct type
1118
+ # Use ScalarImage for probabilities (float), LabelMap for masks (int)
1119
+ if return_probabilities:
1120
+ pred_img = tio.ScalarImage(tensor=pred, affine=affine)
1121
+ else:
1122
+ pred_img = tio.LabelMap(tensor=pred, affine=affine)
1123
+ pred_img.save(out_path)
1124
+
1125
+ return predictions