fastMONAI 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastMONAI/vision_core.py CHANGED
@@ -1,7 +1,7 @@
1
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_vision_core.ipynb.
2
2
 
3
3
  # %% auto 0
4
- __all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask', 'VSCodeProgressCallback', 'setup_vscode_progress']
4
+ __all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask']
5
5
 
6
6
  # %% ../nbs/01_vision_core.ipynb 2
7
7
  from .vision_plot import *
@@ -10,26 +10,26 @@ from torchio import ScalarImage, LabelMap, ToCanonical, Resample
10
10
  import copy
11
11
 
12
12
  # %% ../nbs/01_vision_core.ipynb 5
13
- def _preprocess(obj, reorder, resample):
13
+ def _preprocess(obj, apply_reorder, target_spacing):
14
14
  """
15
15
  Preprocesses the given object.
16
16
 
17
17
  Args:
18
18
  obj: The object to preprocess.
19
- reorder: Whether to reorder the object.
20
- resample: Whether to resample the object.
19
+ apply_reorder: Whether to reorder the object.
20
+ target_spacing: Whether to resample the object.
21
21
 
22
22
  Returns:
23
23
  The preprocessed object and its original size.
24
24
  """
25
- if reorder:
25
+ if apply_reorder:
26
26
  transform = ToCanonical()
27
27
  obj = transform(obj)
28
28
 
29
29
  original_size = obj.shape[1:]
30
30
 
31
- if resample and not all(np.isclose(obj.spacing, resample)):
32
- transform = Resample(resample)
31
+ if target_spacing and not all(np.isclose(obj.spacing, target_spacing)):
32
+ transform = Resample(target_spacing)
33
33
  obj = transform(obj)
34
34
 
35
35
  if MedBase.affine_matrix is None:
@@ -38,33 +38,33 @@ def _preprocess(obj, reorder, resample):
38
38
  return obj, original_size
39
39
 
40
40
  # %% ../nbs/01_vision_core.ipynb 6
41
- def _load_and_preprocess(file_path, reorder, resample, dtype):
41
+ def _load_and_preprocess(file_path, apply_reorder, target_spacing, dtype):
42
42
  """
43
43
  Helper function to load and preprocess an image.
44
44
 
45
45
  Args:
46
46
  file_path: Image file path.
47
- reorder: Whether to reorder data for canonical (RAS+) orientation.
48
- resample: Whether to resample image to different voxel sizes and dimensions.
47
+ apply_reorder: Whether to reorder data for canonical (RAS+) orientation.
48
+ target_spacing: Whether to resample image to different voxel sizes and dimensions.
49
49
  dtype: Desired datatype for output.
50
50
 
51
51
  Returns:
52
52
  tuple: Original image, preprocessed image, and its original size.
53
53
  """
54
54
  org_img = LabelMap(file_path) if dtype is MedMask else ScalarImage(file_path) #_load(file_path, dtype=dtype)
55
- input_img, org_size = _preprocess(org_img, reorder, resample)
55
+ input_img, org_size = _preprocess(org_img, apply_reorder, target_spacing)
56
56
 
57
57
  return org_img, input_img, org_size
58
58
 
59
59
  # %% ../nbs/01_vision_core.ipynb 7
60
- def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_tensor: bool, dtype):
60
+ def _multi_channel(image_paths: L | list, apply_reorder: bool, target_spacing: list, only_tensor: bool, dtype):
61
61
  """
62
62
  Load and preprocess multisequence data.
63
63
 
64
64
  Args:
65
65
  image_paths: List of image paths (e.g., T1, T2, T1CE, DWI).
66
- reorder: Whether to reorder data for canonical (RAS+) orientation.
67
- resample: Whether to resample image to different voxel sizes and dimensions.
66
+ apply_reorder: Whether to reorder data for canonical (RAS+) orientation.
67
+ target_spacing: Whether to resample image to different voxel sizes and dimensions.
68
68
  only_tensor: Whether to return only image tensor.
69
69
  dtype: Desired datatype for output.
70
70
 
@@ -72,7 +72,7 @@ def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_te
72
72
  torch.Tensor: A stacked 4D tensor, if `only_tensor` is True.
73
73
  tuple: Original image, preprocessed image, original size, if `only_tensor` is False.
74
74
  """
