fastMONAI 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1259 @@
1
+ """Patch-based training and inference for 3D medical image segmentation using TorchIO's Queue mechanism."""
2
+
3
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_vision_patch.ipynb.
4
+
5
+ # %% auto 0
6
+ __all__ = ['normalize_patch_transforms', 'PatchConfig', 'med_to_subject', 'create_subjects_dataset', 'create_patch_sampler',
7
+ 'MedPatchDataLoader', 'MedPatchDataLoaders', 'PatchInferenceEngine', 'patch_inference']
8
+
9
+ # %% ../nbs/10_vision_patch.ipynb 2
10
+ import torch
11
+ import torchio as tio
12
+ import pandas as pd
13
+ import numpy as np
14
+ import warnings
15
+ from pathlib import Path
16
+ from dataclasses import dataclass, field
17
+ from typing import Callable
18
+ from torch.utils.data import DataLoader
19
+ from tqdm.auto import tqdm
20
+ from fastai.data.all import *
21
+ from fastai.learner import Learner
22
+ from .vision_core import MedImage, MedMask, MedBase, med_img_reader
23
+ from .vision_inference import _to_original_orientation, _do_resize
24
+ from .dataset_info import MedDataset, suggest_patch_size
25
+
26
+ # %% ../nbs/10_vision_patch.ipynb 3
27
+ def _get_default_device() -> torch.device:
28
+ """Get the default device (CUDA if available, else CPU)."""
29
+ return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+
31
+
32
+ def _warn_config_override(param_name: str, config_value, explicit_value):
33
+ """Warn when explicit argument overrides config value.
34
+
35
+ Args:
36
+ param_name: Name of the parameter (e.g., 'apply_reorder', 'target_spacing')
37
+ config_value: Value from PatchConfig
38
+ explicit_value: Explicitly provided value
39
+ """
40
+ if explicit_value is not None and config_value is not None:
41
+ if explicit_value != config_value:
42
+ warnings.warn(
43
+ f"{param_name} mismatch: explicit={explicit_value}, config={config_value}. "
44
+ f"Using explicit argument."
45
+ )
46
+
47
+
48
+ def _extract_tio_transform(tfm):
49
+ """Extract TorchIO transform from fastMONAI wrapper or return as-is.
50
+
51
+ This function enables using fastMONAI wrappers (e.g., RandomAffine, RandomGamma)
52
+ in patch-based workflows where raw TorchIO transforms are needed for tio.Compose().
53
+
54
+ Uses the explicit `.tio_transform` property when available on fastMONAI wrappers.
55
+ Falls back to returning the transform unchanged for raw TorchIO transforms.
56
+
57
+ Args:
58
+ tfm: fastMONAI wrapper (e.g., RandomAffine) or raw TorchIO transform
59
+
60
+ Returns:
61
+ The underlying TorchIO transform
62
+
63
+ Example:
64
+ >>> from fastMONAI.vision_augmentation import RandomAffine
65
+ >>> wrapped = RandomAffine(degrees=10)
66
+ >>> raw = _extract_tio_transform(wrapped) # Returns tio.RandomAffine
67
+ """
68
+ return getattr(tfm, 'tio_transform', tfm)
69
+
70
+
71
+ def normalize_patch_transforms(tfms: list) -> list:
72
+ """Normalize transforms for patch-based workflow.
73
+
74
+ Extracts underlying TorchIO transforms from fastMONAI wrappers.
75
+ Also accepts raw TorchIO transforms for backward compatibility.
76
+
77
+ This enables using the same transform syntax in both standard and
78
+ patch-based workflows:
79
+
80
+ >>> from fastMONAI.vision_augmentation import RandomAffine, RandomGamma
81
+ >>>
82
+ >>> # Same syntax works in both contexts
83
+ >>> item_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)] # Standard
84
+ >>> patch_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)] # Patch-based
85
+
86
+ Args:
87
+ tfms: List of fastMONAI wrappers or raw TorchIO transforms
88
+
89
+ Returns:
90
+ List of raw TorchIO transforms suitable for tio.Compose()
91
+ """
92
+ if tfms is None:
93
+ return None
94
+ return [_extract_tio_transform(t) for t in tfms]
95
+
96
+ # %% ../nbs/10_vision_patch.ipynb 7
97
+ @dataclass
98
+ class PatchConfig:
99
+ """Configuration for patch-based training and inference.
100
+
101
+ Args:
102
+ patch_size: Size of patches [x, y, z].
103
+ patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list).
104
+ - Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap)
105
+ - Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap)
106
+ - List: per-dimension overlap in pixels
107
+ samples_per_volume: Number of patches to extract per volume during training.
108
+ sampler_type: Type of sampler ('uniform', 'label', 'weighted').
109
+ label_probabilities: For LabelSampler, dict mapping label values to probabilities.
110
+ queue_length: Maximum number of patches to store in queue.
111
+ queue_num_workers: Number of workers for parallel patch extraction.
112
+ aggregation_mode: For inference, how to combine overlapping patches ('crop', 'average', 'hann').
113
+ apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between
114
+ training and inference. Defaults to True (the common case).
115
+ target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between
116
+ training and inference.
117
+ padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding)
118
+ to align with nnU-Net's approach. Can be int, float, or string (e.g., 'minimum', 'mean').
119
+ keep_largest_component: If True, keep only the largest connected component
120
+ in binary segmentation predictions. Only applies during inference when
121
+ return_probabilities=False. Defaults to False.
122
+
123
+ Example:
124
+ >>> config = PatchConfig(
125
+ ... patch_size=[96, 96, 96],
126
+ ... samples_per_volume=16,
127
+ ... sampler_type='label',
128
+ ... label_probabilities={0: 0.1, 1: 0.9},
129
+ ... target_spacing=[0.5, 0.5, 0.5]
130
+ ... )
131
+ """
132
+ patch_size: list = field(default_factory=lambda: [96, 96, 96])
133
+ patch_overlap: int | float | list = 0
134
+ samples_per_volume: int = 8
135
+ sampler_type: str = 'uniform'
136
+ label_probabilities: dict = None
137
+ queue_length: int = 300
138
+ queue_num_workers: int = 4
139
+ aggregation_mode: str = 'hann'
140
+ # Preprocessing parameters - must match between training and inference
141
+ apply_reorder: bool = True # Defaults to True (the common case)
142
+ target_spacing: list = None
143
+ padding_mode: int | float | str = 0 # Zero padding (nnU-Net standard)
144
+ # Post-processing (binary segmentation only)
145
+ keep_largest_component: bool = False
146
+
147
+ def __post_init__(self):
148
+ """Validate configuration."""
149
+ valid_samplers = ['uniform', 'label', 'weighted']
150
+ if self.sampler_type not in valid_samplers:
151
+ raise ValueError(f"sampler_type must be one of {valid_samplers}")
152
+
153
+ valid_aggregation = ['crop', 'average', 'hann']
154
+ if self.aggregation_mode not in valid_aggregation:
155
+ raise ValueError(f"aggregation_mode must be one of {valid_aggregation}")
156
+
157
+ # Validate patch_overlap
158
+ # Negative overlap doesn't make sense
159
+ if isinstance(self.patch_overlap, (int, float)):
160
+ if self.patch_overlap < 0:
161
+ raise ValueError("patch_overlap cannot be negative")
162
+ # Check if overlap as pixels would exceed patch_size (causes step_size=0)
163
+ if self.patch_overlap >= 1: # Pixel value, not fraction
164
+ for ps in self.patch_size:
165
+ if self.patch_overlap >= ps:
166
+ raise ValueError(
167
+ f"patch_overlap ({self.patch_overlap}) must be less than patch_size ({ps}). "
168
+ f"Overlap >= patch_size creates step_size <= 0 (infinite patches)."
169
+ )
170
+ elif isinstance(self.patch_overlap, (list, tuple)):
171
+ for i, (overlap, ps) in enumerate(zip(self.patch_overlap, self.patch_size)):
172
+ if overlap < 0:
173
+ raise ValueError(f"patch_overlap[{i}] cannot be negative")
174
+ if overlap >= ps:
175
+ raise ValueError(
176
+ f"patch_overlap[{i}] ({overlap}) must be less than patch_size[{i}] ({ps}). "
177
+ f"Overlap >= patch_size creates step_size <= 0 (infinite patches)."
178
+ )
179
+
180
+ @classmethod
181
+ def from_dataset(
182
+ cls,
183
+ dataset: 'MedDataset',
184
+ target_spacing: list = None,
185
+ min_patch_size: list = None,
186
+ max_patch_size: list = None,
187
+ divisor: int = 16,
188
+ **kwargs
189
+ ) -> 'PatchConfig':
190
+ """Create PatchConfig with automatic patch_size from dataset analysis.
191
+
192
+ Combines dataset preprocessing suggestions with patch size calculation
193
+ for a complete, DRY configuration.
194
+
195
+ Args:
196
+ dataset: MedDataset instance with analyzed images.
197
+ target_spacing: Target voxel spacing [x, y, z]. If None, uses
198
+ dataset.get_suggestion()['target_spacing'].
199
+ min_patch_size: Minimum per dimension [32, 32, 32].
200
+ max_patch_size: Maximum per dimension [256, 256, 256].
201
+ divisor: Divisibility constraint (default 16 for UNet compatibility).
202
+ **kwargs: Additional PatchConfig parameters (samples_per_volume,
203
+ sampler_type, label_probabilities, etc.).
204
+
205
+ Returns:
206
+ PatchConfig with suggested patch_size, apply_reorder, target_spacing.
207
+
208
+ Example:
209
+ >>> from fastMONAI.dataset_info import MedDataset
210
+ >>> dataset = MedDataset(dataframe=df, mask_col='mask_path', dtype=MedMask)
211
+ >>>
212
+ >>> # Use recommended spacing
213
+ >>> config = PatchConfig.from_dataset(dataset, samples_per_volume=16)
214
+ >>>
215
+ >>> # Use custom spacing
216
+ >>> config = PatchConfig.from_dataset(
217
+ ... dataset,
218
+ ... target_spacing=[1.0, 1.0, 2.0],
219
+ ... samples_per_volume=16
220
+ ... )
221
+ """
222
+ # Get preprocessing suggestion from dataset
223
+ suggestion = dataset.get_suggestion()
224
+
225
+ # Use explicit spacing or dataset suggestion
226
+ _target_spacing = target_spacing if target_spacing is not None else suggestion['target_spacing']
227
+
228
+ # Calculate patch size for the target spacing
229
+ patch_size = suggest_patch_size(
230
+ dataset,
231
+ target_spacing=_target_spacing,
232
+ min_patch_size=min_patch_size,
233
+ max_patch_size=max_patch_size,
234
+ divisor=divisor
235
+ )
236
+
237
+ # Merge with explicit kwargs (kwargs override defaults)
238
+ # Use dataset.apply_reorder directly (not from get_suggestion() since it's not data-derived)
239
+ config_kwargs = {
240
+ 'patch_size': patch_size,
241
+ 'apply_reorder': dataset.apply_reorder,
242
+ 'target_spacing': _target_spacing,
243
+ }
244
+ config_kwargs.update(kwargs)
245
+
246
+ return cls(**config_kwargs)
247
+
248
+ # %% ../nbs/10_vision_patch.ipynb 10
249
+ def med_to_subject(
250
+ img: Path | str,
251
+ mask: Path | str = None,
252
+ ) -> tio.Subject:
253
+ """Create TorchIO Subject with LAZY loading (paths only, no tensor loading).
254
+
255
+ This function stores file paths in the Subject, allowing TorchIO's Queue
256
+ workers to load volumes on-demand during training. This is memory-efficient
257
+ as volumes are not loaded into RAM until needed.
258
+
259
+ Args:
260
+ img: Path to image file.
261
+ mask: Path to mask file (optional).
262
+
263
+ Returns:
264
+ TorchIO Subject with 'image' and optionally 'mask' keys (lazy loaded).
265
+
266
+ Example:
267
+ >>> subject = med_to_subject('image.nii.gz', 'mask.nii.gz')
268
+ >>> # Volume NOT loaded yet - only path stored
269
+ >>> data = subject['image'].data # NOW volume is loaded
270
+ """
271
+ subject_dict = {
272
+ 'image': tio.ScalarImage(path=str(img)) # Lazy - stores path only
273
+ }
274
+
275
+ if mask is not None:
276
+ subject_dict['mask'] = tio.LabelMap(path=str(mask)) # Lazy
277
+
278
+ return tio.Subject(**subject_dict)
279
+
280
+ # %% ../nbs/10_vision_patch.ipynb 11
281
+ def create_subjects_dataset(
282
+ df: pd.DataFrame,
283
+ img_col: str,
284
+ mask_col: str = None,
285
+ pre_tfms: list = None,
286
+ ensure_affine_consistency: bool = True
287
+ ) -> tio.SubjectsDataset:
288
+ """Build TorchIO SubjectsDataset with LAZY loading from DataFrame.
289
+
290
+ This function creates a SubjectsDataset that stores only file paths,
291
+ not loaded tensors. Volumes are loaded on-demand by Queue workers,
292
+ keeping memory usage constant regardless of dataset size.
293
+
294
+ Args:
295
+ df: DataFrame with image (and optionally mask) paths.
296
+ img_col: Column name containing image paths.
297
+ mask_col: Column name containing mask paths (optional).
298
+ pre_tfms: List of TorchIO transforms to apply before patch extraction.
299
+ Use tio.ToCanonical() for reordering and tio.Resample() for resampling.
300
+ ensure_affine_consistency: If True and mask_col is provided, automatically
301
+ prepends tio.CopyAffine(target='image') to ensure spatial metadata
302
+ consistency between image and mask. This prevents "More than one value
303
+ for direction found" errors. Defaults to True.
304
+
305
+ Returns:
306
+ TorchIO SubjectsDataset with lazy-loaded subjects.
307
+
308
+ Example:
309
+ >>> # Preprocessing via transforms (applied by workers on-demand)
310
+ >>> pre_tfms = [
311
+ ... tio.ToCanonical(), # Reorder to RAS+
312
+ ... tio.Resample([0.5, 0.5, 0.5]), # Resample
313
+ ... tio.ZNormalization(), # Intensity normalization
314
+ ... ]
315
+ >>> dataset = create_subjects_dataset(
316
+ ... df, img_col='image', mask_col='label',
317
+ ... pre_tfms=pre_tfms
318
+ ... )
319
+ >>> # Memory: ~0 MB (only paths stored, not volumes)
320
+ """
321
+ subjects = []
322
+ for idx, row in df.iterrows():
323
+ img_path = row[img_col]
324
+ mask_path = row[mask_col] if mask_col else None
325
+
326
+ # Create subject with lazy loading (paths only)
327
+ subject = med_to_subject(img=img_path, mask=mask_path)
328
+ subjects.append(subject)
329
+
330
+ # Build transform pipeline
331
+ all_transforms = []
332
+
333
+ # Add CopyAffine as FIRST transform when mask is present
334
+ # This ensures spatial metadata consistency before other transforms
335
+ if mask_col is not None and ensure_affine_consistency:
336
+ all_transforms.append(tio.CopyAffine(target='image'))
337
+
338
+ # Add user-provided transforms
339
+ if pre_tfms:
340
+ all_transforms.extend(pre_tfms)
341
+
342
+ transform = tio.Compose(all_transforms) if all_transforms else None
343
+
344
+ return tio.SubjectsDataset(subjects, transform=transform)
345
+
346
+ # %% ../nbs/10_vision_patch.ipynb 13
347
+ def create_patch_sampler(config: PatchConfig) -> tio.data.PatchSampler:
348
+ """Create appropriate TorchIO sampler based on config.
349
+
350
+ Args:
351
+ config: PatchConfig with sampler settings.
352
+
353
+ Returns:
354
+ TorchIO PatchSampler instance.
355
+
356
+ Example:
357
+ >>> config = PatchConfig(patch_size=[96, 96, 96], sampler_type='label')
358
+ >>> sampler = create_patch_sampler(config)
359
+ """
360
+ patch_size = config.patch_size
361
+
362
+ if config.sampler_type == 'uniform':
363
+ return tio.UniformSampler(patch_size)
364
+
365
+ elif config.sampler_type == 'label':
366
+ return tio.LabelSampler(
367
+ patch_size,
368
+ label_name='mask',
369
+ label_probabilities=config.label_probabilities
370
+ )
371
+
372
+ elif config.sampler_type == 'weighted':
373
+ raise NotImplementedError(
374
+ "WeightedSampler requires a pre-computed probability map which is not currently supported. "
375
+ "Use 'label' sampler with label_probabilities for weighted sampling based on segmentation labels, "
376
+ "or 'uniform' for random patch extraction."
377
+ )
378
+
379
+ raise ValueError(f"Unknown sampler type: {config.sampler_type}")
380
+
381
+ # %% ../nbs/10_vision_patch.ipynb 17
382
+ class MedPatchDataLoader:
383
+ """DataLoader wrapper for patch-based training with TorchIO Queue.
384
+
385
+ This class wraps a TorchIO Queue to provide a fastai-compatible DataLoader
386
+ interface for patch-based training.
387
+
388
+ Args:
389
+ subjects_dataset: TorchIO SubjectsDataset.
390
+ config: PatchConfig with queue and sampler settings.
391
+ batch_size: Number of patches per batch. Must be positive.
392
+ patch_tfms: Transforms to apply to extracted patches (training only).
393
+ Accepts both fastMONAI wrappers (e.g., RandomAffine, RandomGamma) and
394
+ raw TorchIO transforms. fastMONAI wrappers are automatically normalized
395
+ to raw TorchIO for internal use.
396
+ shuffle: Whether to shuffle subjects and patches.
397
+ drop_last: Whether to drop last incomplete batch.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ subjects_dataset: tio.SubjectsDataset,
403
+ config: PatchConfig,
404
+ batch_size: int = 4,
405
+ patch_tfms: list = None,
406
+ shuffle: bool = True,
407
+ drop_last: bool = False
408
+ ):
409
+ if batch_size <= 0:
410
+ raise ValueError(f"batch_size must be positive, got {batch_size}")
411
+
412
+ self.subjects_dataset = subjects_dataset
413
+ self.config = config
414
+ self.bs = batch_size
415
+ self.shuffle = shuffle
416
+ self.drop_last = drop_last
417
+ self._device = _get_default_device()
418
+
419
+ # Create sampler
420
+ self.sampler = create_patch_sampler(config)
421
+
422
+ # Create patch transforms
423
+ # Normalize transforms - accepts both fastMONAI wrappers and raw TorchIO
424
+ normalized_tfms = normalize_patch_transforms(patch_tfms)
425
+ self.patch_tfms = tio.Compose(normalized_tfms) if normalized_tfms else None
426
+
427
+ # Create queue
428
+ self.queue = tio.Queue(
429
+ subjects_dataset,
430
+ max_length=config.queue_length,
431
+ samples_per_volume=config.samples_per_volume,
432
+ sampler=self.sampler,
433
+ num_workers=config.queue_num_workers,
434
+ shuffle_subjects=shuffle,
435
+ shuffle_patches=shuffle
436
+ )
437
+
438
+ # Create torch DataLoader
439
+ self._dl = DataLoader(
440
+ self.queue,
441
+ batch_size=batch_size,
442
+ num_workers=0, # Queue handles workers
443
+ drop_last=drop_last
444
+ )
445
+
446
+ # Track cleanup state
447
+ self._closed = False
448
+
449
+ def __iter__(self):
450
+ """Iterate over batches, yielding (image, mask) tuples."""
451
+ for batch in self._dl:
452
+ # Extract image and mask tensors
453
+ img = batch['image'][tio.DATA] # [B, C, H, W, D]
454
+ has_mask = 'mask' in batch
455
+
456
+ # Apply patch transforms if provided
457
+ if self.patch_tfms is not None:
458
+ # Apply transforms to each sample in batch
459
+ transformed_imgs = []
460
+ transformed_masks = [] if has_mask else None
461
+
462
+ for i in range(img.shape[0]):
463
+ # Build subject dict with image, and mask if available
464
+ subject_dict = {'image': tio.ScalarImage(tensor=batch['image'][tio.DATA][i])}
465
+ if has_mask:
466
+ subject_dict['mask'] = tio.LabelMap(tensor=batch['mask'][tio.DATA][i])
467
+
468
+ subject = tio.Subject(subject_dict)
469
+ transformed = self.patch_tfms(subject)
470
+ transformed_imgs.append(transformed['image'].data)
471
+ if has_mask:
472
+ transformed_masks.append(transformed['mask'].data)
473
+
474
+ img = torch.stack(transformed_imgs)
475
+ mask = torch.stack(transformed_masks) if has_mask else None
476
+ else:
477
+ mask = batch['mask'][tio.DATA] if has_mask else None
478
+
479
+ # Convert to MedImage/MedMask and move to device
480
+ img = MedImage(img).to(self._device)
481
+ if mask is not None:
482
+ mask = MedMask(mask).to(self._device)
483
+
484
+ yield img, mask
485
+
486
+ def __len__(self):
487
+ """Return number of batches per epoch."""
488
+ n_patches = len(self.subjects_dataset) * self.config.samples_per_volume
489
+ if self.drop_last:
490
+ return n_patches // self.bs
491
+ return (n_patches + self.bs - 1) // self.bs
492
+
493
+ @property
494
+ def dataset(self):
495
+ """Return the underlying queue as dataset."""
496
+ return self.queue
497
+
498
+ @property
499
+ def device(self):
500
+ """Return current device."""
501
+ return self._device
502
+
503
+ def to(self, device):
504
+ """Move DataLoader to device."""
505
+ self._device = device
506
+ return self
507
+
508
+ def one_batch(self):
509
+ """Return one batch from the DataLoader.
510
+
511
+ Required for fastai compatibility - used for device detection
512
+ and batch shape validation during Learner initialization.
513
+
514
+ Returns:
515
+ Tuple of (image, mask) tensors on the correct device.
516
+ """
517
+ return next(iter(self))
518
+
519
+ def close(self):
520
+ """Shut down TorchIO Queue workers. Safe to call multiple times."""
521
+ if self._closed:
522
+ return
523
+ self._closed = True
524
+ try:
525
+ # Trigger cleanup of Queue's internal DataLoader iterator
526
+ if hasattr(self, 'queue') and hasattr(self.queue, '_subjects_iterable'):
527
+ self.queue._subjects_iterable = None
528
+ except Exception:
529
+ pass # Suppress cleanup errors
530
+
531
+ def __enter__(self):
532
+ return self
533
+
534
+ def __exit__(self, exc_type, exc_val, exc_tb):
535
+ self.close()
536
+ return False
537
+
538
+ def __del__(self):
539
+ try:
540
+ self.close()
541
+ except Exception:
542
+ pass
543
+
544
+ # %% ../nbs/10_vision_patch.ipynb 18
545
+ class MedPatchDataLoaders:
546
+ """fastai-compatible DataLoaders for patch-based training with LAZY loading.
547
+
548
+ This class provides train and validation DataLoaders that work with
549
+ fastai's Learner for patch-based training on 3D medical images.
550
+
551
+ Memory-efficient: Volumes are loaded on-demand by Queue workers,
552
+ keeping memory usage constant (~150 MB) regardless of dataset size.
553
+
554
+ Note: Validation uses the same sampling as training (pseudo Dice).
555
+ For true validation metrics, use PatchInferenceEngine with GridSampler
556
+ for full-volume sliding window inference.
557
+
558
+ Example:
559
+ >>> import torchio as tio
560
+ >>>
561
+ >>> # New pattern: preprocessing params in config (DRY)
562
+ >>> config = PatchConfig(
563
+ ... patch_size=[96, 96, 96],
564
+ ... apply_reorder=True,
565
+ ... target_spacing=[0.5, 0.5, 0.5]
566
+ ... )
567
+ >>> dls = MedPatchDataLoaders.from_df(
568
+ ... df, img_col='image', mask_col='label',
569
+ ... valid_pct=0.2,
570
+ ... patch_config=config,
571
+ ... pre_patch_tfms=[tio.ZNormalization()],
572
+ ... bs=4
573
+ ... )
574
+ >>> learn = Learner(dls, model, loss_func=DiceLoss())
575
+ """
576
+
577
+ def __init__(
578
+ self,
579
+ train_dl: MedPatchDataLoader,
580
+ valid_dl: MedPatchDataLoader,
581
+ device: torch.device = None
582
+ ):
583
+ self._train_dl = train_dl
584
+ self._valid_dl = valid_dl
585
+ self._device = device or _get_default_device()
586
+
587
+ # Move to device
588
+ self._train_dl.to(self._device)
589
+ self._valid_dl.to(self._device)
590
+
591
+ # Track cleanup state
592
+ self._closed = False
593
+
594
+ @classmethod
595
+ def from_df(
596
+ cls,
597
+ df: pd.DataFrame,
598
+ img_col: str,
599
+ mask_col: str = None,
600
+ valid_pct: float = 0.2,
601
+ valid_col: str = None,
602
+ patch_config: PatchConfig = None,
603
+ pre_patch_tfms: list = None,
604
+ patch_tfms: list = None,
605
+ apply_reorder: bool = None,
606
+ target_spacing: list = None,
607
+ bs: int = 4,
608
+ seed: int = None,
609
+ device: torch.device = None,
610
+ ensure_affine_consistency: bool = True
611
+ ) -> 'MedPatchDataLoaders':
612
+ """Create train/valid DataLoaders from DataFrame with LAZY loading.
613
+
614
+ Memory-efficient: Only file paths are stored at creation time.
615
+ Volumes are loaded on-demand by Queue workers during training.
616
+
617
+ Note: Both train and valid use the same sampling strategy from patch_config.
618
+ This gives pseudo Dice during training. For true validation metrics,
619
+ use PatchInferenceEngine with full-volume sliding window inference.
620
+
621
+ Args:
622
+ df: DataFrame with image paths.
623
+ img_col: Column name for image paths.
624
+ mask_col: Column name for mask paths.
625
+ valid_pct: Fraction of data for validation.
626
+ valid_col: Column name for train/valid split (if pre-defined).
627
+ patch_config: PatchConfig instance. Preprocessing params (apply_reorder,
628
+ target_spacing) can be set here for DRY usage with PatchInferenceEngine.
629
+ pre_patch_tfms: TorchIO transforms applied before patch extraction
630
+ (after reorder/resample). Example: [tio.ZNormalization()].
631
+ Accepts both fastMONAI wrappers and raw TorchIO transforms.
632
+ patch_tfms: TorchIO transforms applied to extracted patches (training only).
633
+ apply_reorder: If True, reorder to RAS+ orientation. If None, uses
634
+ patch_config.apply_reorder. Explicit value overrides config.
635
+ target_spacing: Target voxel spacing [x, y, z]. If None, uses
636
+ patch_config.target_spacing. Explicit value overrides config.
637
+ bs: Batch size.
638
+ seed: Random seed for splitting.
639
+ device: Device to use.
640
+ ensure_affine_consistency: If True and mask_col is provided, automatically
641
+ adds tio.CopyAffine(target='image') as the first transform to prevent
642
+ spatial metadata mismatch errors. Defaults to True.
643
+
644
+ Returns:
645
+ MedPatchDataLoaders instance.
646
+
647
+ Example:
648
+ >>> # New pattern: config contains preprocessing params
649
+ >>> config = PatchConfig(
650
+ ... patch_size=[96, 96, 96],
651
+ ... apply_reorder=True,
652
+ ... target_spacing=[0.5, 0.5, 0.5],
653
+ ... label_probabilities={0: 0.1, 1: 0.9}
654
+ ... )
655
+ >>> dls = MedPatchDataLoaders.from_df(
656
+ ... df, img_col='image', mask_col='label',
657
+ ... patch_config=config,
658
+ ... pre_patch_tfms=[tio.ZNormalization()],
659
+ ... patch_tfms=[tio.RandomAffine(degrees=10), tio.RandomFlip()],
660
+ ... bs=4
661
+ ... )
662
+ >>> # Memory: ~150 MB (queue buffer only)
663
+ """
664
+ if patch_config is None:
665
+ patch_config = PatchConfig()
666
+
667
+ # Use config values, allow explicit overrides for backward compatibility
668
+ _apply_reorder = apply_reorder if apply_reorder is not None else patch_config.apply_reorder
669
+ _target_spacing = target_spacing if target_spacing is not None else patch_config.target_spacing
670
+
671
+ # Warn if both config and explicit args provided with different values
672
+ _warn_config_override('apply_reorder', patch_config.apply_reorder, apply_reorder)
673
+ _warn_config_override('target_spacing', patch_config.target_spacing, target_spacing)
674
+
675
+ # Split data
676
+ if valid_col is not None:
677
+ train_df = df[df[valid_col] == False].reset_index(drop=True)
678
+ valid_df = df[df[valid_col] == True].reset_index(drop=True)
679
+ else:
680
+ if seed is not None:
681
+ np.random.seed(seed)
682
+ n = len(df)
683
+ valid_idx = np.random.choice(n, size=int(n * valid_pct), replace=False)
684
+ train_idx = np.setdiff1d(np.arange(n), valid_idx)
685
+ train_df = df.iloc[train_idx].reset_index(drop=True)
686
+ valid_df = df.iloc[valid_idx].reset_index(drop=True)
687
+
688
+ # Build preprocessing transforms
689
+ all_pre_tfms = []
690
+
691
+ # Add reorder transform (reorder to RAS+ orientation)
692
+ if _apply_reorder:
693
+ all_pre_tfms.append(tio.ToCanonical())
694
+
695
+ # Add resample transform
696
+ if _target_spacing is not None:
697
+ all_pre_tfms.append(tio.Resample(_target_spacing))
698
+
699
+ # Add user-provided transforms (normalize to raw TorchIO transforms)
700
+ if pre_patch_tfms:
701
+ all_pre_tfms.extend(normalize_patch_transforms(pre_patch_tfms))
702
+
703
+ # Create subjects datasets with lazy loading (paths only, ~0 MB)
704
+ train_subjects = create_subjects_dataset(
705
+ train_df, img_col, mask_col,
706
+ pre_tfms=all_pre_tfms if all_pre_tfms else None,
707
+ ensure_affine_consistency=ensure_affine_consistency
708
+ )
709
+ valid_subjects = create_subjects_dataset(
710
+ valid_df, img_col, mask_col,
711
+ pre_tfms=all_pre_tfms if all_pre_tfms else None,
712
+ ensure_affine_consistency=ensure_affine_consistency
713
+ )
714
+
715
+ # Create DataLoaders (both use same patch_config for consistent sampling)
716
+ train_dl = MedPatchDataLoader(
717
+ train_subjects, patch_config, bs,
718
+ patch_tfms=patch_tfms, shuffle=True, drop_last=True
719
+ )
720
+ valid_dl = MedPatchDataLoader(
721
+ valid_subjects, patch_config, bs,
722
+ patch_tfms=None, # No augmentation for validation
723
+ shuffle=False, drop_last=False
724
+ )
725
+
726
+ # Create instance and store metadata
727
+ instance = cls(train_dl, valid_dl, device)
728
+ instance._img_col = img_col
729
+ instance._mask_col = mask_col
730
+ instance._pre_patch_tfms = pre_patch_tfms
731
+ instance._apply_reorder = _apply_reorder
732
+ instance._target_spacing = _target_spacing
733
+ instance._ensure_affine_consistency = ensure_affine_consistency
734
+ instance._patch_config = patch_config
735
+ return instance
736
+
737
+ @property
738
+ def train(self):
739
+ """Training DataLoader."""
740
+ return self._train_dl
741
+
742
+ @property
743
+ def valid(self):
744
+ """Validation DataLoader."""
745
+ return self._valid_dl
746
+
747
+ @property
748
+ def train_ds(self):
749
+ """Training subjects dataset."""
750
+ return self._train_dl.subjects_dataset
751
+
752
+ @property
753
+ def valid_ds(self):
754
+ """Validation subjects dataset."""
755
+ return self._valid_dl.subjects_dataset
756
+
757
+ @property
758
+ def device(self):
759
+ """Current device."""
760
+ return self._device
761
+
762
+ @property
763
+ def bs(self):
764
+ """Batch size."""
765
+ return self._train_dl.bs
766
+
767
+ @property
768
+ def apply_reorder(self):
769
+ """Whether reordering to RAS+ is enabled."""
770
+ return getattr(self, '_apply_reorder', False)
771
+
772
+ @property
773
+ def target_spacing(self):
774
+ """Target voxel spacing for resampling."""
775
+ return getattr(self, '_target_spacing', None)
776
+
777
+ @property
778
+ def patch_config(self):
779
+ """The PatchConfig used for this DataLoaders."""
780
+ return getattr(self, '_patch_config', None)
781
+
782
+ def to(self, device):
783
+ """Move DataLoaders to device."""
784
+ self._device = device
785
+ self._train_dl.to(device)
786
+ self._valid_dl.to(device)
787
+ return self
788
+
789
+ def __iter__(self):
790
+ """Iterate over training DataLoader."""
791
+ return iter(self._train_dl)
792
+
793
+ def one_batch(self):
794
+ """Return one batch from the training DataLoader.
795
+
796
+ Required for fastai Learner compatibility - used for device
797
+ detection and batch shape validation.
798
+ """
799
+ return self._train_dl.one_batch()
800
+
801
+ def __len__(self):
802
+ """Return number of batches in training DataLoader."""
803
+ return len(self._train_dl)
804
+
805
+ def __getitem__(self, idx):
806
+ """Get DataLoader by index. Required for fastai Learner compatibility.
807
+
808
+ Args:
809
+ idx: 0 for training DataLoader, 1 for validation DataLoader.
810
+
811
+ Returns:
812
+ MedPatchDataLoader instance.
813
+ """
814
+ if idx == 0:
815
+ return self._train_dl
816
+ elif idx == 1:
817
+ return self._valid_dl
818
+ else:
819
+ raise IndexError(f"Index {idx} out of range. Use 0 (train) or 1 (valid).")
820
+
821
+ def cuda(self):
822
+ """Move DataLoaders to CUDA device."""
823
+ return self.to(torch.device('cuda'))
824
+
825
+ def cpu(self):
826
+ """Move DataLoaders to CPU."""
827
+ return self.to(torch.device('cpu'))
828
+
829
+ def new_empty(self):
830
+ """Create a new empty version of self for learner export.
831
+
832
+ Required for fastai Learner.export() compatibility - creates a
833
+ lightweight placeholder that can be pickled without the full dataset.
834
+
835
+ Returns:
836
+ A minimal MedPatchDataLoaders-like object with no data.
837
+ """
838
+ class EmptyMedPatchDataLoaders:
839
+ """Minimal placeholder for exported learner."""
840
+ def __init__(self, device):
841
+ self._device = device
842
+ @property
843
+ def device(self): return self._device
844
+ def to(self, device):
845
+ self._device = device
846
+ return self
847
+ def cpu(self):
848
+ """Move to CPU. Required for load_learner compatibility."""
849
+ return self.to(torch.device('cpu'))
850
+
851
+ return EmptyMedPatchDataLoaders(self._device)
852
+
853
+ def close(self):
854
+ """Shut down all DataLoader workers. Safe to call multiple times."""
855
+ if self._closed:
856
+ return
857
+ self._closed = True
858
+ if hasattr(self, '_train_dl') and self._train_dl is not None:
859
+ self._train_dl.close()
860
+ if hasattr(self, '_valid_dl') and self._valid_dl is not None:
861
+ self._valid_dl.close()
862
+
863
+ def __enter__(self):
864
+ return self
865
+
866
+ def __exit__(self, exc_type, exc_val, exc_tb):
867
+ self.close()
868
+ return False
869
+
870
+ def __del__(self):
871
+ try:
872
+ self.close()
873
+ except Exception:
874
+ pass
875
+
876
+ # %% ../nbs/10_vision_patch.ipynb 20
877
+ import numbers
878
+
879
+ def _normalize_patch_overlap(patch_overlap, patch_size):
880
+ """Convert patch_overlap to integer pixel values for TorchIO compatibility.
881
+
882
+ TorchIO's GridSampler expects patch_overlap as a tuple of even integers.
883
+ This function handles:
884
+ - Fractional overlap (0-1): converted to pixel values based on patch_size
885
+ - Numpy scalar types: converted to native Python types
886
+ - Sequences: converted to tuple of integers
887
+
888
+ Note: Input validation (negative values, overlap >= patch_size) is handled
889
+ by PatchConfig.__post_init__(). This function focuses on format conversion.
890
+
891
+ Args:
892
+ patch_overlap: int, float (0-1 for fraction), or sequence
893
+ patch_size: list/tuple of patch dimensions [x, y, z]
894
+
895
+ Returns:
896
+ Tuple of even integers suitable for TorchIO GridSampler
897
+ """
898
+ # Handle scalar fractional overlap (0 < x < 1)
899
+ # Note: excludes 1.0 as 100% overlap creates step_size=0 (infinite patches)
900
+ if isinstance(patch_overlap, (int, float, numbers.Number)) and 0 < float(patch_overlap) < 1:
901
+ # Convert fraction to pixel values, ensure even
902
+ result = []
903
+ for ps in patch_size:
904
+ pixels = int(int(ps) * float(patch_overlap))
905
+ # Ensure even (required by TorchIO)
906
+ if pixels % 2 != 0:
907
+ pixels = pixels - 1 if pixels > 0 else 0
908
+ result.append(pixels)
909
+ return tuple(result)
910
+
911
+ # Handle scalar integer (including numpy scalars) - values > 1 are pixel counts
912
+ if isinstance(patch_overlap, (int, float, numbers.Number)):
913
+ val = int(patch_overlap)
914
+ # Ensure even
915
+ if val % 2 != 0:
916
+ val = val - 1 if val > 0 else 0
917
+ return tuple(val for _ in patch_size)
918
+
919
+ # Handle sequences (list, tuple, ndarray)
920
+ result = []
921
+ for val in patch_overlap:
922
+ pixels = int(val)
923
+ if pixels % 2 != 0:
924
+ pixels = pixels - 1 if pixels > 0 else 0
925
+ result.append(pixels)
926
+ return tuple(result)
927
+
928
+
929
+ class PatchInferenceEngine:
930
+ """Patch-based inference with automatic volume reconstruction.
931
+
932
+ Uses TorchIO's GridSampler to extract overlapping patches and
933
+ GridAggregator to reconstruct the full volume from predictions.
934
+
935
+ Args:
936
+ learner: fastai Learner or PyTorch model (nn.Module). When passing a raw
937
+ PyTorch model, load weights first with model.load_state_dict().
938
+ config: PatchConfig with inference settings. Preprocessing params (apply_reorder,
939
+ target_spacing, padding_mode) can be set here for DRY usage.
940
+ apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.
941
+ target_spacing: Target voxel spacing. If None, uses config value.
942
+ batch_size: Number of patches to predict at once. Must be positive.
943
+ pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.
944
+ IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).
945
+ This ensures preprocessing consistency between training and inference.
946
+ Accepts both fastMONAI wrappers and raw TorchIO transforms.
947
+
948
+ Example:
949
+ >>> # Option 1: From fastai Learner
950
+ >>> engine = PatchInferenceEngine(learn, config, pre_inference_tfms=[ZNormalization()])
951
+ >>> pred = engine.predict('image.nii.gz')
952
+
953
+ >>> # Option 2: From raw PyTorch model (recommended for deployment)
954
+ >>> model = UNet(spatial_dims=3, in_channels=1, out_channels=2, ...)
955
+ >>> model.load_state_dict(torch.load('weights.pth'))
956
+ >>> model.cuda().eval()
957
+ >>> engine = PatchInferenceEngine(model, config, pre_inference_tfms=[ZNormalization()])
958
+ >>> pred = engine.predict('image.nii.gz')
959
+ """
960
+
961
+ def __init__(
962
+ self,
963
+ learner,
964
+ config: PatchConfig,
965
+ apply_reorder: bool = None,
966
+ target_spacing: list = None,
967
+ batch_size: int = 4,
968
+ pre_inference_tfms: list = None
969
+ ):
970
+ if batch_size <= 0:
971
+ raise ValueError(f"batch_size must be positive, got {batch_size}")
972
+
973
+ # Extract model from Learner if needed (use isinstance for robust detection)
974
+ # Note: We check for Learner explicitly because some models (e.g., MONAI UNet)
975
+ # have a .model attribute that is NOT the full model but an internal Sequential.
976
+ if isinstance(learner, Learner):
977
+ self.model = learner.model
978
+ else:
979
+ self.model = learner # Assume it's already a PyTorch model
980
+
981
+ self.config = config
982
+ self.batch_size = batch_size
983
+
984
+ # Normalize transforms to raw TorchIO (accepts both fastMONAI wrappers and raw TorchIO)
985
+ normalized_tfms = normalize_patch_transforms(pre_inference_tfms)
986
+ self.pre_inference_tfms = tio.Compose(normalized_tfms) if normalized_tfms else None
987
+
988
+ # Use config values, allow explicit overrides for backward compatibility
989
+ self.apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder
990
+ self.target_spacing = target_spacing if target_spacing is not None else config.target_spacing
991
+
992
+ # Warn if explicit args provided but differ from config (potential mistake)
993
+ _warn_config_override('apply_reorder', config.apply_reorder, apply_reorder)
994
+ _warn_config_override('target_spacing', config.target_spacing, target_spacing)
995
+
996
+ # Get device from model parameters, with fallback for parameter-less models
997
+ try:
998
+ self._device = next(self.model.parameters()).device
999
+ except StopIteration:
1000
+ self._device = _get_default_device()
1001
+
1002
+ def predict(
1003
+ self,
1004
+ img_path: Path | str,
1005
+ return_probabilities: bool = False,
1006
+ return_affine: bool = False
1007
+ ) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
1008
+ """Predict on a single volume using patch-based inference.
1009
+
1010
+ Args:
1011
+ img_path: Path to input image.
1012
+ return_probabilities: If True, return probability map instead of argmax.
1013
+ return_affine: If True, return (prediction, affine) tuple instead of just prediction.
1014
+
1015
+ Returns:
1016
+ Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True.
1017
+ """
1018
+ # Load image - keep org_img and org_size for post-processing
1019
+ # Note: med_img_reader handles reorder/resample internally, no global state needed
1020
+ org_img, input_img, org_size = med_img_reader(
1021
+ img_path, apply_reorder=self.apply_reorder, target_spacing=self.target_spacing, only_tensor=False
1022
+ )
1023
+
1024
+ # Create TorchIO Subject from preprocessed image
1025
+ subject = tio.Subject(
1026
+ image=tio.ScalarImage(tensor=input_img.data.float(), affine=input_img.affine)
1027
+ )
1028
+
1029
+ # Apply pre-inference transforms (e.g., ZNormalization) to match training
1030
+ if self.pre_inference_tfms is not None:
1031
+ subject = self.pre_inference_tfms(subject)
1032
+
1033
+ # Pad dimensions smaller than patch_size, keep larger dimensions intact
1034
+ # GridSampler handles large images via overlapping patches
1035
+ img_shape = subject['image'].shape[1:] # Exclude channel dim
1036
+ target_size = [max(s, p) for s, p in zip(img_shape, self.config.patch_size)]
1037
+
1038
+ # Warn if volume needed padding (may cause artifacts if training didn't cover similar sizes)
1039
+ if any(s < p for s, p in zip(img_shape, self.config.patch_size)):
1040
+ padded_dims = [f"dim{i}: {s}<{p}" for i, (s, p) in enumerate(zip(img_shape, self.config.patch_size)) if s < p]
1041
+ warnings.warn(
1042
+ f"Image size {list(img_shape)} smaller than patch_size {self.config.patch_size} "
1043
+ f"in {padded_dims}. Padding with mode={self.config.padding_mode}. "
1044
+ "Ensure training data covered similar sizes to avoid artifacts."
1045
+ )
1046
+
1047
+ # Use padding_mode from config (default: 0 for zero padding, nnU-Net standard)
1048
+ subject = tio.CropOrPad(target_size, padding_mode=self.config.padding_mode)(subject)
1049
+
1050
+ # Convert patch_overlap to integer pixel values for TorchIO compatibility
1051
+ patch_overlap = _normalize_patch_overlap(self.config.patch_overlap, self.config.patch_size)
1052
+
1053
+ # Create GridSampler
1054
+ grid_sampler = tio.GridSampler(
1055
+ subject,
1056
+ patch_size=self.config.patch_size,
1057
+ patch_overlap=patch_overlap
1058
+ )
1059
+
1060
+ # Create GridAggregator
1061
+ aggregator = tio.GridAggregator(
1062
+ grid_sampler,
1063
+ overlap_mode=self.config.aggregation_mode
1064
+ )
1065
+
1066
+ # Create patch loader
1067
+ patch_loader = DataLoader(
1068
+ grid_sampler,
1069
+ batch_size=self.batch_size,
1070
+ num_workers=0
1071
+ )
1072
+
1073
+ # Predict patches
1074
+ self.model.eval()
1075
+ with torch.no_grad():
1076
+ for patches_batch in patch_loader:
1077
+ patch_input = patches_batch['image'][tio.DATA].to(self._device)
1078
+ locations = patches_batch[tio.LOCATION]
1079
+
1080
+ # Forward pass - get logits
1081
+ logits = self.model(patch_input)
1082
+
1083
+ # Convert logits to probabilities BEFORE aggregation
1084
+ # This is critical: softmax is non-linear, so we must aggregate
1085
+ # probabilities, not logits, to get correct boundary handling
1086
+ n_classes = logits.shape[1]
1087
+ if n_classes == 1:
1088
+ probs = torch.sigmoid(logits)
1089
+ else:
1090
+ probs = torch.softmax(logits, dim=1) # dim=1 for batch [B, C, H, W, D]
1091
+
1092
+ # Add probabilities to aggregator
1093
+ aggregator.add_batch(probs.cpu(), locations)
1094
+
1095
+ # Get reconstructed output (now contains probabilities, not logits)
1096
+ output = aggregator.get_output_tensor()
1097
+
1098
+ # Convert to prediction mask (only if not returning probabilities)
1099
+ if return_probabilities:
1100
+ result = output # Keep as float probabilities
1101
+ else:
1102
+ n_classes = output.shape[0]
1103
+ if n_classes == 1:
1104
+ result = (output > 0.5).float()
1105
+ else:
1106
+ result = output.argmax(dim=0, keepdim=True).float()
1107
+
1108
+ # Apply keep_largest post-processing for binary segmentation
1109
+ if not return_probabilities and self.config.keep_largest_component:
1110
+ from fastMONAI.vision_inference import keep_largest
1111
+ result = keep_largest(result.squeeze(0)).unsqueeze(0)
1112
+
1113
+ # Post-processing: resize back to original size and reorient
1114
+ # This matches the workflow in vision_inference.py
1115
+
1116
+ # Wrap result in TorchIO Image for resizing
1117
+ # Use ScalarImage for probabilities, LabelMap for masks
1118
+ if return_probabilities:
1119
+ pred_img = tio.ScalarImage(tensor=result.float(), affine=input_img.affine)
1120
+ else:
1121
+ pred_img = tio.LabelMap(tensor=result.float(), affine=input_img.affine)
1122
+
1123
+ # Resize back to original size (before resampling)
1124
+ pred_img = _do_resize(pred_img, org_size, image_interpolation='nearest')
1125
+
1126
+ # Reorient to original orientation (if reorder was applied)
1127
+ # Use explicit .cpu() for consistent device handling
1128
+ if self.apply_reorder:
1129
+ reoriented_array = _to_original_orientation(
1130
+ pred_img.as_sitk(),
1131
+ ('').join(org_img.orientation)
1132
+ )
1133
+ result = torch.from_numpy(reoriented_array).cpu()
1134
+ # Only convert to long for masks, not probabilities
1135
+ if not return_probabilities:
1136
+ result = result.long()
1137
+ else:
1138
+ result = pred_img.data.cpu()
1139
+ # Only convert to long for masks, not probabilities
1140
+ if not return_probabilities:
1141
+ result = result.long()
1142
+
1143
+ # Use original affine matrix for correct spatial alignment
1144
+ # org_img.affine is always available from med_img_reader
1145
+ if not (hasattr(org_img, 'affine') and org_img.affine is not None):
1146
+ raise RuntimeError(
1147
+ "org_img.affine not available. This should never happen - please report this bug."
1148
+ )
1149
+ affine = org_img.affine.copy()
1150
+
1151
+ if return_affine:
1152
+ return result, affine
1153
+ return result
1154
+
1155
+ def to(self, device):
1156
+ """Move engine to device."""
1157
+ self._device = device
1158
+ self.model.to(device)
1159
+ return self
1160
+
1161
+ # %% ../nbs/10_vision_patch.ipynb 22
1162
+ def patch_inference(
1163
+ learner,
1164
+ config: PatchConfig,
1165
+ file_paths: list,
1166
+ apply_reorder: bool = None,
1167
+ target_spacing: list = None,
1168
+ batch_size: int = 4,
1169
+ return_probabilities: bool = False,
1170
+ progress: bool = True,
1171
+ save_dir: str = None,
1172
+ pre_inference_tfms: list = None
1173
+ ) -> list:
1174
+ """Batch patch-based inference on multiple volumes.
1175
+
1176
+ Args:
1177
+ learner: PyTorch model or fastai Learner.
1178
+ config: PatchConfig with inference settings. Preprocessing params (apply_reorder,
1179
+ target_spacing) can be set here for DRY usage.
1180
+ file_paths: List of image paths.
1181
+ apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.
1182
+ target_spacing: Target voxel spacing. If None, uses config value.
1183
+ batch_size: Patches per batch.
1184
+ return_probabilities: Return probability maps.
1185
+ progress: Show progress bar.
1186
+ save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved.
1187
+ pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.
1188
+ IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).
1189
+
1190
+ Returns:
1191
+ List of predicted tensors.
1192
+
1193
+ Example:
1194
+ >>> # DRY pattern: use same config for training and inference
1195
+ >>> config = PatchConfig(
1196
+ ... patch_size=[96, 96, 96],
1197
+ ... apply_reorder=True,
1198
+ ... target_spacing=[0.4102, 0.4102, 1.5]
1199
+ ... )
1200
+ >>> predictions = patch_inference(
1201
+ ... learner=learn,
1202
+ ... config=config, # apply_reorder and target_spacing from config
1203
+ ... file_paths=val_paths,
1204
+ ... pre_inference_tfms=[tio.ZNormalization()],
1205
+ ... save_dir='predictions/patch_based'
1206
+ ... )
1207
+ """
1208
+ # Use config values if not explicitly provided
1209
+ _apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder
1210
+ _target_spacing = target_spacing if target_spacing is not None else config.target_spacing
1211
+
1212
+ engine = PatchInferenceEngine(
1213
+ learner, config, _apply_reorder, _target_spacing, batch_size, pre_inference_tfms
1214
+ )
1215
+
1216
+ # Create save directory if specified
1217
+ if save_dir is not None:
1218
+ save_path = Path(save_dir)
1219
+ save_path.mkdir(parents=True, exist_ok=True)
1220
+
1221
+ predictions = []
1222
+ iterator = tqdm(file_paths, desc='Patch inference') if progress else file_paths
1223
+
1224
+ for path in iterator:
1225
+ # Get prediction and affine when saving is needed
1226
+ if save_dir is not None:
1227
+ pred, affine = engine.predict(path, return_probabilities, return_affine=True)
1228
+ else:
1229
+ pred = engine.predict(path, return_probabilities)
1230
+ predictions.append(pred)
1231
+
1232
+ # Save prediction if save_dir specified
1233
+ if save_dir is not None:
1234
+ input_path = Path(path)
1235
+ # Create output filename based on input using suffix-based approach
1236
+ # This handles .nii.gz correctly without corrupting filenames with .nii elsewhere
1237
+ stem = input_path.stem
1238
+ if input_path.suffix == '.gz' and stem.endswith('.nii'):
1239
+ # Handle .nii.gz files: stem is "filename.nii", strip the .nii
1240
+ stem = stem[:-4]
1241
+ out_name = f"{stem}_pred.nii.gz"
1242
+ elif input_path.suffix == '.nii':
1243
+ # Handle .nii files
1244
+ out_name = f"{stem}_pred.nii"
1245
+ else:
1246
+ # Fallback for other formats
1247
+ out_name = f"{stem}_pred.nii.gz"
1248
+ out_path = save_path / out_name
1249
+
1250
+ # affine is guaranteed to be valid from engine.predict() with return_affine=True
1251
+ # Save as NIfTI using TorchIO with correct type
1252
+ # Use ScalarImage for probabilities (float), LabelMap for masks (int)
1253
+ if return_probabilities:
1254
+ pred_img = tio.ScalarImage(tensor=pred, affine=affine)
1255
+ else:
1256
+ pred_img = tio.LabelMap(tensor=pred, affine=affine)
1257
+ pred_img.save(out_path)
1258
+
1259
+ return predictions