fastMONAI 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +171 -27
- fastMONAI/dataset_info.py +190 -45
- fastMONAI/external_data.py +1 -1
- fastMONAI/utils.py +101 -18
- fastMONAI/vision_all.py +3 -2
- fastMONAI/vision_augmentation.py +133 -29
- fastMONAI/vision_core.py +29 -132
- fastMONAI/vision_data.py +6 -6
- fastMONAI/vision_inference.py +35 -9
- fastMONAI/vision_metrics.py +420 -19
- fastMONAI/vision_patch.py +1125 -0
- fastMONAI/vision_plot.py +1 -0
- {fastmonai-0.5.2.dist-info → fastmonai-0.5.4.dist-info}/METADATA +6 -6
- fastmonai-0.5.4.dist-info/RECORD +21 -0
- {fastmonai-0.5.2.dist-info → fastmonai-0.5.4.dist-info}/WHEEL +1 -1
- fastmonai-0.5.2.dist-info/RECORD +0 -20
- {fastmonai-0.5.2.dist-info → fastmonai-0.5.4.dist-info}/entry_points.txt +0 -0
- {fastmonai-0.5.2.dist-info → fastmonai-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {fastmonai-0.5.2.dist-info → fastmonai-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1125 @@
|
|
|
1
|
+
"""Patch-based training and inference for 3D medical image segmentation using TorchIO's Queue mechanism."""
|
|
2
|
+
|
|
3
|
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_vision_patch.ipynb.
|
|
4
|
+
|
|
5
|
+
# %% auto 0
|
|
6
|
+
__all__ = ['normalize_patch_transforms', 'PatchConfig', 'med_to_subject', 'create_subjects_dataset', 'create_patch_sampler',
|
|
7
|
+
'MedPatchDataLoader', 'MedPatchDataLoaders', 'PatchInferenceEngine', 'patch_inference']
|
|
8
|
+
|
|
9
|
+
# %% ../nbs/10_vision_patch.ipynb 2
|
|
10
|
+
import torch
|
|
11
|
+
import torchio as tio
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import numpy as np
|
|
14
|
+
import warnings
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from typing import Callable
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
from tqdm.auto import tqdm
|
|
20
|
+
from fastai.data.all import *
|
|
21
|
+
from .vision_core import MedImage, MedMask, MedBase, med_img_reader
|
|
22
|
+
from .vision_inference import _to_original_orientation, _do_resize
|
|
23
|
+
|
|
24
|
+
# %% ../nbs/10_vision_patch.ipynb 3
|
|
25
|
+
def _get_default_device() -> torch.device:
|
|
26
|
+
"""Get the default device (CUDA if available, else CPU)."""
|
|
27
|
+
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _warn_config_override(param_name: str, config_value, explicit_value):
|
|
31
|
+
"""Warn when explicit argument overrides config value.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
param_name: Name of the parameter (e.g., 'apply_reorder', 'target_spacing')
|
|
35
|
+
config_value: Value from PatchConfig
|
|
36
|
+
explicit_value: Explicitly provided value
|
|
37
|
+
"""
|
|
38
|
+
if explicit_value is not None and config_value is not None:
|
|
39
|
+
if explicit_value != config_value:
|
|
40
|
+
warnings.warn(
|
|
41
|
+
f"{param_name} mismatch: explicit={explicit_value}, config={config_value}. "
|
|
42
|
+
f"Using explicit argument."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _extract_tio_transform(tfm):
|
|
47
|
+
"""Extract TorchIO transform from fastMONAI wrapper or return as-is.
|
|
48
|
+
|
|
49
|
+
This function enables using fastMONAI wrappers (e.g., RandomAffine, RandomGamma)
|
|
50
|
+
in patch-based workflows where raw TorchIO transforms are needed for tio.Compose().
|
|
51
|
+
|
|
52
|
+
Uses the explicit `.tio_transform` property when available on fastMONAI wrappers.
|
|
53
|
+
Falls back to returning the transform unchanged for raw TorchIO transforms.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
tfm: fastMONAI wrapper (e.g., RandomAffine) or raw TorchIO transform
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The underlying TorchIO transform
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
>>> from fastMONAI.vision_augmentation import RandomAffine
|
|
63
|
+
>>> wrapped = RandomAffine(degrees=10)
|
|
64
|
+
>>> raw = _extract_tio_transform(wrapped) # Returns tio.RandomAffine
|
|
65
|
+
"""
|
|
66
|
+
return getattr(tfm, 'tio_transform', tfm)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def normalize_patch_transforms(tfms: list) -> list:
|
|
70
|
+
"""Normalize transforms for patch-based workflow.
|
|
71
|
+
|
|
72
|
+
Extracts underlying TorchIO transforms from fastMONAI wrappers.
|
|
73
|
+
Also accepts raw TorchIO transforms for backward compatibility.
|
|
74
|
+
|
|
75
|
+
This enables using the same transform syntax in both standard and
|
|
76
|
+
patch-based workflows:
|
|
77
|
+
|
|
78
|
+
>>> from fastMONAI.vision_augmentation import RandomAffine, RandomGamma
|
|
79
|
+
>>>
|
|
80
|
+
>>> # Same syntax works in both contexts
|
|
81
|
+
>>> item_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)] # Standard
|
|
82
|
+
>>> patch_tfms = [RandomAffine(degrees=10), RandomGamma(p=0.5)] # Patch-based
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
tfms: List of fastMONAI wrappers or raw TorchIO transforms
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List of raw TorchIO transforms suitable for tio.Compose()
|
|
89
|
+
"""
|
|
90
|
+
if tfms is None:
|
|
91
|
+
return None
|
|
92
|
+
return [_extract_tio_transform(t) for t in tfms]
|
|
93
|
+
|
|
94
|
+
# %% ../nbs/10_vision_patch.ipynb 7
|
|
95
|
+
@dataclass
|
|
96
|
+
class PatchConfig:
|
|
97
|
+
"""Configuration for patch-based training and inference.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
patch_size: Size of patches [x, y, z].
|
|
101
|
+
patch_overlap: Overlap for inference GridSampler (int, float 0-1, or list).
|
|
102
|
+
- Float 0-1: fraction of patch_size (e.g., 0.5 = 50% overlap)
|
|
103
|
+
- Int >= 1: pixel overlap (e.g., 48 = 48 pixel overlap)
|
|
104
|
+
- List: per-dimension overlap in pixels
|
|
105
|
+
samples_per_volume: Number of patches to extract per volume during training.
|
|
106
|
+
sampler_type: Type of sampler ('uniform', 'label', 'weighted').
|
|
107
|
+
label_probabilities: For LabelSampler, dict mapping label values to probabilities.
|
|
108
|
+
queue_length: Maximum number of patches to store in queue.
|
|
109
|
+
queue_num_workers: Number of workers for parallel patch extraction.
|
|
110
|
+
aggregation_mode: For inference, how to combine overlapping patches ('crop', 'average', 'hann').
|
|
111
|
+
apply_reorder: Whether to reorder to RAS+ canonical orientation. Must match between
|
|
112
|
+
training and inference.
|
|
113
|
+
target_spacing: Target voxel spacing [x, y, z] for resampling. Must match between
|
|
114
|
+
training and inference.
|
|
115
|
+
padding_mode: Padding mode for CropOrPad when image < patch_size. Default is 0 (zero padding)
|
|
116
|
+
to align with nnU-Net's approach. Can be int, float, or string (e.g., 'minimum', 'mean').
|
|
117
|
+
keep_largest_component: If True, keep only the largest connected component
|
|
118
|
+
in binary segmentation predictions. Only applies during inference when
|
|
119
|
+
return_probabilities=False. Defaults to False.
|
|
120
|
+
|
|
121
|
+
Example:
|
|
122
|
+
>>> config = PatchConfig(
|
|
123
|
+
... patch_size=[96, 96, 96],
|
|
124
|
+
... samples_per_volume=16,
|
|
125
|
+
... sampler_type='label',
|
|
126
|
+
... label_probabilities={0: 0.1, 1: 0.9},
|
|
127
|
+
... apply_reorder=True,
|
|
128
|
+
... target_spacing=[0.5, 0.5, 0.5]
|
|
129
|
+
... )
|
|
130
|
+
"""
|
|
131
|
+
patch_size: list = field(default_factory=lambda: [96, 96, 96])
|
|
132
|
+
patch_overlap: int | float | list = 0
|
|
133
|
+
samples_per_volume: int = 8
|
|
134
|
+
sampler_type: str = 'uniform'
|
|
135
|
+
label_probabilities: dict = None
|
|
136
|
+
queue_length: int = 300
|
|
137
|
+
queue_num_workers: int = 4
|
|
138
|
+
aggregation_mode: str = 'hann'
|
|
139
|
+
# Preprocessing parameters - must match between training and inference
|
|
140
|
+
apply_reorder: bool = False
|
|
141
|
+
target_spacing: list = None
|
|
142
|
+
padding_mode: int | float | str = 0 # Zero padding (nnU-Net standard)
|
|
143
|
+
# Post-processing (binary segmentation only)
|
|
144
|
+
keep_largest_component: bool = False
|
|
145
|
+
|
|
146
|
+
def __post_init__(self):
|
|
147
|
+
"""Validate configuration."""
|
|
148
|
+
valid_samplers = ['uniform', 'label', 'weighted']
|
|
149
|
+
if self.sampler_type not in valid_samplers:
|
|
150
|
+
raise ValueError(f"sampler_type must be one of {valid_samplers}")
|
|
151
|
+
|
|
152
|
+
valid_aggregation = ['crop', 'average', 'hann']
|
|
153
|
+
if self.aggregation_mode not in valid_aggregation:
|
|
154
|
+
raise ValueError(f"aggregation_mode must be one of {valid_aggregation}")
|
|
155
|
+
|
|
156
|
+
# Validate patch_overlap
|
|
157
|
+
# Negative overlap doesn't make sense
|
|
158
|
+
if isinstance(self.patch_overlap, (int, float)):
|
|
159
|
+
if self.patch_overlap < 0:
|
|
160
|
+
raise ValueError("patch_overlap cannot be negative")
|
|
161
|
+
# Check if overlap as pixels would exceed patch_size (causes step_size=0)
|
|
162
|
+
if self.patch_overlap >= 1: # Pixel value, not fraction
|
|
163
|
+
for ps in self.patch_size:
|
|
164
|
+
if self.patch_overlap >= ps:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"patch_overlap ({self.patch_overlap}) must be less than patch_size ({ps}). "
|
|
167
|
+
f"Overlap >= patch_size creates step_size <= 0 (infinite patches)."
|
|
168
|
+
)
|
|
169
|
+
elif isinstance(self.patch_overlap, (list, tuple)):
|
|
170
|
+
for i, (overlap, ps) in enumerate(zip(self.patch_overlap, self.patch_size)):
|
|
171
|
+
if overlap < 0:
|
|
172
|
+
raise ValueError(f"patch_overlap[{i}] cannot be negative")
|
|
173
|
+
if overlap >= ps:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"patch_overlap[{i}] ({overlap}) must be less than patch_size[{i}] ({ps}). "
|
|
176
|
+
f"Overlap >= patch_size creates step_size <= 0 (infinite patches)."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# %% ../nbs/10_vision_patch.ipynb 10
|
|
180
|
+
def med_to_subject(
|
|
181
|
+
img: Path | str,
|
|
182
|
+
mask: Path | str = None,
|
|
183
|
+
) -> tio.Subject:
|
|
184
|
+
"""Create TorchIO Subject with LAZY loading (paths only, no tensor loading).
|
|
185
|
+
|
|
186
|
+
This function stores file paths in the Subject, allowing TorchIO's Queue
|
|
187
|
+
workers to load volumes on-demand during training. This is memory-efficient
|
|
188
|
+
as volumes are not loaded into RAM until needed.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
img: Path to image file.
|
|
192
|
+
mask: Path to mask file (optional).
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
TorchIO Subject with 'image' and optionally 'mask' keys (lazy loaded).
|
|
196
|
+
|
|
197
|
+
Example:
|
|
198
|
+
>>> subject = med_to_subject('image.nii.gz', 'mask.nii.gz')
|
|
199
|
+
>>> # Volume NOT loaded yet - only path stored
|
|
200
|
+
>>> data = subject['image'].data # NOW volume is loaded
|
|
201
|
+
"""
|
|
202
|
+
subject_dict = {
|
|
203
|
+
'image': tio.ScalarImage(path=str(img)) # Lazy - stores path only
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
if mask is not None:
|
|
207
|
+
subject_dict['mask'] = tio.LabelMap(path=str(mask)) # Lazy
|
|
208
|
+
|
|
209
|
+
return tio.Subject(**subject_dict)
|
|
210
|
+
|
|
211
|
+
# %% ../nbs/10_vision_patch.ipynb 11
|
|
212
|
+
def create_subjects_dataset(
|
|
213
|
+
df: pd.DataFrame,
|
|
214
|
+
img_col: str,
|
|
215
|
+
mask_col: str = None,
|
|
216
|
+
pre_tfms: list = None,
|
|
217
|
+
ensure_affine_consistency: bool = True
|
|
218
|
+
) -> tio.SubjectsDataset:
|
|
219
|
+
"""Build TorchIO SubjectsDataset with LAZY loading from DataFrame.
|
|
220
|
+
|
|
221
|
+
This function creates a SubjectsDataset that stores only file paths,
|
|
222
|
+
not loaded tensors. Volumes are loaded on-demand by Queue workers,
|
|
223
|
+
keeping memory usage constant regardless of dataset size.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
df: DataFrame with image (and optionally mask) paths.
|
|
227
|
+
img_col: Column name containing image paths.
|
|
228
|
+
mask_col: Column name containing mask paths (optional).
|
|
229
|
+
pre_tfms: List of TorchIO transforms to apply before patch extraction.
|
|
230
|
+
Use tio.ToCanonical() for reordering and tio.Resample() for resampling.
|
|
231
|
+
ensure_affine_consistency: If True and mask_col is provided, automatically
|
|
232
|
+
prepends tio.CopyAffine(target='image') to ensure spatial metadata
|
|
233
|
+
consistency between image and mask. This prevents "More than one value
|
|
234
|
+
for direction found" errors. Defaults to True.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
TorchIO SubjectsDataset with lazy-loaded subjects.
|
|
238
|
+
|
|
239
|
+
Example:
|
|
240
|
+
>>> # Preprocessing via transforms (applied by workers on-demand)
|
|
241
|
+
>>> pre_tfms = [
|
|
242
|
+
... tio.ToCanonical(), # Reorder to RAS+
|
|
243
|
+
... tio.Resample([0.5, 0.5, 0.5]), # Resample
|
|
244
|
+
... tio.ZNormalization(), # Intensity normalization
|
|
245
|
+
... ]
|
|
246
|
+
>>> dataset = create_subjects_dataset(
|
|
247
|
+
... df, img_col='image', mask_col='label',
|
|
248
|
+
... pre_tfms=pre_tfms
|
|
249
|
+
... )
|
|
250
|
+
>>> # Memory: ~0 MB (only paths stored, not volumes)
|
|
251
|
+
"""
|
|
252
|
+
subjects = []
|
|
253
|
+
for idx, row in df.iterrows():
|
|
254
|
+
img_path = row[img_col]
|
|
255
|
+
mask_path = row[mask_col] if mask_col else None
|
|
256
|
+
|
|
257
|
+
# Create subject with lazy loading (paths only)
|
|
258
|
+
subject = med_to_subject(img=img_path, mask=mask_path)
|
|
259
|
+
subjects.append(subject)
|
|
260
|
+
|
|
261
|
+
# Build transform pipeline
|
|
262
|
+
all_transforms = []
|
|
263
|
+
|
|
264
|
+
# Add CopyAffine as FIRST transform when mask is present
|
|
265
|
+
# This ensures spatial metadata consistency before other transforms
|
|
266
|
+
if mask_col is not None and ensure_affine_consistency:
|
|
267
|
+
all_transforms.append(tio.CopyAffine(target='image'))
|
|
268
|
+
|
|
269
|
+
# Add user-provided transforms
|
|
270
|
+
if pre_tfms:
|
|
271
|
+
all_transforms.extend(pre_tfms)
|
|
272
|
+
|
|
273
|
+
transform = tio.Compose(all_transforms) if all_transforms else None
|
|
274
|
+
|
|
275
|
+
return tio.SubjectsDataset(subjects, transform=transform)
|
|
276
|
+
|
|
277
|
+
# %% ../nbs/10_vision_patch.ipynb 13
|
|
278
|
+
def create_patch_sampler(config: PatchConfig) -> tio.data.PatchSampler:
|
|
279
|
+
"""Create appropriate TorchIO sampler based on config.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
config: PatchConfig with sampler settings.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
TorchIO PatchSampler instance.
|
|
286
|
+
|
|
287
|
+
Example:
|
|
288
|
+
>>> config = PatchConfig(patch_size=[96, 96, 96], sampler_type='label')
|
|
289
|
+
>>> sampler = create_patch_sampler(config)
|
|
290
|
+
"""
|
|
291
|
+
patch_size = config.patch_size
|
|
292
|
+
|
|
293
|
+
if config.sampler_type == 'uniform':
|
|
294
|
+
return tio.UniformSampler(patch_size)
|
|
295
|
+
|
|
296
|
+
elif config.sampler_type == 'label':
|
|
297
|
+
return tio.LabelSampler(
|
|
298
|
+
patch_size,
|
|
299
|
+
label_name='mask',
|
|
300
|
+
label_probabilities=config.label_probabilities
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
elif config.sampler_type == 'weighted':
|
|
304
|
+
raise NotImplementedError(
|
|
305
|
+
"WeightedSampler requires a pre-computed probability map which is not currently supported. "
|
|
306
|
+
"Use 'label' sampler with label_probabilities for weighted sampling based on segmentation labels, "
|
|
307
|
+
"or 'uniform' for random patch extraction."
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
raise ValueError(f"Unknown sampler type: {config.sampler_type}")
|
|
311
|
+
|
|
312
|
+
# %% ../nbs/10_vision_patch.ipynb 17
|
|
313
|
+
class MedPatchDataLoader:
|
|
314
|
+
"""DataLoader wrapper for patch-based training with TorchIO Queue.
|
|
315
|
+
|
|
316
|
+
This class wraps a TorchIO Queue to provide a fastai-compatible DataLoader
|
|
317
|
+
interface for patch-based training.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
subjects_dataset: TorchIO SubjectsDataset.
|
|
321
|
+
config: PatchConfig with queue and sampler settings.
|
|
322
|
+
batch_size: Number of patches per batch. Must be positive.
|
|
323
|
+
patch_tfms: Transforms to apply to extracted patches (training only).
|
|
324
|
+
Accepts both fastMONAI wrappers (e.g., RandomAffine, RandomGamma) and
|
|
325
|
+
raw TorchIO transforms. fastMONAI wrappers are automatically normalized
|
|
326
|
+
to raw TorchIO for internal use.
|
|
327
|
+
shuffle: Whether to shuffle subjects and patches.
|
|
328
|
+
drop_last: Whether to drop last incomplete batch.
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
def __init__(
|
|
332
|
+
self,
|
|
333
|
+
subjects_dataset: tio.SubjectsDataset,
|
|
334
|
+
config: PatchConfig,
|
|
335
|
+
batch_size: int = 4,
|
|
336
|
+
patch_tfms: list = None,
|
|
337
|
+
shuffle: bool = True,
|
|
338
|
+
drop_last: bool = False
|
|
339
|
+
):
|
|
340
|
+
if batch_size <= 0:
|
|
341
|
+
raise ValueError(f"batch_size must be positive, got {batch_size}")
|
|
342
|
+
|
|
343
|
+
self.subjects_dataset = subjects_dataset
|
|
344
|
+
self.config = config
|
|
345
|
+
self.bs = batch_size
|
|
346
|
+
self.shuffle = shuffle
|
|
347
|
+
self.drop_last = drop_last
|
|
348
|
+
self._device = _get_default_device()
|
|
349
|
+
|
|
350
|
+
# Create sampler
|
|
351
|
+
self.sampler = create_patch_sampler(config)
|
|
352
|
+
|
|
353
|
+
# Create patch transforms
|
|
354
|
+
# Normalize transforms - accepts both fastMONAI wrappers and raw TorchIO
|
|
355
|
+
normalized_tfms = normalize_patch_transforms(patch_tfms)
|
|
356
|
+
self.patch_tfms = tio.Compose(normalized_tfms) if normalized_tfms else None
|
|
357
|
+
|
|
358
|
+
# Create queue
|
|
359
|
+
self.queue = tio.Queue(
|
|
360
|
+
subjects_dataset,
|
|
361
|
+
max_length=config.queue_length,
|
|
362
|
+
samples_per_volume=config.samples_per_volume,
|
|
363
|
+
sampler=self.sampler,
|
|
364
|
+
num_workers=config.queue_num_workers,
|
|
365
|
+
shuffle_subjects=shuffle,
|
|
366
|
+
shuffle_patches=shuffle
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Create torch DataLoader
|
|
370
|
+
self._dl = DataLoader(
|
|
371
|
+
self.queue,
|
|
372
|
+
batch_size=batch_size,
|
|
373
|
+
num_workers=0, # Queue handles workers
|
|
374
|
+
drop_last=drop_last
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
def __iter__(self):
|
|
378
|
+
"""Iterate over batches, yielding (image, mask) tuples."""
|
|
379
|
+
for batch in self._dl:
|
|
380
|
+
# Extract image and mask tensors
|
|
381
|
+
img = batch['image'][tio.DATA] # [B, C, H, W, D]
|
|
382
|
+
has_mask = 'mask' in batch
|
|
383
|
+
|
|
384
|
+
# Apply patch transforms if provided
|
|
385
|
+
if self.patch_tfms is not None:
|
|
386
|
+
# Apply transforms to each sample in batch
|
|
387
|
+
transformed_imgs = []
|
|
388
|
+
transformed_masks = [] if has_mask else None
|
|
389
|
+
|
|
390
|
+
for i in range(img.shape[0]):
|
|
391
|
+
# Build subject dict with image, and mask if available
|
|
392
|
+
subject_dict = {'image': tio.ScalarImage(tensor=batch['image'][tio.DATA][i])}
|
|
393
|
+
if has_mask:
|
|
394
|
+
subject_dict['mask'] = tio.LabelMap(tensor=batch['mask'][tio.DATA][i])
|
|
395
|
+
|
|
396
|
+
subject = tio.Subject(subject_dict)
|
|
397
|
+
transformed = self.patch_tfms(subject)
|
|
398
|
+
transformed_imgs.append(transformed['image'].data)
|
|
399
|
+
if has_mask:
|
|
400
|
+
transformed_masks.append(transformed['mask'].data)
|
|
401
|
+
|
|
402
|
+
img = torch.stack(transformed_imgs)
|
|
403
|
+
mask = torch.stack(transformed_masks) if has_mask else None
|
|
404
|
+
else:
|
|
405
|
+
mask = batch['mask'][tio.DATA] if has_mask else None
|
|
406
|
+
|
|
407
|
+
# Convert to MedImage/MedMask and move to device
|
|
408
|
+
img = MedImage(img).to(self._device)
|
|
409
|
+
if mask is not None:
|
|
410
|
+
mask = MedMask(mask).to(self._device)
|
|
411
|
+
|
|
412
|
+
yield img, mask
|
|
413
|
+
|
|
414
|
+
def __len__(self):
|
|
415
|
+
"""Return number of batches per epoch."""
|
|
416
|
+
n_patches = len(self.subjects_dataset) * self.config.samples_per_volume
|
|
417
|
+
if self.drop_last:
|
|
418
|
+
return n_patches // self.bs
|
|
419
|
+
return (n_patches + self.bs - 1) // self.bs
|
|
420
|
+
|
|
421
|
+
@property
|
|
422
|
+
def dataset(self):
|
|
423
|
+
"""Return the underlying queue as dataset."""
|
|
424
|
+
return self.queue
|
|
425
|
+
|
|
426
|
+
@property
|
|
427
|
+
def device(self):
|
|
428
|
+
"""Return current device."""
|
|
429
|
+
return self._device
|
|
430
|
+
|
|
431
|
+
def to(self, device):
|
|
432
|
+
"""Move DataLoader to device."""
|
|
433
|
+
self._device = device
|
|
434
|
+
return self
|
|
435
|
+
|
|
436
|
+
def one_batch(self):
|
|
437
|
+
"""Return one batch from the DataLoader.
|
|
438
|
+
|
|
439
|
+
Required for fastai compatibility - used for device detection
|
|
440
|
+
and batch shape validation during Learner initialization.
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
Tuple of (image, mask) tensors on the correct device.
|
|
444
|
+
"""
|
|
445
|
+
return next(iter(self))
|
|
446
|
+
|
|
447
|
+
# %% ../nbs/10_vision_patch.ipynb 18
|
|
448
|
+
class MedPatchDataLoaders:
|
|
449
|
+
"""fastai-compatible DataLoaders for patch-based training with LAZY loading.
|
|
450
|
+
|
|
451
|
+
This class provides train and validation DataLoaders that work with
|
|
452
|
+
fastai's Learner for patch-based training on 3D medical images.
|
|
453
|
+
|
|
454
|
+
Memory-efficient: Volumes are loaded on-demand by Queue workers,
|
|
455
|
+
keeping memory usage constant (~150 MB) regardless of dataset size.
|
|
456
|
+
|
|
457
|
+
Note: Validation uses the same sampling as training (pseudo Dice).
|
|
458
|
+
For true validation metrics, use PatchInferenceEngine with GridSampler
|
|
459
|
+
for full-volume sliding window inference.
|
|
460
|
+
|
|
461
|
+
Example:
|
|
462
|
+
>>> import torchio as tio
|
|
463
|
+
>>>
|
|
464
|
+
>>> # New pattern: preprocessing params in config (DRY)
|
|
465
|
+
>>> config = PatchConfig(
|
|
466
|
+
... patch_size=[96, 96, 96],
|
|
467
|
+
... apply_reorder=True,
|
|
468
|
+
... target_spacing=[0.5, 0.5, 0.5]
|
|
469
|
+
... )
|
|
470
|
+
>>> dls = MedPatchDataLoaders.from_df(
|
|
471
|
+
... df, img_col='image', mask_col='label',
|
|
472
|
+
... valid_pct=0.2,
|
|
473
|
+
... patch_config=config,
|
|
474
|
+
... pre_patch_tfms=[tio.ZNormalization()],
|
|
475
|
+
... bs=4
|
|
476
|
+
... )
|
|
477
|
+
>>> learn = Learner(dls, model, loss_func=DiceLoss())
|
|
478
|
+
"""
|
|
479
|
+
|
|
480
|
+
def __init__(
|
|
481
|
+
self,
|
|
482
|
+
train_dl: MedPatchDataLoader,
|
|
483
|
+
valid_dl: MedPatchDataLoader,
|
|
484
|
+
device: torch.device = None
|
|
485
|
+
):
|
|
486
|
+
self._train_dl = train_dl
|
|
487
|
+
self._valid_dl = valid_dl
|
|
488
|
+
self._device = device or _get_default_device()
|
|
489
|
+
|
|
490
|
+
# Move to device
|
|
491
|
+
self._train_dl.to(self._device)
|
|
492
|
+
self._valid_dl.to(self._device)
|
|
493
|
+
|
|
494
|
+
@classmethod
|
|
495
|
+
def from_df(
|
|
496
|
+
cls,
|
|
497
|
+
df: pd.DataFrame,
|
|
498
|
+
img_col: str,
|
|
499
|
+
mask_col: str = None,
|
|
500
|
+
valid_pct: float = 0.2,
|
|
501
|
+
valid_col: str = None,
|
|
502
|
+
patch_config: PatchConfig = None,
|
|
503
|
+
pre_patch_tfms: list = None,
|
|
504
|
+
patch_tfms: list = None,
|
|
505
|
+
apply_reorder: bool = None,
|
|
506
|
+
target_spacing: list = None,
|
|
507
|
+
bs: int = 4,
|
|
508
|
+
seed: int = None,
|
|
509
|
+
device: torch.device = None,
|
|
510
|
+
ensure_affine_consistency: bool = True
|
|
511
|
+
) -> 'MedPatchDataLoaders':
|
|
512
|
+
"""Create train/valid DataLoaders from DataFrame with LAZY loading.
|
|
513
|
+
|
|
514
|
+
Memory-efficient: Only file paths are stored at creation time.
|
|
515
|
+
Volumes are loaded on-demand by Queue workers during training.
|
|
516
|
+
|
|
517
|
+
Note: Both train and valid use the same sampling strategy from patch_config.
|
|
518
|
+
This gives pseudo Dice during training. For true validation metrics,
|
|
519
|
+
use PatchInferenceEngine with full-volume sliding window inference.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
df: DataFrame with image paths.
|
|
523
|
+
img_col: Column name for image paths.
|
|
524
|
+
mask_col: Column name for mask paths.
|
|
525
|
+
valid_pct: Fraction of data for validation.
|
|
526
|
+
valid_col: Column name for train/valid split (if pre-defined).
|
|
527
|
+
patch_config: PatchConfig instance. Preprocessing params (apply_reorder,
|
|
528
|
+
target_spacing) can be set here for DRY usage with PatchInferenceEngine.
|
|
529
|
+
pre_patch_tfms: TorchIO transforms applied before patch extraction
|
|
530
|
+
(after reorder/resample). Example: [tio.ZNormalization()].
|
|
531
|
+
patch_tfms: TorchIO transforms applied to extracted patches (training only).
|
|
532
|
+
apply_reorder: If True, reorder to RAS+ orientation. If None, uses
|
|
533
|
+
patch_config.apply_reorder. Explicit value overrides config.
|
|
534
|
+
target_spacing: Target voxel spacing [x, y, z]. If None, uses
|
|
535
|
+
patch_config.target_spacing. Explicit value overrides config.
|
|
536
|
+
bs: Batch size.
|
|
537
|
+
seed: Random seed for splitting.
|
|
538
|
+
device: Device to use.
|
|
539
|
+
ensure_affine_consistency: If True and mask_col is provided, automatically
|
|
540
|
+
adds tio.CopyAffine(target='image') as the first transform to prevent
|
|
541
|
+
spatial metadata mismatch errors. Defaults to True.
|
|
542
|
+
|
|
543
|
+
Returns:
|
|
544
|
+
MedPatchDataLoaders instance.
|
|
545
|
+
|
|
546
|
+
Example:
|
|
547
|
+
>>> # New pattern: config contains preprocessing params
|
|
548
|
+
>>> config = PatchConfig(
|
|
549
|
+
... patch_size=[96, 96, 96],
|
|
550
|
+
... apply_reorder=True,
|
|
551
|
+
... target_spacing=[0.5, 0.5, 0.5],
|
|
552
|
+
... label_probabilities={0: 0.1, 1: 0.9}
|
|
553
|
+
... )
|
|
554
|
+
>>> dls = MedPatchDataLoaders.from_df(
|
|
555
|
+
... df, img_col='image', mask_col='label',
|
|
556
|
+
... patch_config=config,
|
|
557
|
+
... pre_patch_tfms=[tio.ZNormalization()],
|
|
558
|
+
... patch_tfms=[tio.RandomAffine(degrees=10), tio.RandomFlip()],
|
|
559
|
+
... bs=4
|
|
560
|
+
... )
|
|
561
|
+
>>> # Memory: ~150 MB (queue buffer only)
|
|
562
|
+
"""
|
|
563
|
+
if patch_config is None:
|
|
564
|
+
patch_config = PatchConfig()
|
|
565
|
+
|
|
566
|
+
# Use config values, allow explicit overrides for backward compatibility
|
|
567
|
+
_apply_reorder = apply_reorder if apply_reorder is not None else patch_config.apply_reorder
|
|
568
|
+
_target_spacing = target_spacing if target_spacing is not None else patch_config.target_spacing
|
|
569
|
+
|
|
570
|
+
# Warn if both config and explicit args provided with different values
|
|
571
|
+
_warn_config_override('apply_reorder', patch_config.apply_reorder, apply_reorder)
|
|
572
|
+
_warn_config_override('target_spacing', patch_config.target_spacing, target_spacing)
|
|
573
|
+
|
|
574
|
+
# Split data
|
|
575
|
+
if valid_col is not None:
|
|
576
|
+
train_df = df[df[valid_col] == False].reset_index(drop=True)
|
|
577
|
+
valid_df = df[df[valid_col] == True].reset_index(drop=True)
|
|
578
|
+
else:
|
|
579
|
+
if seed is not None:
|
|
580
|
+
np.random.seed(seed)
|
|
581
|
+
n = len(df)
|
|
582
|
+
valid_idx = np.random.choice(n, size=int(n * valid_pct), replace=False)
|
|
583
|
+
train_idx = np.setdiff1d(np.arange(n), valid_idx)
|
|
584
|
+
train_df = df.iloc[train_idx].reset_index(drop=True)
|
|
585
|
+
valid_df = df.iloc[valid_idx].reset_index(drop=True)
|
|
586
|
+
|
|
587
|
+
# Build preprocessing transforms
|
|
588
|
+
all_pre_tfms = []
|
|
589
|
+
|
|
590
|
+
# Add reorder transform (reorder to RAS+ orientation)
|
|
591
|
+
if _apply_reorder:
|
|
592
|
+
all_pre_tfms.append(tio.ToCanonical())
|
|
593
|
+
|
|
594
|
+
# Add resample transform
|
|
595
|
+
if _target_spacing is not None:
|
|
596
|
+
all_pre_tfms.append(tio.Resample(_target_spacing))
|
|
597
|
+
|
|
598
|
+
# Add user-provided transforms
|
|
599
|
+
if pre_patch_tfms:
|
|
600
|
+
all_pre_tfms.extend(pre_patch_tfms)
|
|
601
|
+
|
|
602
|
+
# Create subjects datasets with lazy loading (paths only, ~0 MB)
|
|
603
|
+
train_subjects = create_subjects_dataset(
|
|
604
|
+
train_df, img_col, mask_col,
|
|
605
|
+
pre_tfms=all_pre_tfms if all_pre_tfms else None,
|
|
606
|
+
ensure_affine_consistency=ensure_affine_consistency
|
|
607
|
+
)
|
|
608
|
+
valid_subjects = create_subjects_dataset(
|
|
609
|
+
valid_df, img_col, mask_col,
|
|
610
|
+
pre_tfms=all_pre_tfms if all_pre_tfms else None,
|
|
611
|
+
ensure_affine_consistency=ensure_affine_consistency
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
# Create DataLoaders (both use same patch_config for consistent sampling)
|
|
615
|
+
train_dl = MedPatchDataLoader(
|
|
616
|
+
train_subjects, patch_config, bs,
|
|
617
|
+
patch_tfms=patch_tfms, shuffle=True, drop_last=True
|
|
618
|
+
)
|
|
619
|
+
valid_dl = MedPatchDataLoader(
|
|
620
|
+
valid_subjects, patch_config, bs,
|
|
621
|
+
patch_tfms=None, # No augmentation for validation
|
|
622
|
+
shuffle=False, drop_last=False
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# Create instance and store metadata
|
|
626
|
+
instance = cls(train_dl, valid_dl, device)
|
|
627
|
+
instance._img_col = img_col
|
|
628
|
+
instance._mask_col = mask_col
|
|
629
|
+
instance._pre_patch_tfms = pre_patch_tfms
|
|
630
|
+
instance._apply_reorder = _apply_reorder
|
|
631
|
+
instance._target_spacing = _target_spacing
|
|
632
|
+
instance._ensure_affine_consistency = ensure_affine_consistency
|
|
633
|
+
instance._patch_config = patch_config
|
|
634
|
+
return instance
|
|
635
|
+
|
|
636
|
+
@property
|
|
637
|
+
def train(self):
|
|
638
|
+
"""Training DataLoader."""
|
|
639
|
+
return self._train_dl
|
|
640
|
+
|
|
641
|
+
@property
|
|
642
|
+
def valid(self):
|
|
643
|
+
"""Validation DataLoader."""
|
|
644
|
+
return self._valid_dl
|
|
645
|
+
|
|
646
|
+
@property
|
|
647
|
+
def train_ds(self):
|
|
648
|
+
"""Training subjects dataset."""
|
|
649
|
+
return self._train_dl.subjects_dataset
|
|
650
|
+
|
|
651
|
+
@property
|
|
652
|
+
def valid_ds(self):
|
|
653
|
+
"""Validation subjects dataset."""
|
|
654
|
+
return self._valid_dl.subjects_dataset
|
|
655
|
+
|
|
656
|
+
@property
|
|
657
|
+
def device(self):
|
|
658
|
+
"""Current device."""
|
|
659
|
+
return self._device
|
|
660
|
+
|
|
661
|
+
@property
|
|
662
|
+
def bs(self):
|
|
663
|
+
"""Batch size."""
|
|
664
|
+
return self._train_dl.bs
|
|
665
|
+
|
|
666
|
+
@property
|
|
667
|
+
def apply_reorder(self):
|
|
668
|
+
"""Whether reordering to RAS+ is enabled."""
|
|
669
|
+
return getattr(self, '_apply_reorder', False)
|
|
670
|
+
|
|
671
|
+
@property
|
|
672
|
+
def target_spacing(self):
|
|
673
|
+
"""Target voxel spacing for resampling."""
|
|
674
|
+
return getattr(self, '_target_spacing', None)
|
|
675
|
+
|
|
676
|
+
@property
|
|
677
|
+
def patch_config(self):
|
|
678
|
+
"""The PatchConfig used for this DataLoaders."""
|
|
679
|
+
return getattr(self, '_patch_config', None)
|
|
680
|
+
|
|
681
|
+
def to(self, device):
|
|
682
|
+
"""Move DataLoaders to device."""
|
|
683
|
+
self._device = device
|
|
684
|
+
self._train_dl.to(device)
|
|
685
|
+
self._valid_dl.to(device)
|
|
686
|
+
return self
|
|
687
|
+
|
|
688
|
+
def __iter__(self):
|
|
689
|
+
"""Iterate over training DataLoader."""
|
|
690
|
+
return iter(self._train_dl)
|
|
691
|
+
|
|
692
|
+
def one_batch(self):
|
|
693
|
+
"""Return one batch from the training DataLoader.
|
|
694
|
+
|
|
695
|
+
Required for fastai Learner compatibility - used for device
|
|
696
|
+
detection and batch shape validation.
|
|
697
|
+
"""
|
|
698
|
+
return self._train_dl.one_batch()
|
|
699
|
+
|
|
700
|
+
def __len__(self):
|
|
701
|
+
"""Return number of batches in training DataLoader."""
|
|
702
|
+
return len(self._train_dl)
|
|
703
|
+
|
|
704
|
+
def __getitem__(self, idx):
|
|
705
|
+
"""Get DataLoader by index. Required for fastai Learner compatibility.
|
|
706
|
+
|
|
707
|
+
Args:
|
|
708
|
+
idx: 0 for training DataLoader, 1 for validation DataLoader.
|
|
709
|
+
|
|
710
|
+
Returns:
|
|
711
|
+
MedPatchDataLoader instance.
|
|
712
|
+
"""
|
|
713
|
+
if idx == 0:
|
|
714
|
+
return self._train_dl
|
|
715
|
+
elif idx == 1:
|
|
716
|
+
return self._valid_dl
|
|
717
|
+
else:
|
|
718
|
+
raise IndexError(f"Index {idx} out of range. Use 0 (train) or 1 (valid).")
|
|
719
|
+
|
|
720
|
+
def cuda(self):
|
|
721
|
+
"""Move DataLoaders to CUDA device."""
|
|
722
|
+
return self.to(torch.device('cuda'))
|
|
723
|
+
|
|
724
|
+
def cpu(self):
|
|
725
|
+
"""Move DataLoaders to CPU."""
|
|
726
|
+
return self.to(torch.device('cpu'))
|
|
727
|
+
|
|
728
|
+
def new_empty(self):
|
|
729
|
+
"""Create a new empty version of self for learner export.
|
|
730
|
+
|
|
731
|
+
Required for fastai Learner.export() compatibility - creates a
|
|
732
|
+
lightweight placeholder that can be pickled without the full dataset.
|
|
733
|
+
|
|
734
|
+
Returns:
|
|
735
|
+
A minimal MedPatchDataLoaders-like object with no data.
|
|
736
|
+
"""
|
|
737
|
+
class EmptyMedPatchDataLoaders:
|
|
738
|
+
"""Minimal placeholder for exported learner."""
|
|
739
|
+
def __init__(self, device):
|
|
740
|
+
self._device = device
|
|
741
|
+
@property
|
|
742
|
+
def device(self): return self._device
|
|
743
|
+
def to(self, device):
|
|
744
|
+
self._device = device
|
|
745
|
+
return self
|
|
746
|
+
|
|
747
|
+
return EmptyMedPatchDataLoaders(self._device)
|
|
748
|
+
|
|
749
|
+
# %% ../nbs/10_vision_patch.ipynb 20
|
|
750
|
+
import numbers
|
|
751
|
+
|
|
752
|
+
def _normalize_patch_overlap(patch_overlap, patch_size):
|
|
753
|
+
"""Convert patch_overlap to integer pixel values for TorchIO compatibility.
|
|
754
|
+
|
|
755
|
+
TorchIO's GridSampler expects patch_overlap as a tuple of even integers.
|
|
756
|
+
This function handles:
|
|
757
|
+
- Fractional overlap (0-1): converted to pixel values based on patch_size
|
|
758
|
+
- Numpy scalar types: converted to native Python types
|
|
759
|
+
- Sequences: converted to tuple of integers
|
|
760
|
+
|
|
761
|
+
Note: Input validation (negative values, overlap >= patch_size) is handled
|
|
762
|
+
by PatchConfig.__post_init__(). This function focuses on format conversion.
|
|
763
|
+
|
|
764
|
+
Args:
|
|
765
|
+
patch_overlap: int, float (0-1 for fraction), or sequence
|
|
766
|
+
patch_size: list/tuple of patch dimensions [x, y, z]
|
|
767
|
+
|
|
768
|
+
Returns:
|
|
769
|
+
Tuple of even integers suitable for TorchIO GridSampler
|
|
770
|
+
"""
|
|
771
|
+
# Handle scalar fractional overlap (0 < x < 1)
|
|
772
|
+
# Note: excludes 1.0 as 100% overlap creates step_size=0 (infinite patches)
|
|
773
|
+
if isinstance(patch_overlap, (int, float, numbers.Number)) and 0 < float(patch_overlap) < 1:
|
|
774
|
+
# Convert fraction to pixel values, ensure even
|
|
775
|
+
result = []
|
|
776
|
+
for ps in patch_size:
|
|
777
|
+
pixels = int(int(ps) * float(patch_overlap))
|
|
778
|
+
# Ensure even (required by TorchIO)
|
|
779
|
+
if pixels % 2 != 0:
|
|
780
|
+
pixels = pixels - 1 if pixels > 0 else 0
|
|
781
|
+
result.append(pixels)
|
|
782
|
+
return tuple(result)
|
|
783
|
+
|
|
784
|
+
# Handle scalar integer (including numpy scalars) - values > 1 are pixel counts
|
|
785
|
+
if isinstance(patch_overlap, (int, float, numbers.Number)):
|
|
786
|
+
val = int(patch_overlap)
|
|
787
|
+
# Ensure even
|
|
788
|
+
if val % 2 != 0:
|
|
789
|
+
val = val - 1 if val > 0 else 0
|
|
790
|
+
return tuple(val for _ in patch_size)
|
|
791
|
+
|
|
792
|
+
# Handle sequences (list, tuple, ndarray)
|
|
793
|
+
result = []
|
|
794
|
+
for val in patch_overlap:
|
|
795
|
+
pixels = int(val)
|
|
796
|
+
if pixels % 2 != 0:
|
|
797
|
+
pixels = pixels - 1 if pixels > 0 else 0
|
|
798
|
+
result.append(pixels)
|
|
799
|
+
return tuple(result)
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
class PatchInferenceEngine:
|
|
803
|
+
"""Patch-based inference with automatic volume reconstruction.
|
|
804
|
+
|
|
805
|
+
Uses TorchIO's GridSampler to extract overlapping patches and
|
|
806
|
+
GridAggregator to reconstruct the full volume from predictions.
|
|
807
|
+
|
|
808
|
+
Args:
|
|
809
|
+
learner: PyTorch model or fastai Learner.
|
|
810
|
+
config: PatchConfig with inference settings. Preprocessing params (apply_reorder,
|
|
811
|
+
target_spacing, padding_mode) can be set here for DRY usage.
|
|
812
|
+
apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.
|
|
813
|
+
target_spacing: Target voxel spacing. If None, uses config value.
|
|
814
|
+
batch_size: Number of patches to predict at once. Must be positive.
|
|
815
|
+
pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.
|
|
816
|
+
IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).
|
|
817
|
+
This ensures preprocessing consistency between training and inference.
|
|
818
|
+
|
|
819
|
+
Example:
|
|
820
|
+
>>> # DRY pattern: use same config for training and inference
|
|
821
|
+
>>> config = PatchConfig(
|
|
822
|
+
... patch_size=[96, 96, 96],
|
|
823
|
+
... apply_reorder=True,
|
|
824
|
+
... target_spacing=[1.0, 1.0, 1.0]
|
|
825
|
+
... )
|
|
826
|
+
>>> # Training
|
|
827
|
+
>>> dls = MedPatchDataLoaders.from_df(df, 'img', 'mask', patch_config=config)
|
|
828
|
+
>>> # Inference - no need to repeat params!
|
|
829
|
+
>>> engine = PatchInferenceEngine(
|
|
830
|
+
... learn, config,
|
|
831
|
+
... pre_inference_tfms=[tio.ZNormalization()]
|
|
832
|
+
... )
|
|
833
|
+
>>> pred = engine.predict('image.nii.gz')
|
|
834
|
+
"""
|
|
835
|
+
|
|
836
|
+
def __init__(
|
|
837
|
+
self,
|
|
838
|
+
learner,
|
|
839
|
+
config: PatchConfig,
|
|
840
|
+
apply_reorder: bool = None,
|
|
841
|
+
target_spacing: list = None,
|
|
842
|
+
batch_size: int = 4,
|
|
843
|
+
pre_inference_tfms: list = None
|
|
844
|
+
):
|
|
845
|
+
if batch_size <= 0:
|
|
846
|
+
raise ValueError(f"batch_size must be positive, got {batch_size}")
|
|
847
|
+
|
|
848
|
+
# Extract model from Learner if needed
|
|
849
|
+
self.model = learner.model if hasattr(learner, 'model') else learner
|
|
850
|
+
self.config = config
|
|
851
|
+
self.batch_size = batch_size
|
|
852
|
+
self.pre_inference_tfms = tio.Compose(pre_inference_tfms) if pre_inference_tfms else None
|
|
853
|
+
|
|
854
|
+
# Use config values, allow explicit overrides for backward compatibility
|
|
855
|
+
self.apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder
|
|
856
|
+
self.target_spacing = target_spacing if target_spacing is not None else config.target_spacing
|
|
857
|
+
|
|
858
|
+
# Warn if explicit args provided but differ from config (potential mistake)
|
|
859
|
+
_warn_config_override('apply_reorder', config.apply_reorder, apply_reorder)
|
|
860
|
+
_warn_config_override('target_spacing', config.target_spacing, target_spacing)
|
|
861
|
+
|
|
862
|
+
# Get device from model parameters, with fallback for parameter-less models
|
|
863
|
+
try:
|
|
864
|
+
self._device = next(self.model.parameters()).device
|
|
865
|
+
except StopIteration:
|
|
866
|
+
self._device = _get_default_device()
|
|
867
|
+
|
|
868
|
+
def predict(
|
|
869
|
+
self,
|
|
870
|
+
img_path: Path | str,
|
|
871
|
+
return_probabilities: bool = False,
|
|
872
|
+
return_affine: bool = False
|
|
873
|
+
) -> torch.Tensor | tuple[torch.Tensor, np.ndarray]:
|
|
874
|
+
"""Predict on a single volume using patch-based inference.
|
|
875
|
+
|
|
876
|
+
Args:
|
|
877
|
+
img_path: Path to input image.
|
|
878
|
+
return_probabilities: If True, return probability map instead of argmax.
|
|
879
|
+
return_affine: If True, return (prediction, affine) tuple instead of just prediction.
|
|
880
|
+
|
|
881
|
+
Returns:
|
|
882
|
+
Predicted segmentation mask tensor, or tuple (prediction, affine) if return_affine=True.
|
|
883
|
+
"""
|
|
884
|
+
# Load image - keep org_img and org_size for post-processing
|
|
885
|
+
# Note: med_img_reader handles reorder/resample internally, no global state needed
|
|
886
|
+
org_img, input_img, org_size = med_img_reader(
|
|
887
|
+
img_path, apply_reorder=self.apply_reorder, target_spacing=self.target_spacing, only_tensor=False
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
# Create TorchIO Subject from preprocessed image
|
|
891
|
+
subject = tio.Subject(
|
|
892
|
+
image=tio.ScalarImage(tensor=input_img.data.float(), affine=input_img.affine)
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
# Apply pre-inference transforms (e.g., ZNormalization) to match training
|
|
896
|
+
if self.pre_inference_tfms is not None:
|
|
897
|
+
subject = self.pre_inference_tfms(subject)
|
|
898
|
+
|
|
899
|
+
# Pad dimensions smaller than patch_size, keep larger dimensions intact
|
|
900
|
+
# GridSampler handles large images via overlapping patches
|
|
901
|
+
img_shape = subject['image'].shape[1:] # Exclude channel dim
|
|
902
|
+
target_size = [max(s, p) for s, p in zip(img_shape, self.config.patch_size)]
|
|
903
|
+
|
|
904
|
+
# Warn if volume needed padding (may cause artifacts if training didn't cover similar sizes)
|
|
905
|
+
if any(s < p for s, p in zip(img_shape, self.config.patch_size)):
|
|
906
|
+
padded_dims = [f"dim{i}: {s}<{p}" for i, (s, p) in enumerate(zip(img_shape, self.config.patch_size)) if s < p]
|
|
907
|
+
warnings.warn(
|
|
908
|
+
f"Image size {list(img_shape)} smaller than patch_size {self.config.patch_size} "
|
|
909
|
+
f"in {padded_dims}. Padding with mode={self.config.padding_mode}. "
|
|
910
|
+
"Ensure training data covered similar sizes to avoid artifacts."
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
# Use padding_mode from config (default: 0 for zero padding, nnU-Net standard)
|
|
914
|
+
subject = tio.CropOrPad(target_size, padding_mode=self.config.padding_mode)(subject)
|
|
915
|
+
|
|
916
|
+
# Convert patch_overlap to integer pixel values for TorchIO compatibility
|
|
917
|
+
patch_overlap = _normalize_patch_overlap(self.config.patch_overlap, self.config.patch_size)
|
|
918
|
+
|
|
919
|
+
# Create GridSampler
|
|
920
|
+
grid_sampler = tio.GridSampler(
|
|
921
|
+
subject,
|
|
922
|
+
patch_size=self.config.patch_size,
|
|
923
|
+
patch_overlap=patch_overlap
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
# Create GridAggregator
|
|
927
|
+
aggregator = tio.GridAggregator(
|
|
928
|
+
grid_sampler,
|
|
929
|
+
overlap_mode=self.config.aggregation_mode
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
# Create patch loader
|
|
933
|
+
patch_loader = DataLoader(
|
|
934
|
+
grid_sampler,
|
|
935
|
+
batch_size=self.batch_size,
|
|
936
|
+
num_workers=0
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# Predict patches
|
|
940
|
+
self.model.eval()
|
|
941
|
+
with torch.no_grad():
|
|
942
|
+
for patches_batch in patch_loader:
|
|
943
|
+
patch_input = patches_batch['image'][tio.DATA].to(self._device)
|
|
944
|
+
locations = patches_batch[tio.LOCATION]
|
|
945
|
+
|
|
946
|
+
# Forward pass - get logits
|
|
947
|
+
logits = self.model(patch_input)
|
|
948
|
+
|
|
949
|
+
# Convert logits to probabilities BEFORE aggregation
|
|
950
|
+
# This is critical: softmax is non-linear, so we must aggregate
|
|
951
|
+
# probabilities, not logits, to get correct boundary handling
|
|
952
|
+
n_classes = logits.shape[1]
|
|
953
|
+
if n_classes == 1:
|
|
954
|
+
probs = torch.sigmoid(logits)
|
|
955
|
+
else:
|
|
956
|
+
probs = torch.softmax(logits, dim=1) # dim=1 for batch [B, C, H, W, D]
|
|
957
|
+
|
|
958
|
+
# Add probabilities to aggregator
|
|
959
|
+
aggregator.add_batch(probs.cpu(), locations)
|
|
960
|
+
|
|
961
|
+
# Get reconstructed output (now contains probabilities, not logits)
|
|
962
|
+
output = aggregator.get_output_tensor()
|
|
963
|
+
|
|
964
|
+
# Convert to prediction mask (only if not returning probabilities)
|
|
965
|
+
if return_probabilities:
|
|
966
|
+
result = output # Keep as float probabilities
|
|
967
|
+
else:
|
|
968
|
+
n_classes = output.shape[0]
|
|
969
|
+
if n_classes == 1:
|
|
970
|
+
result = (output > 0.5).float()
|
|
971
|
+
else:
|
|
972
|
+
result = output.argmax(dim=0, keepdim=True).float()
|
|
973
|
+
|
|
974
|
+
# Apply keep_largest post-processing for binary segmentation
|
|
975
|
+
if not return_probabilities and self.config.keep_largest_component:
|
|
976
|
+
from fastMONAI.vision_inference import keep_largest
|
|
977
|
+
result = keep_largest(result.squeeze(0)).unsqueeze(0)
|
|
978
|
+
|
|
979
|
+
# Post-processing: resize back to original size and reorient
|
|
980
|
+
# This matches the workflow in vision_inference.py
|
|
981
|
+
|
|
982
|
+
# Wrap result in TorchIO Image for resizing
|
|
983
|
+
# Use ScalarImage for probabilities, LabelMap for masks
|
|
984
|
+
if return_probabilities:
|
|
985
|
+
pred_img = tio.ScalarImage(tensor=result.float(), affine=input_img.affine)
|
|
986
|
+
else:
|
|
987
|
+
pred_img = tio.LabelMap(tensor=result.float(), affine=input_img.affine)
|
|
988
|
+
|
|
989
|
+
# Resize back to original size (before resampling)
|
|
990
|
+
pred_img = _do_resize(pred_img, org_size, image_interpolation='nearest')
|
|
991
|
+
|
|
992
|
+
# Reorient to original orientation (if reorder was applied)
|
|
993
|
+
# Use explicit .cpu() for consistent device handling
|
|
994
|
+
if self.apply_reorder:
|
|
995
|
+
reoriented_array = _to_original_orientation(
|
|
996
|
+
pred_img.as_sitk(),
|
|
997
|
+
('').join(org_img.orientation)
|
|
998
|
+
)
|
|
999
|
+
result = torch.from_numpy(reoriented_array).cpu()
|
|
1000
|
+
# Only convert to long for masks, not probabilities
|
|
1001
|
+
if not return_probabilities:
|
|
1002
|
+
result = result.long()
|
|
1003
|
+
else:
|
|
1004
|
+
result = pred_img.data.cpu()
|
|
1005
|
+
# Only convert to long for masks, not probabilities
|
|
1006
|
+
if not return_probabilities:
|
|
1007
|
+
result = result.long()
|
|
1008
|
+
|
|
1009
|
+
# Use original affine matrix for correct spatial alignment
|
|
1010
|
+
# org_img.affine is always available from med_img_reader
|
|
1011
|
+
if not (hasattr(org_img, 'affine') and org_img.affine is not None):
|
|
1012
|
+
raise RuntimeError(
|
|
1013
|
+
"org_img.affine not available. This should never happen - please report this bug."
|
|
1014
|
+
)
|
|
1015
|
+
affine = org_img.affine.copy()
|
|
1016
|
+
|
|
1017
|
+
if return_affine:
|
|
1018
|
+
return result, affine
|
|
1019
|
+
return result
|
|
1020
|
+
|
|
1021
|
+
def to(self, device):
|
|
1022
|
+
"""Move engine to device."""
|
|
1023
|
+
self._device = device
|
|
1024
|
+
self.model.to(device)
|
|
1025
|
+
return self
|
|
1026
|
+
|
|
1027
|
+
# %% ../nbs/10_vision_patch.ipynb 21
|
|
1028
|
+
def patch_inference(
|
|
1029
|
+
learner,
|
|
1030
|
+
config: PatchConfig,
|
|
1031
|
+
file_paths: list,
|
|
1032
|
+
apply_reorder: bool = None,
|
|
1033
|
+
target_spacing: list = None,
|
|
1034
|
+
batch_size: int = 4,
|
|
1035
|
+
return_probabilities: bool = False,
|
|
1036
|
+
progress: bool = True,
|
|
1037
|
+
save_dir: str = None,
|
|
1038
|
+
pre_inference_tfms: list = None
|
|
1039
|
+
) -> list:
|
|
1040
|
+
"""Batch patch-based inference on multiple volumes.
|
|
1041
|
+
|
|
1042
|
+
Args:
|
|
1043
|
+
learner: PyTorch model or fastai Learner.
|
|
1044
|
+
config: PatchConfig with inference settings. Preprocessing params (apply_reorder,
|
|
1045
|
+
target_spacing) can be set here for DRY usage.
|
|
1046
|
+
file_paths: List of image paths.
|
|
1047
|
+
apply_reorder: Whether to reorder to RAS+ orientation. If None, uses config value.
|
|
1048
|
+
target_spacing: Target voxel spacing. If None, uses config value.
|
|
1049
|
+
batch_size: Patches per batch.
|
|
1050
|
+
return_probabilities: Return probability maps.
|
|
1051
|
+
progress: Show progress bar.
|
|
1052
|
+
save_dir: Directory to save predictions as NIfTI files. If None, predictions are not saved.
|
|
1053
|
+
pre_inference_tfms: List of TorchIO transforms to apply before patch extraction.
|
|
1054
|
+
IMPORTANT: Should match the pre_patch_tfms used during training (e.g., [tio.ZNormalization()]).
|
|
1055
|
+
|
|
1056
|
+
Returns:
|
|
1057
|
+
List of predicted tensors.
|
|
1058
|
+
|
|
1059
|
+
Example:
|
|
1060
|
+
>>> # DRY pattern: use same config for training and inference
|
|
1061
|
+
>>> config = PatchConfig(
|
|
1062
|
+
... patch_size=[96, 96, 96],
|
|
1063
|
+
... apply_reorder=True,
|
|
1064
|
+
... target_spacing=[0.4102, 0.4102, 1.5]
|
|
1065
|
+
... )
|
|
1066
|
+
>>> predictions = patch_inference(
|
|
1067
|
+
... learner=learn,
|
|
1068
|
+
... config=config, # apply_reorder and target_spacing from config
|
|
1069
|
+
... file_paths=val_paths,
|
|
1070
|
+
... pre_inference_tfms=[tio.ZNormalization()],
|
|
1071
|
+
... save_dir='predictions/patch_based'
|
|
1072
|
+
... )
|
|
1073
|
+
"""
|
|
1074
|
+
# Use config values if not explicitly provided
|
|
1075
|
+
_apply_reorder = apply_reorder if apply_reorder is not None else config.apply_reorder
|
|
1076
|
+
_target_spacing = target_spacing if target_spacing is not None else config.target_spacing
|
|
1077
|
+
|
|
1078
|
+
engine = PatchInferenceEngine(
|
|
1079
|
+
learner, config, _apply_reorder, _target_spacing, batch_size, pre_inference_tfms
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
# Create save directory if specified
|
|
1083
|
+
if save_dir is not None:
|
|
1084
|
+
save_path = Path(save_dir)
|
|
1085
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
1086
|
+
|
|
1087
|
+
predictions = []
|
|
1088
|
+
iterator = tqdm(file_paths, desc='Patch inference') if progress else file_paths
|
|
1089
|
+
|
|
1090
|
+
for path in iterator:
|
|
1091
|
+
# Get prediction and affine when saving is needed
|
|
1092
|
+
if save_dir is not None:
|
|
1093
|
+
pred, affine = engine.predict(path, return_probabilities, return_affine=True)
|
|
1094
|
+
else:
|
|
1095
|
+
pred = engine.predict(path, return_probabilities)
|
|
1096
|
+
predictions.append(pred)
|
|
1097
|
+
|
|
1098
|
+
# Save prediction if save_dir specified
|
|
1099
|
+
if save_dir is not None:
|
|
1100
|
+
input_path = Path(path)
|
|
1101
|
+
# Create output filename based on input using suffix-based approach
|
|
1102
|
+
# This handles .nii.gz correctly without corrupting filenames with .nii elsewhere
|
|
1103
|
+
stem = input_path.stem
|
|
1104
|
+
if input_path.suffix == '.gz' and stem.endswith('.nii'):
|
|
1105
|
+
# Handle .nii.gz files: stem is "filename.nii", strip the .nii
|
|
1106
|
+
stem = stem[:-4]
|
|
1107
|
+
out_name = f"{stem}_pred.nii.gz"
|
|
1108
|
+
elif input_path.suffix == '.nii':
|
|
1109
|
+
# Handle .nii files
|
|
1110
|
+
out_name = f"{stem}_pred.nii"
|
|
1111
|
+
else:
|
|
1112
|
+
# Fallback for other formats
|
|
1113
|
+
out_name = f"{stem}_pred.nii.gz"
|
|
1114
|
+
out_path = save_path / out_name
|
|
1115
|
+
|
|
1116
|
+
# affine is guaranteed to be valid from engine.predict() with return_affine=True
|
|
1117
|
+
# Save as NIfTI using TorchIO with correct type
|
|
1118
|
+
# Use ScalarImage for probabilities (float), LabelMap for masks (int)
|
|
1119
|
+
if return_probabilities:
|
|
1120
|
+
pred_img = tio.ScalarImage(tensor=pred, affine=affine)
|
|
1121
|
+
else:
|
|
1122
|
+
pred_img = tio.LabelMap(tensor=pred, affine=affine)
|
|
1123
|
+
pred_img.save(out_path)
|
|
1124
|
+
|
|
1125
|
+
return predictions
|