75
- image_data = [_load_and_preprocess(image, reorder, resample, dtype) for image in image_paths]
75
+ image_data = [_load_and_preprocess(image, apply_reorder, target_spacing, dtype) for image in image_paths]
76
76
  org_img, input_img, org_size = image_data[-1]
77
77
 
78
78
  tensor = torch.stack([img.data[0] for _, img, _ in image_data], dim=0)
@@ -84,15 +84,15 @@ def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_te
84
84
  return org_img, input_img, org_size
85
85
 
86
86
  # %% ../nbs/01_vision_core.ipynb 8
87
- def med_img_reader(file_path: str | Path | L | list, reorder: bool = False, resample: list = None,
87
+ def med_img_reader(file_path: str | Path | L | list, apply_reorder: bool = False, target_spacing: list = None,
88
88
  only_tensor: bool = True, dtype = torch.Tensor):
89
89
  """Loads and preprocesses a medical image.
90
90
 
91
91
  Args:
92
92
  file_path: Path to the image. Can be a string, Path object or a list.
93
- reorder: Whether to reorder the data to be closest to canonical
93
+ apply_reorder: Whether to reorder the data to be closest to canonical
94
94
  (RAS+) orientation. Defaults to False.
95
- resample: Whether to resample image to different voxel sizes and
95
+ target_spacing: Whether to resample image to different voxel sizes and
96
96
  image dimensions. Defaults to None.
97
97
  only_tensor: Whether to return only image tensor. Defaults to True.
98
98
  dtype: Datatype for the return value. Defaults to torch.Tensor.
@@ -104,10 +104,10 @@ def med_img_reader(file_path: str | Path | L | list, reorder: bool = False, resa
104
104
  """
105
105
 
106
106
  if isinstance(file_path, (list, L)):
107
- return _multi_channel(file_path, reorder, resample, only_tensor, dtype)
107
+ return _multi_channel(file_path, apply_reorder, target_spacing, only_tensor, dtype)
108
108
 
109
109
  org_img, input_img, org_size = _load_and_preprocess(
110
- file_path, reorder, resample, dtype)
110
+ file_path, apply_reorder, target_spacing, dtype)
111
111
 
112
112
  if only_tensor:
113
113
  return dtype(input_img.data.type(torch.float))
@@ -129,7 +129,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
129
129
 
130
130
  _bypass_type = torch.Tensor
131
131
  _show_args = {'cmap':'gray'}
132
- resample, reorder = None, False
132
+ target_spacing, apply_reorder = None, False
133
133
  affine_matrix = None
134
134
 
135
135
  @classmethod
@@ -150,7 +150,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
150
150
  if isinstance(fn, torch.Tensor):
151
151
  return cls(fn)
152
152
 
153
- return med_img_reader(fn, resample=cls.resample, reorder=cls.reorder, dtype=cls)
153
+ return med_img_reader(fn, target_spacing=cls.target_spacing, apply_reorder=cls.apply_reorder, dtype=cls)
154
154
 
155
155
  def __new__(cls, x, **kwargs):
156
156
  """Creates a new instance of MedBase from a tensor."""
@@ -196,18 +196,18 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
196
196
  return copied
197
197
 
198
198
  @classmethod
199
- def item_preprocessing(cls, resample: (list, int, tuple), reorder: bool):
199
+ def item_preprocessing(cls, target_spacing: (list, int, tuple), apply_reorder: bool):
200
200
  """
201
- Changes the values for the class variables `resample` and `reorder`.
201
+ Changes the values for the class variables `target_spacing` and `apply_reorder`.
202
202
 
203
203
  Args:
204
- resample : (list, int, tuple)
204
+ target_spacing : (list, int, tuple)
205
205
  A list with voxel spacing.
206
- reorder : bool
206
+ apply_reorder : bool
207
207
  Whether to reorder the data to be closest to canonical (RAS+) orientation.
208
208
  """
