fastMONAI 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +171 -27
- fastMONAI/dataset_info.py +190 -45
- fastMONAI/external_data.py +1 -1
- fastMONAI/utils.py +101 -18
- fastMONAI/vision_all.py +3 -2
- fastMONAI/vision_augmentation.py +133 -29
- fastMONAI/vision_core.py +29 -132
- fastMONAI/vision_data.py +6 -6
- fastMONAI/vision_inference.py +35 -9
- fastMONAI/vision_metrics.py +420 -19
- fastMONAI/vision_patch.py +1125 -0
- fastMONAI/vision_plot.py +1 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/METADATA +5 -5
- fastmonai-0.5.4.dist-info/RECORD +21 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/WHEEL +1 -1
- fastmonai-0.5.3.dist-info/RECORD +0 -20
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/entry_points.txt +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/top_level.txt +0 -0
fastMONAI/vision_augmentation.py
CHANGED
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
# %% auto 0
|
|
4
4
|
__all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity',
|
|
5
5
|
'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField',
|
|
6
|
-
'RandomBlur', 'RandomGamma', '
|
|
7
|
-
'OneOf']
|
|
6
|
+
'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomElasticDeformation',
|
|
7
|
+
'RandomAffine', 'RandomFlip', 'OneOf']
|
|
8
8
|
|
|
9
9
|
# %% ../nbs/03_vision_augment.ipynb 2
|
|
10
10
|
from fastai.data.all import *
|
|
@@ -14,10 +14,10 @@ from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity
|
|
|
14
14
|
|
|
15
15
|
# %% ../nbs/03_vision_augment.ipynb 5
|
|
16
16
|
class CustomDictTransform(ItemTransform):
|
|
17
|
-
"""A class that serves as a wrapper to perform an identical transformation on both
|
|
17
|
+
"""A class that serves as a wrapper to perform an identical transformation on both
|
|
18
18
|
the image and the target (if it's a mask).
|
|
19
19
|
"""
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
split_idx = 0 # Only perform transformations on training data. Use TTA() for transformations on validation data.
|
|
22
22
|
|
|
23
23
|
def __init__(self, aug):
|
|
@@ -28,31 +28,42 @@ class CustomDictTransform(ItemTransform):
|
|
|
28
28
|
"""
|
|
29
29
|
self.aug = aug
|
|
30
30
|
|
|
31
|
+
@property
|
|
32
|
+
def tio_transform(self):
|
|
33
|
+
"""Return the underlying TorchIO transform.
|
|
34
|
+
|
|
35
|
+
This property enables using fastMONAI wrappers in patch-based workflows
|
|
36
|
+
where raw TorchIO transforms are needed for tio.Compose().
|
|
37
|
+
"""
|
|
38
|
+
return self.aug
|
|
39
|
+
|
|
31
40
|
def encodes(self, x):
|
|
32
41
|
"""
|
|
33
|
-
Applies the stored transformation to an image, and the same random transformation
|
|
42
|
+
Applies the stored transformation to an image, and the same random transformation
|
|
34
43
|
to the target if it is a mask. If the target is not a mask, it returns the target as is.
|
|
35
44
|
|
|
36
45
|
Args:
|
|
37
|
-
x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the
|
|
46
|
+
x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the
|
|
38
47
|
image and the target.
|
|
39
48
|
|
|
40
49
|
Returns:
|
|
41
|
-
Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.
|
|
42
|
-
If the target is a mask, it's transformed identically to the image. If the target
|
|
50
|
+
Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.
|
|
51
|
+
If the target is a mask, it's transformed identically to the image. If the target
|
|
43
52
|
is not a mask, the original target is returned.
|
|
44
53
|
"""
|
|
45
54
|
img, y_true = x
|
|
55
|
+
|
|
56
|
+
# Use identity affine if MedImage.affine_matrix is not set
|
|
57
|
+
affine = MedImage.affine_matrix if MedImage.affine_matrix is not None else np.eye(4)
|
|
46
58
|
|
|
47
59
|
if isinstance(y_true, (MedMask)):
|
|
48
|
-
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=
|
|
49
|
-
mask=tio.LabelMap(tensor=y_true, affine=
|
|
60
|
+
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=affine),
|
|
61
|
+
mask=tio.LabelMap(tensor=y_true, affine=affine)))
|
|
50
62
|
return MedImage.create(aug['img'].data), MedMask.create(aug['mask'].data)
|
|
51
63
|
|
|
52
64
|
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img)))
|
|
53
65
|
return MedImage.create(aug['img'].data), y_true
|
|
54
66
|
|
|
55
|
-
|
|
56
67
|
# %% ../nbs/03_vision_augment.ipynb 7
|
|
57
68
|
def do_pad_or_crop(o, target_shape, padding_mode, mask_name, dtype=torch.Tensor):
|
|
58
69
|
#TODO:refactorize
|
|
@@ -72,6 +83,11 @@ class PadOrCrop(DisplayedTransform):
|
|
|
72
83
|
padding_mode=padding_mode,
|
|
73
84
|
mask_name=mask_name)
|
|
74
85
|
|
|
86
|
+
@property
|
|
87
|
+
def tio_transform(self):
|
|
88
|
+
"""Return the underlying TorchIO transform."""
|
|
89
|
+
return self.pad_or_crop
|
|
90
|
+
|
|
75
91
|
def encodes(self, o: (MedImage, MedMask)):
|
|
76
92
|
return type(o)(self.pad_or_crop(o))
|
|
77
93
|
|
|
@@ -85,6 +101,11 @@ class ZNormalization(DisplayedTransform):
|
|
|
85
101
|
self.z_normalization = tio.ZNormalization(masking_method=masking_method)
|
|
86
102
|
self.channel_wise = channel_wise
|
|
87
103
|
|
|
104
|
+
@property
|
|
105
|
+
def tio_transform(self):
|
|
106
|
+
"""Return the underlying TorchIO transform."""
|
|
107
|
+
return self.z_normalization
|
|
108
|
+
|
|
88
109
|
def encodes(self, o: MedImage):
|
|
89
110
|
try:
|
|
90
111
|
if self.channel_wise:
|
|
@@ -132,7 +153,12 @@ class RescaleIntensity(DisplayedTransform):
|
|
|
132
153
|
|
|
133
154
|
def __init__(self, out_min_max: tuple[float, float], in_min_max: tuple[float, float]):
|
|
134
155
|
self.rescale = tio.RescaleIntensity(out_min_max=out_min_max, in_min_max=in_min_max)
|
|
135
|
-
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def tio_transform(self):
|
|
159
|
+
"""Return the underlying TorchIO transform."""
|
|
160
|
+
return self.rescale
|
|
161
|
+
|
|
136
162
|
def encodes(self, o: MedImage):
|
|
137
163
|
return MedImage.create(self.rescale(o))
|
|
138
164
|
|
|
@@ -211,8 +237,17 @@ class RandomGhosting(DisplayedTransform):
|
|
|
211
237
|
def __init__(self, intensity=(0.5, 1), p=0.5):
|
|
212
238
|
self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)
|
|
213
239
|
|
|
240
|
+
@property
|
|
241
|
+
def tio_transform(self):
|
|
242
|
+
"""Return the underlying TorchIO transform."""
|
|
243
|
+
return self.add_ghosts
|
|
244
|
+
|
|
214
245
|
def encodes(self, o: MedImage):
|
|
215
|
-
|
|
246
|
+
result = self.add_ghosts(o)
|
|
247
|
+
# Handle potential complex values from k-space operations
|
|
248
|
+
if result.is_complex():
|
|
249
|
+
result = torch.real(result)
|
|
250
|
+
return MedImage.create(result)
|
|
216
251
|
|
|
217
252
|
def encodes(self, o: MedMask):
|
|
218
253
|
return o
|
|
@@ -221,26 +256,40 @@ class RandomGhosting(DisplayedTransform):
|
|
|
221
256
|
class RandomSpike(DisplayedTransform):
|
|
222
257
|
'''Apply TorchIO `RandomSpike`.'''
|
|
223
258
|
|
|
224
|
-
split_idx,order=0,1
|
|
259
|
+
split_idx, order = 0, 1
|
|
225
260
|
|
|
226
261
|
def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5):
|
|
227
262
|
self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p)
|
|
228
263
|
|
|
229
|
-
|
|
230
|
-
|
|
264
|
+
@property
|
|
265
|
+
def tio_transform(self):
|
|
266
|
+
"""Return the underlying TorchIO transform."""
|
|
267
|
+
return self.add_spikes
|
|
268
|
+
|
|
269
|
+
def encodes(self, o: MedImage):
|
|
270
|
+
result = self.add_spikes(o)
|
|
271
|
+
# Handle potential complex values from k-space operations
|
|
272
|
+
if result.is_complex():
|
|
273
|
+
result = torch.real(result)
|
|
274
|
+
return MedImage.create(result)
|
|
231
275
|
|
|
232
|
-
def encodes(self, o:MedMask):
|
|
276
|
+
def encodes(self, o: MedMask):
|
|
233
277
|
return o
|
|
234
278
|
|
|
235
279
|
# %% ../nbs/03_vision_augment.ipynb 16
|
|
236
280
|
class RandomNoise(DisplayedTransform):
|
|
237
281
|
'''Apply TorchIO `RandomNoise`.'''
|
|
238
282
|
|
|
239
|
-
split_idx,order=0,1
|
|
283
|
+
split_idx, order = 0, 1
|
|
240
284
|
|
|
241
285
|
def __init__(self, mean=0, std=(0, 0.25), p=0.5):
|
|
242
286
|
self.add_noise = tio.RandomNoise(mean=mean, std=std, p=p)
|
|
243
287
|
|
|
288
|
+
@property
|
|
289
|
+
def tio_transform(self):
|
|
290
|
+
"""Return the underlying TorchIO transform."""
|
|
291
|
+
return self.add_noise
|
|
292
|
+
|
|
244
293
|
def encodes(self, o: MedImage):
|
|
245
294
|
return MedImage.create(self.add_noise(o))
|
|
246
295
|
|
|
@@ -251,11 +300,16 @@ class RandomNoise(DisplayedTransform):
|
|
|
251
300
|
class RandomBiasField(DisplayedTransform):
|
|
252
301
|
'''Apply TorchIO `RandomBiasField`.'''
|
|
253
302
|
|
|
254
|
-
split_idx,order=0,1
|
|
303
|
+
split_idx, order = 0, 1
|
|
255
304
|
|
|
256
305
|
def __init__(self, coefficients=0.5, order=3, p=0.5):
|
|
257
306
|
self.add_biasfield = tio.RandomBiasField(coefficients=coefficients, order=order, p=p)
|
|
258
307
|
|
|
308
|
+
@property
|
|
309
|
+
def tio_transform(self):
|
|
310
|
+
"""Return the underlying TorchIO transform."""
|
|
311
|
+
return self.add_biasfield
|
|
312
|
+
|
|
259
313
|
def encodes(self, o: MedImage):
|
|
260
314
|
return MedImage.create(self.add_biasfield(o))
|
|
261
315
|
|
|
@@ -264,13 +318,18 @@ class RandomBiasField(DisplayedTransform):
|
|
|
264
318
|
|
|
265
319
|
# %% ../nbs/03_vision_augment.ipynb 18
|
|
266
320
|
class RandomBlur(DisplayedTransform):
|
|
267
|
-
'''Apply TorchIO `
|
|
321
|
+
'''Apply TorchIO `RandomBlur`.'''
|
|
268
322
|
|
|
269
|
-
split_idx,order=0,1
|
|
323
|
+
split_idx, order = 0, 1
|
|
270
324
|
|
|
271
325
|
def __init__(self, std=(0, 2), p=0.5):
|
|
272
326
|
self.add_blur = tio.RandomBlur(std=std, p=p)
|
|
273
|
-
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def tio_transform(self):
|
|
330
|
+
"""Return the underlying TorchIO transform."""
|
|
331
|
+
return self.add_blur
|
|
332
|
+
|
|
274
333
|
def encodes(self, o: MedImage):
|
|
275
334
|
return MedImage.create(self.add_blur(o))
|
|
276
335
|
|
|
@@ -281,12 +340,16 @@ class RandomBlur(DisplayedTransform):
|
|
|
281
340
|
class RandomGamma(DisplayedTransform):
|
|
282
341
|
'''Apply TorchIO `RandomGamma`.'''
|
|
283
342
|
|
|
284
|
-
|
|
285
|
-
split_idx,order=0,1
|
|
343
|
+
split_idx, order = 0, 1
|
|
286
344
|
|
|
287
345
|
def __init__(self, log_gamma=(-0.3, 0.3), p=0.5):
|
|
288
346
|
self.add_gamma = tio.RandomGamma(log_gamma=log_gamma, p=p)
|
|
289
347
|
|
|
348
|
+
@property
|
|
349
|
+
def tio_transform(self):
|
|
350
|
+
"""Return the underlying TorchIO transform."""
|
|
351
|
+
return self.add_gamma
|
|
352
|
+
|
|
290
353
|
def encodes(self, o: MedImage):
|
|
291
354
|
return MedImage.create(self.add_gamma(o))
|
|
292
355
|
|
|
@@ -294,6 +357,38 @@ class RandomGamma(DisplayedTransform):
|
|
|
294
357
|
return o
|
|
295
358
|
|
|
296
359
|
# %% ../nbs/03_vision_augment.ipynb 20
|
|
360
|
+
class RandomIntensityScale(DisplayedTransform):
|
|
361
|
+
"""Randomly scale image intensities by a multiplicative factor.
|
|
362
|
+
|
|
363
|
+
Useful for domain generalization across different acquisition protocols
|
|
364
|
+
with varying intensity ranges.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
scale_range (tuple[float, float]): Range of scale factors (min, max).
|
|
368
|
+
Values > 1 increase intensity, < 1 decrease intensity.
|
|
369
|
+
p (float): Probability of applying the transform (default: 0.5)
|
|
370
|
+
|
|
371
|
+
Example:
|
|
372
|
+
# Scale intensities randomly between 0.5x and 2.0x
|
|
373
|
+
transform = RandomIntensityScale(scale_range=(0.5, 2.0), p=0.3)
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
split_idx, order = 0, 1
|
|
377
|
+
|
|
378
|
+
def __init__(self, scale_range: tuple[float, float] = (0.5, 2.0), p: float = 0.5):
|
|
379
|
+
self.scale_range = scale_range
|
|
380
|
+
self.p = p
|
|
381
|
+
|
|
382
|
+
def encodes(self, o: MedImage):
|
|
383
|
+
if torch.rand(1).item() > self.p:
|
|
384
|
+
return o
|
|
385
|
+
scale = torch.empty(1).uniform_(self.scale_range[0], self.scale_range[1]).item()
|
|
386
|
+
return MedImage.create(o * scale)
|
|
387
|
+
|
|
388
|
+
def encodes(self, o: MedMask):
|
|
389
|
+
return o
|
|
390
|
+
|
|
391
|
+
# %% ../nbs/03_vision_augment.ipynb 21
|
|
297
392
|
class RandomMotion(DisplayedTransform):
|
|
298
393
|
"""Apply TorchIO `RandomMotion`."""
|
|
299
394
|
|
|
@@ -315,13 +410,22 @@ class RandomMotion(DisplayedTransform):
|
|
|
315
410
|
p=p
|
|
316
411
|
)
|
|
317
412
|
|
|
413
|
+
@property
|
|
414
|
+
def tio_transform(self):
|
|
415
|
+
"""Return the underlying TorchIO transform."""
|
|
416
|
+
return self.add_motion
|
|
417
|
+
|
|
318
418
|
def encodes(self, o: MedImage):
|
|
319
|
-
|
|
419
|
+
result = self.add_motion(o)
|
|
420
|
+
# Handle potential complex values from k-space operations
|
|
421
|
+
if result.is_complex():
|
|
422
|
+
result = torch.real(result)
|
|
423
|
+
return MedImage.create(result)
|
|
320
424
|
|
|
321
425
|
def encodes(self, o: MedMask):
|
|
322
426
|
return o
|
|
323
427
|
|
|
324
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
428
|
+
# %% ../nbs/03_vision_augment.ipynb 23
|
|
325
429
|
class RandomElasticDeformation(CustomDictTransform):
|
|
326
430
|
"""Apply TorchIO `RandomElasticDeformation`."""
|
|
327
431
|
|
|
@@ -334,7 +438,7 @@ class RandomElasticDeformation(CustomDictTransform):
|
|
|
334
438
|
image_interpolation=image_interpolation,
|
|
335
439
|
p=p))
|
|
336
440
|
|
|
337
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
441
|
+
# %% ../nbs/03_vision_augment.ipynb 24
|
|
338
442
|
class RandomAffine(CustomDictTransform):
|
|
339
443
|
"""Apply TorchIO `RandomAffine`."""
|
|
340
444
|
|
|
@@ -350,14 +454,14 @@ class RandomAffine(CustomDictTransform):
|
|
|
350
454
|
default_pad_value=default_pad_value,
|
|
351
455
|
p=p))
|
|
352
456
|
|
|
353
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
457
|
+
# %% ../nbs/03_vision_augment.ipynb 25
|
|
354
458
|
class RandomFlip(CustomDictTransform):
|
|
355
459
|
"""Apply TorchIO `RandomFlip`."""
|
|
356
460
|
|
|
357
461
|
def __init__(self, axes='LR', p=0.5):
|
|
358
462
|
super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))
|
|
359
463
|
|
|
360
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
464
|
+
# %% ../nbs/03_vision_augment.ipynb 26
|
|
361
465
|
class OneOf(CustomDictTransform):
|
|
362
466
|
"""Apply only one of the given transforms using TorchIO `OneOf`."""
|
|
363
467
|
|
fastMONAI/vision_core.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_vision_core.ipynb.
|
|
2
2
|
|
|
3
3
|
# %% auto 0
|
|
4
|
-
__all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask'
|
|
4
|
+
__all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask']
|
|
5
5
|
|
|
6
6
|
# %% ../nbs/01_vision_core.ipynb 2
|
|
7
7
|
from .vision_plot import *
|
|
@@ -10,26 +10,26 @@ from torchio import ScalarImage, LabelMap, ToCanonical, Resample
|
|
|
10
10
|
import copy
|
|
11
11
|
|
|
12
12
|
# %% ../nbs/01_vision_core.ipynb 5
|
|
13
|
-
def _preprocess(obj,
|
|
13
|
+
def _preprocess(obj, apply_reorder, target_spacing):
|
|
14
14
|
"""
|
|
15
15
|
Preprocesses the given object.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
18
18
|
obj: The object to preprocess.
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
apply_reorder: Whether to reorder the object.
|
|
20
|
+
target_spacing: Whether to resample the object.
|
|
21
21
|
|
|
22
22
|
Returns:
|
|
23
23
|
The preprocessed object and its original size.
|
|
24
24
|
"""
|
|
25
|
-
if
|
|
25
|
+
if apply_reorder:
|
|
26
26
|
transform = ToCanonical()
|
|
27
27
|
obj = transform(obj)
|
|
28
28
|
|
|
29
29
|
original_size = obj.shape[1:]
|
|
30
30
|
|
|
31
|
-
if
|
|
32
|
-
transform = Resample(
|
|
31
|
+
if target_spacing and not all(np.isclose(obj.spacing, target_spacing)):
|
|
32
|
+
transform = Resample(target_spacing)
|
|
33
33
|
obj = transform(obj)
|
|
34
34
|
|
|
35
35
|
if MedBase.affine_matrix is None:
|
|
@@ -38,33 +38,33 @@ def _preprocess(obj, reorder, resample):
|
|
|
38
38
|
return obj, original_size
|
|
39
39
|
|
|
40
40
|
# %% ../nbs/01_vision_core.ipynb 6
|
|
41
|
-
def _load_and_preprocess(file_path,
|
|
41
|
+
def _load_and_preprocess(file_path, apply_reorder, target_spacing, dtype):
|
|
42
42
|
"""
|
|
43
43
|
Helper function to load and preprocess an image.
|
|
44
44
|
|
|
45
45
|
Args:
|
|
46
46
|
file_path: Image file path.
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
apply_reorder: Whether to reorder data for canonical (RAS+) orientation.
|
|
48
|
+
target_spacing: Whether to resample image to different voxel sizes and dimensions.
|
|
49
49
|
dtype: Desired datatype for output.
|
|
50
50
|
|
|
51
51
|
Returns:
|
|
52
52
|
tuple: Original image, preprocessed image, and its original size.
|
|
53
53
|
"""
|
|
54
54
|
org_img = LabelMap(file_path) if dtype is MedMask else ScalarImage(file_path) #_load(file_path, dtype=dtype)
|
|
55
|
-
input_img, org_size = _preprocess(org_img,
|
|
55
|
+
input_img, org_size = _preprocess(org_img, apply_reorder, target_spacing)
|
|
56
56
|
|
|
57
57
|
return org_img, input_img, org_size
|
|
58
58
|
|
|
59
59
|
# %% ../nbs/01_vision_core.ipynb 7
|
|
60
|
-
def _multi_channel(image_paths: L | list,
|
|
60
|
+
def _multi_channel(image_paths: L | list, apply_reorder: bool, target_spacing: list, only_tensor: bool, dtype):
|
|
61
61
|
"""
|
|
62
62
|
Load and preprocess multisequence data.
|
|
63
63
|
|
|
64
64
|
Args:
|
|
65
65
|
image_paths: List of image paths (e.g., T1, T2, T1CE, DWI).
|
|
66
|
-
|
|
67
|
-
|
|
66
|
+
apply_reorder: Whether to reorder data for canonical (RAS+) orientation.
|
|
67
|
+
target_spacing: Whether to resample image to different voxel sizes and dimensions.
|
|
68
68
|
only_tensor: Whether to return only image tensor.
|
|
69
69
|
dtype: Desired datatype for output.
|
|
70
70
|
|
|
@@ -72,7 +72,7 @@ def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_te
|
|
|
72
72
|
torch.Tensor: A stacked 4D tensor, if `only_tensor` is True.
|
|
73
73
|
tuple: Original image, preprocessed image, original size, if `only_tensor` is False.
|
|
74
74
|
"""
|
|
75
|
-
image_data = [_load_and_preprocess(image,
|
|
75
|
+
image_data = [_load_and_preprocess(image, apply_reorder, target_spacing, dtype) for image in image_paths]
|
|
76
76
|
org_img, input_img, org_size = image_data[-1]
|
|
77
77
|
|
|
78
78
|
tensor = torch.stack([img.data[0] for _, img, _ in image_data], dim=0)
|
|
@@ -84,15 +84,15 @@ def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_te
|
|
|
84
84
|
return org_img, input_img, org_size
|
|
85
85
|
|
|
86
86
|
# %% ../nbs/01_vision_core.ipynb 8
|
|
87
|
-
def med_img_reader(file_path: str | Path | L | list,
|
|
87
|
+
def med_img_reader(file_path: str | Path | L | list, apply_reorder: bool = False, target_spacing: list = None,
|
|
88
88
|
only_tensor: bool = True, dtype = torch.Tensor):
|
|
89
89
|
"""Loads and preprocesses a medical image.
|
|
90
90
|
|
|
91
91
|
Args:
|
|
92
92
|
file_path: Path to the image. Can be a string, Path object or a list.
|
|
93
|
-
|
|
93
|
+
apply_reorder: Whether to reorder the data to be closest to canonical
|
|
94
94
|
(RAS+) orientation. Defaults to False.
|
|
95
|
-
|
|
95
|
+
target_spacing: Whether to resample image to different voxel sizes and
|
|
96
96
|
image dimensions. Defaults to None.
|
|
97
97
|
only_tensor: Whether to return only image tensor. Defaults to True.
|
|
98
98
|
dtype: Datatype for the return value. Defaults to torch.Tensor.
|
|
@@ -104,10 +104,10 @@ def med_img_reader(file_path: str | Path | L | list, reorder: bool = False, resa
|
|
|
104
104
|
"""
|
|
105
105
|
|
|
106
106
|
if isinstance(file_path, (list, L)):
|
|
107
|
-
return _multi_channel(file_path,
|
|
107
|
+
return _multi_channel(file_path, apply_reorder, target_spacing, only_tensor, dtype)
|
|
108
108
|
|
|
109
109
|
org_img, input_img, org_size = _load_and_preprocess(
|
|
110
|
-
file_path,
|
|
110
|
+
file_path, apply_reorder, target_spacing, dtype)
|
|
111
111
|
|
|
112
112
|
if only_tensor:
|
|
113
113
|
return dtype(input_img.data.type(torch.float))
|
|
@@ -129,7 +129,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
129
129
|
|
|
130
130
|
_bypass_type = torch.Tensor
|
|
131
131
|
_show_args = {'cmap':'gray'}
|
|
132
|
-
|
|
132
|
+
target_spacing, apply_reorder = None, False
|
|
133
133
|
affine_matrix = None
|
|
134
134
|
|
|
135
135
|
@classmethod
|
|
@@ -150,7 +150,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
150
150
|
if isinstance(fn, torch.Tensor):
|
|
151
151
|
return cls(fn)
|
|
152
152
|
|
|
153
|
-
return med_img_reader(fn,
|
|
153
|
+
return med_img_reader(fn, target_spacing=cls.target_spacing, apply_reorder=cls.apply_reorder, dtype=cls)
|
|
154
154
|
|
|
155
155
|
def __new__(cls, x, **kwargs):
|
|
156
156
|
"""Creates a new instance of MedBase from a tensor."""
|
|
@@ -196,18 +196,18 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
196
196
|
return copied
|
|
197
197
|
|
|
198
198
|
@classmethod
|
|
199
|
-
def item_preprocessing(cls,
|
|
199
|
+
def item_preprocessing(cls, target_spacing: (list, int, tuple), apply_reorder: bool):
|
|
200
200
|
"""
|
|
201
|
-
Changes the values for the class variables `
|
|
201
|
+
Changes the values for the class variables `target_spacing` and `apply_reorder`.
|
|
202
202
|
|
|
203
203
|
Args:
|
|
204
|
-
|
|
204
|
+
target_spacing : (list, int, tuple)
|
|
205
205
|
A list with voxel spacing.
|
|
206
|
-
|
|
206
|
+
apply_reorder : bool
|
|
207
207
|
Whether to reorder the data to be closest to canonical (RAS+) orientation.
|
|
208
208
|
"""
|
|
209
|
-
cls.
|
|
210
|
-
cls.
|
|
209
|
+
cls.target_spacing = target_spacing
|
|
210
|
+
cls.apply_reorder = apply_reorder
|
|
211
211
|
|
|
212
212
|
def show(self, ctx=None, channel: int = 0, slice_index: int = None, anatomical_plane: int = 0, **kwargs):
|
|
213
213
|
"""
|
|
@@ -230,7 +230,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
230
230
|
"""
|
|
231
231
|
return show_med_img(
|
|
232
232
|
self, ctx=ctx, channel=channel, slice_index=slice_index,
|
|
233
|
-
anatomical_plane=anatomical_plane, voxel_size=self.
|
|
233
|
+
anatomical_plane=anatomical_plane, voxel_size=self.target_spacing,
|
|
234
234
|
**merge(self._show_args, kwargs)
|
|
235
235
|
)
|
|
236
236
|
|
|
@@ -247,106 +247,3 @@ class MedImage(MedBase):
|
|
|
247
247
|
class MedMask(MedBase):
|
|
248
248
|
"""Subclass of MedBase that represents an mask object."""
|
|
249
249
|
_show_args = {'alpha':0.5, 'cmap':'tab20'}
|
|
250
|
-
|
|
251
|
-
# %% ../nbs/01_vision_core.ipynb 14
|
|
252
|
-
import os
|
|
253
|
-
from fastai.callback.progress import ProgressCallback
|
|
254
|
-
from fastai.callback.core import Callback
|
|
255
|
-
import sys
|
|
256
|
-
from IPython import get_ipython
|
|
257
|
-
|
|
258
|
-
class VSCodeProgressCallback(ProgressCallback):
|
|
259
|
-
"""Enhanced progress callback that works better in VS Code notebooks."""
|
|
260
|
-
|
|
261
|
-
def __init__(self, **kwargs):
|
|
262
|
-
super().__init__(**kwargs)
|
|
263
|
-
self.is_vscode = self._detect_vscode_environment()
|
|
264
|
-
self.lr_find_progress = None
|
|
265
|
-
|
|
266
|
-
def _detect_vscode_environment(self):
|
|
267
|
-
"""Detect if running in VS Code Jupyter environment."""
|
|
268
|
-
ipython = get_ipython()
|
|
269
|
-
if ipython is None:
|
|
270
|
-
return True # Assume VS Code if no IPython (safer default)
|
|
271
|
-
# VS Code detection - more comprehensive check
|
|
272
|
-
kernel_name = str(type(ipython.kernel)).lower() if hasattr(ipython, 'kernel') else ''
|
|
273
|
-
return ('vscode' in kernel_name or
|
|
274
|
-
'zmq' in kernel_name or # VS Code often uses ZMQInteractiveShell
|
|
275
|
-
not hasattr(ipython, 'display_pub')) # Missing display publisher often indicates VS Code
|
|
276
|
-
|
|
277
|
-
def before_fit(self):
|
|
278
|
-
"""Initialize progress tracking before training."""
|
|
279
|
-
if self.is_vscode:
|
|
280
|
-
if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
|
|
281
|
-
# This is lr_find, handle differently
|
|
282
|
-
print("🔍 Starting Learning Rate Finder...")
|
|
283
|
-
self.lr_find_progress = 0
|
|
284
|
-
else:
|
|
285
|
-
# Regular training
|
|
286
|
-
print(f"🚀 Training for {self.learn.n_epoch} epochs...")
|
|
287
|
-
super().before_fit()
|
|
288
|
-
|
|
289
|
-
def before_epoch(self):
|
|
290
|
-
"""Initialize epoch progress."""
|
|
291
|
-
if self.is_vscode:
|
|
292
|
-
if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
|
|
293
|
-
print(f"📊 LR Find - Testing learning rates...")
|
|
294
|
-
else:
|
|
295
|
-
print(f"📈 Epoch {self.epoch+1}/{self.learn.n_epoch}")
|
|
296
|
-
sys.stdout.flush()
|
|
297
|
-
super().before_epoch()
|
|
298
|
-
|
|
299
|
-
def after_batch(self):
|
|
300
|
-
"""Update progress after each batch."""
|
|
301
|
-
super().after_batch()
|
|
302
|
-
if self.is_vscode:
|
|
303
|
-
if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
|
|
304
|
-
# Special handling for lr_find
|
|
305
|
-
self.lr_find_progress = getattr(self, 'iter', 0) + 1
|
|
306
|
-
total = getattr(self, 'n_iter', 100)
|
|
307
|
-
if self.lr_find_progress % max(1, total // 10) == 0:
|
|
308
|
-
progress = (self.lr_find_progress / total) * 100
|
|
309
|
-
print(f"⏳ LR Find Progress: {self.lr_find_progress}/{total} ({progress:.1f}%)")
|
|
310
|
-
sys.stdout.flush()
|
|
311
|
-
else:
|
|
312
|
-
# Regular training progress
|
|
313
|
-
if hasattr(self, 'iter') and hasattr(self, 'n_iter'):
|
|
314
|
-
if self.iter % max(1, self.n_iter // 20) == 0:
|
|
315
|
-
progress = (self.iter / self.n_iter) * 100
|
|
316
|
-
print(f"⏳ Batch {self.iter}/{self.n_iter} ({progress:.1f}%)")
|
|
317
|
-
sys.stdout.flush()
|
|
318
|
-
|
|
319
|
-
def after_fit(self):
|
|
320
|
-
"""Complete progress tracking after training."""
|
|
321
|
-
if self.is_vscode:
|
|
322
|
-
if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
|
|
323
|
-
print("✅ Learning Rate Finder completed!")
|
|
324
|
-
else:
|
|
325
|
-
print("✅ Training completed!")
|
|
326
|
-
sys.stdout.flush()
|
|
327
|
-
super().after_fit()
|
|
328
|
-
|
|
329
|
-
def before_validate(self):
|
|
330
|
-
"""Update before validation."""
|
|
331
|
-
if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):
|
|
332
|
-
print("🔄 Validating...")
|
|
333
|
-
sys.stdout.flush()
|
|
334
|
-
super().before_validate()
|
|
335
|
-
|
|
336
|
-
def after_validate(self):
|
|
337
|
-
"""Update after validation."""
|
|
338
|
-
if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):
|
|
339
|
-
print("✅ Validation completed")
|
|
340
|
-
sys.stdout.flush()
|
|
341
|
-
super().after_validate()
|
|
342
|
-
|
|
343
|
-
def setup_vscode_progress():
|
|
344
|
-
"""Configure fastai to use VS Code-compatible progress callback."""
|
|
345
|
-
from fastai.learner import defaults
|
|
346
|
-
|
|
347
|
-
# Replace default ProgressCallback with VSCodeProgressCallback
|
|
348
|
-
if ProgressCallback in defaults.callbacks:
|
|
349
|
-
defaults.callbacks = [cb if cb != ProgressCallback else VSCodeProgressCallback
|
|
350
|
-
for cb in defaults.callbacks]
|
|
351
|
-
|
|
352
|
-
print("✅ Configured VS Code-compatible progress callback")
|
fastMONAI/vision_data.py
CHANGED
|
@@ -27,7 +27,7 @@ def pred_to_multiclass_mask(pred: torch.Tensor) -> torch.Tensor:
|
|
|
27
27
|
|
|
28
28
|
pred = pred.softmax(dim=0)
|
|
29
29
|
|
|
30
|
-
return pred.argmax(dim=0,
|
|
30
|
+
return pred.argmax(dim=0, keepdim=True)
|
|
31
31
|
|
|
32
32
|
# %% ../nbs/02_vision_data.ipynb 6
|
|
33
33
|
def batch_pred_to_multiclass_mask(pred: torch.Tensor) -> (torch.Tensor, int):
|
|
@@ -68,12 +68,12 @@ class MedDataBlock(DataBlock):
|
|
|
68
68
|
#TODO add get_x
|
|
69
69
|
def __init__(self, blocks: list = None, dl_type: TfmdDL = None, getters: list = None,
|
|
70
70
|
n_inp: int | None = None, item_tfms: list = None, batch_tfms: list = None,
|
|
71
|
-
|
|
71
|
+
apply_reorder: bool = False, target_spacing: (int, list) = None, **kwargs):
|
|
72
72
|
|
|
73
73
|
super().__init__(blocks, dl_type, getters, n_inp, item_tfms,
|
|
74
74
|
batch_tfms, **kwargs)
|
|
75
75
|
|
|
76
|
-
MedBase.item_preprocessing(
|
|
76
|
+
MedBase.item_preprocessing(target_spacing, apply_reorder)
|
|
77
77
|
|
|
78
78
|
# %% ../nbs/02_vision_data.ipynb 11
|
|
79
79
|
def MedMaskBlock():
|
|
@@ -88,7 +88,7 @@ class MedImageDataLoaders(DataLoaders):
|
|
|
88
88
|
@delegates(DataLoaders.from_dblock)
|
|
89
89
|
def from_df(cls, df, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='',
|
|
90
90
|
label_col=1, label_delim=None, y_block=None, valid_col=None,
|
|
91
|
-
item_tfms=None, batch_tfms=None,
|
|
91
|
+
item_tfms=None, batch_tfms=None, apply_reorder=False, target_spacing=None, **kwargs):
|
|
92
92
|
"""Create from DataFrame."""
|
|
93
93
|
|
|
94
94
|
if y_block is None:
|
|
@@ -104,8 +104,8 @@ class MedImageDataLoaders(DataLoaders):
|
|
|
104
104
|
get_y=ColReader(label_col, label_delim=label_delim),
|
|
105
105
|
splitter=splitter,
|
|
106
106
|
item_tfms=item_tfms,
|
|
107
|
-
|
|
108
|
-
|
|
107
|
+
apply_reorder=apply_reorder,
|
|
108
|
+
target_spacing=target_spacing
|
|
109
109
|
)
|
|
110
110
|
|
|
111
111
|
return cls.from_dblock(dblock, df, **kwargs)
|