fastMONAI 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +224 -28
- fastMONAI/dataset_info.py +329 -47
- fastMONAI/external_data.py +1 -1
- fastMONAI/utils.py +394 -22
- fastMONAI/vision_all.py +3 -2
- fastMONAI/vision_augmentation.py +264 -28
- fastMONAI/vision_core.py +29 -132
- fastMONAI/vision_data.py +6 -6
- fastMONAI/vision_inference.py +35 -9
- fastMONAI/vision_metrics.py +420 -19
- fastMONAI/vision_patch.py +1259 -0
- fastMONAI/vision_plot.py +90 -1
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/METADATA +5 -5
- fastmonai-0.6.0.dist-info/RECORD +21 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/WHEEL +1 -1
- fastmonai-0.5.3.dist-info/RECORD +0 -20
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/entry_points.txt +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.6.0.dist-info}/top_level.txt +0 -0
fastMONAI/vision_core.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_vision_core.ipynb.
|
|
2
2
|
|
|
3
3
|
# %% auto 0
|
|
4
|
-
__all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask'
|
|
4
|
+
__all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask']
|
|
5
5
|
|
|
6
6
|
# %% ../nbs/01_vision_core.ipynb 2
|
|
7
7
|
from .vision_plot import *
|
|
@@ -10,26 +10,26 @@ from torchio import ScalarImage, LabelMap, ToCanonical, Resample
|
|
|
10
10
|
import copy
|
|
11
11
|
|
|
12
12
|
# %% ../nbs/01_vision_core.ipynb 5
|
|
13
|
-
def _preprocess(obj,
|
|
13
|
+
def _preprocess(obj, apply_reorder, target_spacing):
|
|
14
14
|
"""
|
|
15
15
|
Preprocesses the given object.
|
|
16
16
|
|
|
17
17
|
Args:
|
|
18
18
|
obj: The object to preprocess.
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
apply_reorder: Whether to reorder the object.
|
|
20
|
+
target_spacing: Whether to resample the object.
|
|
21
21
|
|
|
22
22
|
Returns:
|
|
23
23
|
The preprocessed object and its original size.
|
|
24
24
|
"""
|
|
25
|
-
if
|
|
25
|
+
if apply_reorder:
|
|
26
26
|
transform = ToCanonical()
|
|
27
27
|
obj = transform(obj)
|
|
28
28
|
|
|
29
29
|
original_size = obj.shape[1:]
|
|
30
30
|
|
|
31
|
-
if
|
|
32
|
-
transform = Resample(
|
|
31
|
+
if target_spacing and not all(np.isclose(obj.spacing, target_spacing)):
|
|
32
|
+
transform = Resample(target_spacing)
|
|
33
33
|
obj = transform(obj)
|
|
34
34
|
|
|
35
35
|
if MedBase.affine_matrix is None:
|
|
@@ -38,33 +38,33 @@ def _preprocess(obj, reorder, resample):
|
|
|
38
38
|
return obj, original_size
|
|
39
39
|
|
|
40
40
|
# %% ../nbs/01_vision_core.ipynb 6
|
|
41
|
-
def _load_and_preprocess(file_path,
|
|
41
|
+
def _load_and_preprocess(file_path, apply_reorder, target_spacing, dtype):
|
|
42
42
|
"""
|
|
43
43
|
Helper function to load and preprocess an image.
|
|
44
44
|
|
|
45
45
|
Args:
|
|
46
46
|
file_path: Image file path.
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
apply_reorder: Whether to reorder data for canonical (RAS+) orientation.
|
|
48
|
+
target_spacing: Whether to resample image to different voxel sizes and dimensions.
|
|
49
49
|
dtype: Desired datatype for output.
|
|
50
50
|
|
|
51
51
|
Returns:
|
|
52
52
|
tuple: Original image, preprocessed image, and its original size.
|
|
53
53
|
"""
|
|
54
54
|
org_img = LabelMap(file_path) if dtype is MedMask else ScalarImage(file_path) #_load(file_path, dtype=dtype)
|
|
55
|
-
input_img, org_size = _preprocess(org_img,
|
|
55
|
+
input_img, org_size = _preprocess(org_img, apply_reorder, target_spacing)
|
|
56
56
|
|
|
57
57
|
return org_img, input_img, org_size
|
|
58
58
|
|
|
59
59
|
# %% ../nbs/01_vision_core.ipynb 7
|
|
60
|
-
def _multi_channel(image_paths: L | list,
|
|
60
|
+
def _multi_channel(image_paths: L | list, apply_reorder: bool, target_spacing: list, only_tensor: bool, dtype):
|
|
61
61
|
"""
|
|
62
62
|
Load and preprocess multisequence data.
|
|
63
63
|
|
|
64
64
|
Args:
|
|
65
65
|
image_paths: List of image paths (e.g., T1, T2, T1CE, DWI).
|
|
66
|
-
|
|
67
|
-
|
|
66
|
+
apply_reorder: Whether to reorder data for canonical (RAS+) orientation.
|
|
67
|
+
target_spacing: Whether to resample image to different voxel sizes and dimensions.
|
|
68
68
|
only_tensor: Whether to return only image tensor.
|
|
69
69
|
dtype: Desired datatype for output.
|
|
70
70
|
|
|
@@ -72,7 +72,7 @@ def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_te
|
|
|
72
72
|
torch.Tensor: A stacked 4D tensor, if `only_tensor` is True.
|
|
73
73
|
tuple: Original image, preprocessed image, original size, if `only_tensor` is False.
|
|
74
74
|
"""
|
|
75
|
-
image_data = [_load_and_preprocess(image,
|
|
75
|
+
image_data = [_load_and_preprocess(image, apply_reorder, target_spacing, dtype) for image in image_paths]
|
|
76
76
|
org_img, input_img, org_size = image_data[-1]
|
|
77
77
|
|
|
78
78
|
tensor = torch.stack([img.data[0] for _, img, _ in image_data], dim=0)
|
|
@@ -84,15 +84,15 @@ def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_te
|
|
|
84
84
|
return org_img, input_img, org_size
|
|
85
85
|
|
|
86
86
|
# %% ../nbs/01_vision_core.ipynb 8
|
|
87
|
-
def med_img_reader(file_path: str | Path | L | list,
|
|
87
|
+
def med_img_reader(file_path: str | Path | L | list, apply_reorder: bool = False, target_spacing: list = None,
|
|
88
88
|
only_tensor: bool = True, dtype = torch.Tensor):
|
|
89
89
|
"""Loads and preprocesses a medical image.
|
|
90
90
|
|
|
91
91
|
Args:
|
|
92
92
|
file_path: Path to the image. Can be a string, Path object or a list.
|
|
93
|
-
|
|
93
|
+
apply_reorder: Whether to reorder the data to be closest to canonical
|
|
94
94
|
(RAS+) orientation. Defaults to False.
|
|
95
|
-
|
|
95
|
+
target_spacing: Whether to resample image to different voxel sizes and
|
|
96
96
|
image dimensions. Defaults to None.
|
|
97
97
|
only_tensor: Whether to return only image tensor. Defaults to True.
|
|
98
98
|
dtype: Datatype for the return value. Defaults to torch.Tensor.
|
|
@@ -104,10 +104,10 @@ def med_img_reader(file_path: str | Path | L | list, reorder: bool = False, resa
|
|
|
104
104
|
"""
|
|
105
105
|
|
|
106
106
|
if isinstance(file_path, (list, L)):
|
|
107
|
-
return _multi_channel(file_path,
|
|
107
|
+
return _multi_channel(file_path, apply_reorder, target_spacing, only_tensor, dtype)
|
|
108
108
|
|
|
109
109
|
org_img, input_img, org_size = _load_and_preprocess(
|
|
110
|
-
file_path,
|
|
110
|
+
file_path, apply_reorder, target_spacing, dtype)
|
|
111
111
|
|
|
112
112
|
if only_tensor:
|
|
113
113
|
return dtype(input_img.data.type(torch.float))
|
|
@@ -129,7 +129,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
129
129
|
|
|
130
130
|
_bypass_type = torch.Tensor
|
|
131
131
|
_show_args = {'cmap':'gray'}
|
|
132
|
-
|
|
132
|
+
target_spacing, apply_reorder = None, False
|
|
133
133
|
affine_matrix = None
|
|
134
134
|
|
|
135
135
|
@classmethod
|
|
@@ -150,7 +150,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
150
150
|
if isinstance(fn, torch.Tensor):
|
|
151
151
|
return cls(fn)
|
|
152
152
|
|
|
153
|
-
return med_img_reader(fn,
|
|
153
|
+
return med_img_reader(fn, target_spacing=cls.target_spacing, apply_reorder=cls.apply_reorder, dtype=cls)
|
|
154
154
|
|
|
155
155
|
def __new__(cls, x, **kwargs):
|
|
156
156
|
"""Creates a new instance of MedBase from a tensor."""
|
|
@@ -196,18 +196,18 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
196
196
|
return copied
|
|
197
197
|
|
|
198
198
|
@classmethod
|
|
199
|
-
def item_preprocessing(cls,
|
|
199
|
+
def item_preprocessing(cls, target_spacing: (list, int, tuple), apply_reorder: bool):
|
|
200
200
|
"""
|
|
201
|
-
Changes the values for the class variables `
|
|
201
|
+
Changes the values for the class variables `target_spacing` and `apply_reorder`.
|
|
202
202
|
|
|
203
203
|
Args:
|
|
204
|
-
|
|
204
|
+
target_spacing : (list, int, tuple)
|
|
205
205
|
A list with voxel spacing.
|
|
206
|
-
|
|
206
|
+
apply_reorder : bool
|
|
207
207
|
Whether to reorder the data to be closest to canonical (RAS+) orientation.
|
|
208
208
|
"""
|
|
209
|
-
cls.
|
|
210
|
-
cls.
|
|
209
|
+
cls.target_spacing = target_spacing
|
|
210
|
+
cls.apply_reorder = apply_reorder
|
|
211
211
|
|
|
212
212
|
def show(self, ctx=None, channel: int = 0, slice_index: int = None, anatomical_plane: int = 0, **kwargs):
|
|
213
213
|
"""
|
|
@@ -230,7 +230,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
230
230
|
"""
|
|
231
231
|
return show_med_img(
|
|
232
232
|
self, ctx=ctx, channel=channel, slice_index=slice_index,
|
|
233
|
-
anatomical_plane=anatomical_plane, voxel_size=self.
|
|
233
|
+
anatomical_plane=anatomical_plane, voxel_size=self.target_spacing,
|
|
234
234
|
**merge(self._show_args, kwargs)
|
|
235
235
|
)
|
|
236
236
|
|
|
@@ -247,106 +247,3 @@ class MedImage(MedBase):
|
|
|
247
247
|
class MedMask(MedBase):
|
|
248
248
|
"""Subclass of MedBase that represents an mask object."""
|
|
249
249
|
_show_args = {'alpha':0.5, 'cmap':'tab20'}
|
|
250
|
-
|
|
251
|
-
# %% ../nbs/01_vision_core.ipynb 14
|
|
252
|
-
import os
|
|
253
|
-
from fastai.callback.progress import ProgressCallback
|
|
254
|
-
from fastai.callback.core import Callback
|
|
255
|
-
import sys
|
|
256
|
-
from IPython import get_ipython
|
|
257
|
-
|
|
258
|
-
class VSCodeProgressCallback(ProgressCallback):
|
|
259
|
-
"""Enhanced progress callback that works better in VS Code notebooks."""
|
|
260
|
-
|
|
261
|
-
def __init__(self, **kwargs):
|
|
262
|
-
super().__init__(**kwargs)
|
|
263
|
-
self.is_vscode = self._detect_vscode_environment()
|
|
264
|
-
self.lr_find_progress = None
|
|
265
|
-
|
|
266
|
-
def _detect_vscode_environment(self):
|
|
267
|
-
"""Detect if running in VS Code Jupyter environment."""
|
|
268
|
-
ipython = get_ipython()
|
|
269
|
-
if ipython is None:
|
|
270
|
-
return True # Assume VS Code if no IPython (safer default)
|
|
271
|
-
# VS Code detection - more comprehensive check
|
|
272
|
-
kernel_name = str(type(ipython.kernel)).lower() if hasattr(ipython, 'kernel') else ''
|
|
273
|
-
return ('vscode' in kernel_name or
|
|
274
|
-
'zmq' in kernel_name or # VS Code often uses ZMQInteractiveShell
|
|
275
|
-
not hasattr(ipython, 'display_pub')) # Missing display publisher often indicates VS Code
|
|
276
|
-
|
|
277
|
-
def before_fit(self):
|
|
278
|
-
"""Initialize progress tracking before training."""
|
|
279
|
-
if self.is_vscode:
|
|
280
|
-
if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
|
|
281
|
-
# This is lr_find, handle differently
|
|
282
|
-
print("🔍 Starting Learning Rate Finder...")
|
|
283
|
-
self.lr_find_progress = 0
|
|
284
|
-
else:
|
|
285
|
-
# Regular training
|
|
286
|
-
print(f"🚀 Training for {self.learn.n_epoch} epochs...")
|
|
287
|
-
super().before_fit()
|
|
288
|
-
|
|
289
|
-
def before_epoch(self):
|
|
290
|
-
"""Initialize epoch progress."""
|
|
291
|
-
if self.is_vscode:
|
|
292
|
-
if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
|
|
293
|
-
print(f"📊 LR Find - Testing learning rates...")
|
|
294
|
-
else:
|
|
295
|
-
print(f"📈 Epoch {self.epoch+1}/{self.learn.n_epoch}")
|
|
296
|
-
sys.stdout.flush()
|
|
297
|
-
super().before_epoch()
|
|
298
|
-
|
|
299
|
-
def after_batch(self):
|
|
300
|
-
"""Update progress after each batch."""
|
|
301
|
-
super().after_batch()
|
|
302
|
-
if self.is_vscode:
|
|
303
|
-
if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
|
|
304
|
-
# Special handling for lr_find
|
|
305
|
-
self.lr_find_progress = getattr(self, 'iter', 0) + 1
|
|
306
|
-
total = getattr(self, 'n_iter', 100)
|
|
307
|
-
if self.lr_find_progress % max(1, total // 10) == 0:
|
|
308
|
-
progress = (self.lr_find_progress / total) * 100
|
|
309
|
-
print(f"⏳ LR Find Progress: {self.lr_find_progress}/{total} ({progress:.1f}%)")
|
|
310
|
-
sys.stdout.flush()
|
|
311
|
-
else:
|
|
312
|
-
# Regular training progress
|
|
313
|
-
if hasattr(self, 'iter') and hasattr(self, 'n_iter'):
|
|
314
|
-
if self.iter % max(1, self.n_iter // 20) == 0:
|
|
315
|
-
progress = (self.iter / self.n_iter) * 100
|
|
316
|
-
print(f"⏳ Batch {self.iter}/{self.n_iter} ({progress:.1f}%)")
|
|
317
|
-
sys.stdout.flush()
|
|
318
|
-
|
|
319
|
-
def after_fit(self):
|
|
320
|
-
"""Complete progress tracking after training."""
|
|
321
|
-
if self.is_vscode:
|
|
322
|
-
if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
|
|
323
|
-
print("✅ Learning Rate Finder completed!")
|
|
324
|
-
else:
|
|
325
|
-
print("✅ Training completed!")
|
|
326
|
-
sys.stdout.flush()
|
|
327
|
-
super().after_fit()
|
|
328
|
-
|
|
329
|
-
def before_validate(self):
|
|
330
|
-
"""Update before validation."""
|
|
331
|
-
if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):
|
|
332
|
-
print("🔄 Validating...")
|
|
333
|
-
sys.stdout.flush()
|
|
334
|
-
super().before_validate()
|
|
335
|
-
|
|
336
|
-
def after_validate(self):
|
|
337
|
-
"""Update after validation."""
|
|
338
|
-
if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):
|
|
339
|
-
print("✅ Validation completed")
|
|
340
|
-
sys.stdout.flush()
|
|
341
|
-
super().after_validate()
|
|
342
|
-
|
|
343
|
-
def setup_vscode_progress():
|
|
344
|
-
"""Configure fastai to use VS Code-compatible progress callback."""
|
|
345
|
-
from fastai.learner import defaults
|
|
346
|
-
|
|
347
|
-
# Replace default ProgressCallback with VSCodeProgressCallback
|
|
348
|
-
if ProgressCallback in defaults.callbacks:
|
|
349
|
-
defaults.callbacks = [cb if cb != ProgressCallback else VSCodeProgressCallback
|
|
350
|
-
for cb in defaults.callbacks]
|
|
351
|
-
|
|
352
|
-
print("✅ Configured VS Code-compatible progress callback")
|
fastMONAI/vision_data.py
CHANGED
|
@@ -27,7 +27,7 @@ def pred_to_multiclass_mask(pred: torch.Tensor) -> torch.Tensor:
|
|
|
27
27
|
|
|
28
28
|
pred = pred.softmax(dim=0)
|
|
29
29
|
|
|
30
|
-
return pred.argmax(dim=0,
|
|
30
|
+
return pred.argmax(dim=0, keepdim=True)
|
|
31
31
|
|
|
32
32
|
# %% ../nbs/02_vision_data.ipynb 6
|
|
33
33
|
def batch_pred_to_multiclass_mask(pred: torch.Tensor) -> (torch.Tensor, int):
|
|
@@ -68,12 +68,12 @@ class MedDataBlock(DataBlock):
|
|
|
68
68
|
#TODO add get_x
|
|
69
69
|
def __init__(self, blocks: list = None, dl_type: TfmdDL = None, getters: list = None,
|
|
70
70
|
n_inp: int | None = None, item_tfms: list = None, batch_tfms: list = None,
|
|
71
|
-
|
|
71
|
+
apply_reorder: bool = False, target_spacing: (int, list) = None, **kwargs):
|
|
72
72
|
|
|
73
73
|
super().__init__(blocks, dl_type, getters, n_inp, item_tfms,
|
|
74
74
|
batch_tfms, **kwargs)
|
|
75
75
|
|
|
76
|
-
MedBase.item_preprocessing(
|
|
76
|
+
MedBase.item_preprocessing(target_spacing, apply_reorder)
|
|
77
77
|
|
|
78
78
|
# %% ../nbs/02_vision_data.ipynb 11
|
|
79
79
|
def MedMaskBlock():
|
|
@@ -88,7 +88,7 @@ class MedImageDataLoaders(DataLoaders):
|
|
|
88
88
|
@delegates(DataLoaders.from_dblock)
|
|
89
89
|
def from_df(cls, df, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='',
|
|
90
90
|
label_col=1, label_delim=None, y_block=None, valid_col=None,
|
|
91
|
-
item_tfms=None, batch_tfms=None,
|
|
91
|
+
item_tfms=None, batch_tfms=None, apply_reorder=False, target_spacing=None, **kwargs):
|
|
92
92
|
"""Create from DataFrame."""
|
|
93
93
|
|
|
94
94
|
if y_block is None:
|
|
@@ -104,8 +104,8 @@ class MedImageDataLoaders(DataLoaders):
|
|
|
104
104
|
get_y=ColReader(label_col, label_delim=label_delim),
|
|
105
105
|
splitter=splitter,
|
|
106
106
|
item_tfms=item_tfms,
|
|
107
|
-
|
|
108
|
-
|
|
107
|
+
apply_reorder=apply_reorder,
|
|
108
|
+
target_spacing=target_spacing
|
|
109
109
|
)
|
|
110
110
|
|
|
111
111
|
return cls.from_dblock(dblock, df, **kwargs)
|
fastMONAI/vision_inference.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
# %% auto 0
|
|
4
4
|
__all__ = ['save_series_pred', 'load_system_resources', 'inference', 'compute_binary_tumor_volume', 'refine_binary_pred_mask',
|
|
5
|
-
'gradio_image_classifier']
|
|
5
|
+
'keep_largest', 'gradio_image_classifier']
|
|
6
6
|
|
|
7
7
|
# %% ../nbs/06_vision_inference.ipynb 1
|
|
8
8
|
from copy import copy
|
|
@@ -67,18 +67,18 @@ def load_system_resources(models_path, learner_fn, variables_fn):
|
|
|
67
67
|
|
|
68
68
|
learn = load_learner(models_path / learner_fn, cpu=True)
|
|
69
69
|
vars_fn = models_path / variables_fn
|
|
70
|
-
_,
|
|
70
|
+
_, apply_reorder, target_spacing = load_variables(pkl_fn=vars_fn)
|
|
71
71
|
|
|
72
|
-
return learn,
|
|
72
|
+
return learn, apply_reorder, target_spacing
|
|
73
73
|
|
|
74
74
|
# %% ../nbs/06_vision_inference.ipynb 8
|
|
75
|
-
def inference(learn_inf,
|
|
75
|
+
def inference(learn_inf, apply_reorder, target_spacing, fn: (str, Path) = '',
|
|
76
76
|
save_path: (str, Path) = None, org_img=None, input_img=None,
|
|
77
77
|
org_size=None):
|
|
78
78
|
"""Predict on new data using exported model."""
|
|
79
79
|
|
|
80
80
|
if None in [org_img, input_img, org_size]:
|
|
81
|
-
org_img, input_img, org_size = med_img_reader(fn,
|
|
81
|
+
org_img, input_img, org_size = med_img_reader(fn, apply_reorder, target_spacing,
|
|
82
82
|
only_tensor=False)
|
|
83
83
|
else:
|
|
84
84
|
org_img, input_img = copy(org_img), copy(input_img)
|
|
@@ -148,6 +148,10 @@ def refine_binary_pred_mask(pred_mask,
|
|
|
148
148
|
if verbose:
|
|
149
149
|
print(n_components)
|
|
150
150
|
|
|
151
|
+
# Handle empty mask case (no foreground components)
|
|
152
|
+
if n_components == 0:
|
|
153
|
+
return torch.zeros_like(torch.Tensor(pred_mask)).float()
|
|
154
|
+
|
|
151
155
|
if remove_size is None:
|
|
152
156
|
sizes = np.bincount(labeled_mask.ravel())
|
|
153
157
|
max_label = sizes[1:].argmax() + 1
|
|
@@ -157,14 +161,36 @@ def refine_binary_pred_mask(pred_mask,
|
|
|
157
161
|
processed_mask = remove_small_objects(
|
|
158
162
|
labeled_mask, min_size=small_objects_threshold)
|
|
159
163
|
|
|
160
|
-
return torch.Tensor(processed_mask > 0).float()
|
|
164
|
+
return torch.Tensor(processed_mask > 0).float()
|
|
165
|
+
|
|
166
|
+
# %% ../nbs/06_vision_inference.ipynb 12
|
|
167
|
+
def keep_largest(pred_mask: torch.Tensor) -> torch.Tensor:
|
|
168
|
+
"""Keep only the largest connected component in a binary mask.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
pred_mask: Binary prediction mask tensor.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Binary mask with only the largest connected component.
|
|
175
|
+
"""
|
|
176
|
+
mask_np = pred_mask.numpy() if isinstance(pred_mask, torch.Tensor) else pred_mask
|
|
177
|
+
labeled_mask, n_components = label(mask_np)
|
|
178
|
+
|
|
179
|
+
if n_components == 0:
|
|
180
|
+
return torch.zeros_like(pred_mask) if isinstance(pred_mask, torch.Tensor) else mask_np
|
|
181
|
+
|
|
182
|
+
sizes = np.bincount(labeled_mask.ravel())
|
|
183
|
+
largest_label = sizes[1:].argmax() + 1 # Skip background (label 0)
|
|
184
|
+
|
|
185
|
+
result = (labeled_mask == largest_label).astype(np.float32)
|
|
186
|
+
return torch.from_numpy(result) if isinstance(pred_mask, torch.Tensor) else result
|
|
161
187
|
|
|
162
|
-
# %% ../nbs/06_vision_inference.ipynb
|
|
163
|
-
def gradio_image_classifier(file_obj, learn,
|
|
188
|
+
# %% ../nbs/06_vision_inference.ipynb 14
|
|
189
|
+
def gradio_image_classifier(file_obj, learn, apply_reorder, target_spacing):
|
|
164
190
|
"""Predict on images using exported learner and return the result as a dictionary."""
|
|
165
191
|
|
|
166
192
|
img_path = Path(file_obj.name)
|
|
167
|
-
img = med_img_reader(img_path,
|
|
193
|
+
img = med_img_reader(img_path, apply_reorder=apply_reorder, target_spacing=target_spacing)
|
|
168
194
|
|
|
169
195
|
_, _, predictions = learn.predict(img)
|
|
170
196
|
prediction_dict = {index: value.item() for index, value in enumerate(predictions)}
|