209
- cls.resample = resample
210
- cls.reorder = reorder
209
+ cls.target_spacing = target_spacing
210
+ cls.apply_reorder = apply_reorder
211
211
 
212
212
  def show(self, ctx=None, channel: int = 0, slice_index: int = None, anatomical_plane: int = 0, **kwargs):
213
213
  """
@@ -230,7 +230,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
230
230
  """
231
231
  return show_med_img(
232
232
  self, ctx=ctx, channel=channel, slice_index=slice_index,
233
- anatomical_plane=anatomical_plane, voxel_size=self.resample,
233
+ anatomical_plane=anatomical_plane, voxel_size=self.target_spacing,
234
234
  **merge(self._show_args, kwargs)
235
235
  )
236
236
 
@@ -247,106 +247,3 @@ class MedImage(MedBase):
247
247
  class MedMask(MedBase):
248
248
  """Subclass of MedBase that represents an mask object."""
249
249
  _show_args = {'alpha':0.5, 'cmap':'tab20'}
250
-
251
- # %% ../nbs/01_vision_core.ipynb 14
252
- import os
253
- from fastai.callback.progress import ProgressCallback
254
- from fastai.callback.core import Callback
255
- import sys
256
- from IPython import get_ipython
257
-
258
- class VSCodeProgressCallback(ProgressCallback):
259
- """Enhanced progress callback that works better in VS Code notebooks."""
260
-
261
- def __init__(self, **kwargs):
262
- super().__init__(**kwargs)
263
- self.is_vscode = self._detect_vscode_environment()
264
- self.lr_find_progress = None
265
-
266
- def _detect_vscode_environment(self):
267
- """Detect if running in VS Code Jupyter environment."""
268
- ipython = get_ipython()
269
- if ipython is None:
270
- return True # Assume VS Code if no IPython (safer default)
271
- # VS Code detection - more comprehensive check
272
- kernel_name = str(type(ipython.kernel)).lower() if hasattr(ipython, 'kernel') else ''
273
- return ('vscode' in kernel_name or
274
- 'zmq' in kernel_name or # VS Code often uses ZMQInteractiveShell
275
- not hasattr(ipython, 'display_pub')) # Missing display publisher often indicates VS Code
276
-
277
- def before_fit(self):
278
- """Initialize progress tracking before training."""
279
- if self.is_vscode:
280
- if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
281
- # This is lr_find, handle differently
282
- print("🔍 Starting Learning Rate Finder...")
283
- self.lr_find_progress = 0
284
- else:
285
- # Regular training
286
- print(f"🚀 Training for {self.learn.n_epoch} epochs...")
287
- super().before_fit()
288
-
289
- def before_epoch(self):
290
- """Initialize epoch progress."""
291
- if self.is_vscode:
292
- if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
293
- print(f"📊 LR Find - Testing learning rates...")
294
- else:
295
- print(f"📈 Epoch {self.epoch+1}/{self.learn.n_epoch}")
296
- sys.stdout.flush()
297
- super().before_epoch()
298
-
299
- def after_batch(self):
300
- """Update progress after each batch."""
301
- super().after_batch()
302
- if self.is_vscode:
303
- if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
304
- # Special handling for lr_find
305
- self.lr_find_progress = getattr(self, 'iter', 0) + 1
306
- total = getattr(self, 'n_iter', 100)
307
- if self.lr_find_progress % max(1, total // 10) == 0:
308
- progress = (self.lr_find_progress / total) * 100
309
- print(f"⏳ LR Find Progress: {self.lr_find_progress}/{total} ({progress:.1f}%)")
310
- sys.stdout.flush()
311
- else:
312
- # Regular training progress
313
- if hasattr(self, 'iter') and hasattr(self, 'n_iter'):
314
- if self.iter % max(1, self.n_iter // 20) == 0:
315
- progress = (self.iter / self.n_iter) * 100
316
- print(f"⏳ Batch {self.iter}/{self.n_iter} ({progress:.1f}%)")
317
- sys.stdout.flush()
318
-
319
- def after_fit(self):
320
- """Complete progress tracking after training."""
321
- if self.is_vscode:
322
- if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
323
- print("✅ Learning Rate Finder completed!")
324
- else:
325
- print("✅ Training completed!")
326
- sys.stdout.flush()
327
- super().after_fit()
328
-
329
- def before_validate(self):
330
- """Update before validation."""
331
- if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):
332
- print("🔄 Validating...")
333
- sys.stdout.flush()
334
- super().before_validate()
335
-
336
- def after_validate(self):
337
- """Update after validation."""
338
- if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):
339
- print("✅ Validation completed")
340
- sys.stdout.flush()
341
- super().after_validate()
342
-
343
- def setup_vscode_progress():
344
- """Configure fastai to use VS Code-compatible progress callback."""
345
- from fastai.learner import defaults
346
-
347
- # Replace default ProgressCallback with VSCodeProgressCallback
348
- if ProgressCallback in defaults.callbacks:
349
- defaults.callbacks = [cb if cb != ProgressCallback else VSCodeProgressCallback
350
- for cb in defaults.callbacks]
351
-
352
- print("✅ Configured VS Code-compatible progress callback")
fastMONAI/vision_data.py CHANGED
@@ -27,7 +27,7 @@ def pred_to_multiclass_mask(pred: torch.Tensor) -> torch.Tensor:
27
27
 
28
28
  pred = pred.softmax(dim=0)
29
29
 
30
- return pred.argmax(dim=0, keepdims=True)
30
+ return pred.argmax(dim=0, keepdim=True)
31
31
 
32
32
  # %% ../nbs/02_vision_data.ipynb 6
33
33
  def batch_pred_to_multiclass_mask(pred: torch.Tensor) -> (torch.Tensor, int):
@@ -68,12 +68,12 @@ class MedDataBlock(DataBlock):
68
68
  #TODO add get_x
69
69
  def __init__(self, blocks: list = None, dl_type: TfmdDL = None, getters: list = None,
70
70
  n_inp: int | None = None, item_tfms: list = None, batch_tfms: list = None,
71
- reorder: bool = False, resample: (int, list) = None, **kwargs):
71
+ apply_reorder: bool = False, target_spacing: (int, list) = None, **kwargs):
72
72
 
73
73
  super().__init__(blocks, dl_type, getters, n_inp, item_tfms,
74
74
  batch_tfms, **kwargs)
75
75
 
76
- MedBase.item_preprocessing(resample, reorder)
76
+ MedBase.item_preprocessing(target_spacing, apply_reorder)
77
77
 
78
78
  # %% ../nbs/02_vision_data.ipynb 11
79
79
  def MedMaskBlock():
@@ -88,7 +88,7 @@ class MedImageDataLoaders(DataLoaders):
88
88
  @delegates(DataLoaders.from_dblock)
89
89
  def from_df(cls, df, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='',
90
90
  label_col=1, label_delim=None, y_block=None, valid_col=None,
91
- item_tfms=None, batch_tfms=None, reorder=False, resample=None, **kwargs):
91
+ item_tfms=None, batch_tfms=None, apply_reorder=False, target_spacing=None, **kwargs):
92
92
  """Create from DataFrame."""
93
93
 
94
94
  if y_block is None:
@@ -104,8 +104,8 @@ class MedImageDataLoaders(DataLoaders):
104
104
  get_y=ColReader(label_col, label_delim=label_delim),
105
105
  splitter=splitter,
106
106
  item_tfms=item_tfms,
107
- reorder=reorder,
108
- resample=resample
107
+ apply_reorder=apply_reorder,
108
+ target_spacing=target_spacing
109
109
  )
110
110
 
111
111
  return cls.from_dblock(dblock, df, **kwargs)
@@ -2,7 +2,7 @@
2
2
 
3
3
  # %% auto 0
4
4
  __all__ = ['save_series_pred', 'load_system_resources', 'inference', 'compute_binary_tumor_volume', 'refine_binary_pred_mask',
5
- 'gradio_image_classifier']
5
+ 'keep_largest', 'gradio_image_classifier']
6
6
 
7
7
  # %% ../nbs/06_vision_inference.ipynb 1
8
8
  from copy import copy
@@ -67,18 +67,18 @@ def load_system_resources(models_path, learner_fn, variables_fn):
67
67
 
68
68
  learn = load_learner(models_path / learner_fn, cpu=True)
69
69
  vars_fn = models_path / variables_fn
70
- _, reorder, resample = load_variables(pkl_fn=vars_fn)
70
+ _, apply_reorder, target_spacing = load_variables(pkl_fn=vars_fn)
71
71
 
72
- return learn, reorder, resample
72
+ return learn, apply_reorder, target_spacing
73
73
 
74
74
  # %% ../nbs/06_vision_inference.ipynb 8
75
- def inference(learn_inf, reorder, resample, fn: (str, Path) = '',
75
+ def inference(learn_inf, apply_reorder, target_spacing, fn: (str, Path) = '',
76
76
  save_path: (str, Path) = None, org_img=None, input_img=None,
77
77
  org_size=None):
78
78
  """Predict on new data using exported model."""
79
79
 
80
80
  if None in [org_img, input_img, org_size]:
81
- org_img, input_img, org_size = med_img_reader(fn, reorder, resample,
81
+ org_img, input_img, org_size = med_img_reader(fn, apply_reorder, target_spacing,
82
82
  only_tensor=False)
83
83
  else:
84
84
  org_img, input_img = copy(org_img), copy(input_img)
@@ -148,6 +148,10 @@ def refine_binary_pred_mask(pred_mask,
148
148
  if verbose:
149
149
  print(n_components)
150
150
 
151
+ # Handle empty mask case (no foreground components)
152
+ if n_components == 0:
153
+ return torch.zeros_like(torch.Tensor(pred_mask)).float()
154
+
151
155
  if remove_size is None:
152
156
  sizes = np.bincount(labeled_mask.ravel())
153
157
  max_label = sizes[1:].argmax() + 1
@@ -157,14 +161,36 @@ def refine_binary_pred_mask(pred_mask,
157
161
  processed_mask = remove_small_objects(
158
162
  labeled_mask, min_size=small_objects_threshold)
159
163
 
160
- return torch.Tensor(processed_mask > 0).float()
164
+ return torch.Tensor(processed_mask > 0).float()
165
+
166
+ # %% ../nbs/06_vision_inference.ipynb 12
167
+ def keep_largest(pred_mask: torch.Tensor) -> torch.Tensor:
168
+ """Keep only the largest connected component in a binary mask.
169
+
170
+ Args:
171
+ pred_mask: Binary prediction mask tensor.
172
+
173
+ Returns:
174
+ Binary mask with only the largest connected component.
175
+ """
176
+ mask_np = pred_mask.numpy() if isinstance(pred_mask, torch.Tensor) else pred_mask
177
+ labeled_mask, n_components = label(mask_np)
178
+
179
+ if n_components == 0:
180
+ return torch.zeros_like(pred_mask) if isinstance(pred_mask, torch.Tensor) else mask_np
181
+
182
+ sizes = np.bincount(labeled_mask.ravel())
183
+ largest_label = sizes[1:].argmax() + 1 # Skip background (label 0)
184
+
185
+ result = (labeled_mask == largest_label).astype(np.float32)
186
+ return torch.from_numpy(result) if isinstance(pred_mask, torch.Tensor) else result
161
187
 
162
- # %% ../nbs/06_vision_inference.ipynb 13
163
- def gradio_image_classifier(file_obj, learn, reorder, resample):
188
+ # %% ../nbs/06_vision_inference.ipynb 14
189
+ def gradio_image_classifier(file_obj, learn, apply_reorder, target_spacing):
164
190
  """Predict on images using exported learner and return the result as a dictionary."""
165
191
 
166
192
  img_path = Path(file_obj.name)
167
- img = med_img_reader(img_path, reorder=reorder, resample=resample)
193
+ img = med_img_reader(img_path, apply_reorder=apply_reorder, target_spacing=target_spacing)
168
194
 
169
195
  _, _, predictions = learn.predict(img)
170
196
  prediction_dict = {index: value.item() for index, value in enumerate(predictions)}