fastMONAI 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,8 +3,8 @@
3
3
  # %% auto 0
4
4
  __all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity',
5
5
  'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField',
6
- 'RandomBlur', 'RandomGamma', 'RandomMotion', 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip',
7
- 'OneOf']
6
+ 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomElasticDeformation',
7
+ 'RandomAffine', 'RandomFlip', 'OneOf']
8
8
 
9
9
  # %% ../nbs/03_vision_augment.ipynb 2
10
10
  from fastai.data.all import *
@@ -14,10 +14,10 @@ from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity
14
14
 
15
15
  # %% ../nbs/03_vision_augment.ipynb 5
16
16
  class CustomDictTransform(ItemTransform):
17
- """A class that serves as a wrapper to perform an identical transformation on both
17
+ """A class that serves as a wrapper to perform an identical transformation on both
18
18
  the image and the target (if it's a mask).
19
19
  """
20
-
20
+
21
21
  split_idx = 0 # Only perform transformations on training data. Use TTA() for transformations on validation data.
22
22
 
23
23
  def __init__(self, aug):
@@ -28,31 +28,42 @@ class CustomDictTransform(ItemTransform):
28
28
  """
29
29
  self.aug = aug
30
30
 
31
+ @property
32
+ def tio_transform(self):
33
+ """Return the underlying TorchIO transform.
34
+
35
+ This property enables using fastMONAI wrappers in patch-based workflows
36
+ where raw TorchIO transforms are needed for tio.Compose().
37
+ """
38
+ return self.aug
39
+
31
40
  def encodes(self, x):
32
41
  """
33
- Applies the stored transformation to an image, and the same random transformation
42
+ Applies the stored transformation to an image, and the same random transformation
34
43
  to the target if it is a mask. If the target is not a mask, it returns the target as is.
35
44
 
36
45
  Args:
37
- x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the
46
+ x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the
38
47
  image and the target.
39
48
 
40
49
  Returns:
41
- Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.
42
- If the target is a mask, it's transformed identically to the image. If the target
50
+ Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.
51
+ If the target is a mask, it's transformed identically to the image. If the target
43
52
  is not a mask, the original target is returned.
44
53
  """
45
54
  img, y_true = x
55
+
56
+ # Use identity affine if MedImage.affine_matrix is not set
57
+ affine = MedImage.affine_matrix if MedImage.affine_matrix is not None else np.eye(4)
46
58
 
47
59
  if isinstance(y_true, (MedMask)):
48
- aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=MedImage.affine_matrix),
49
- mask=tio.LabelMap(tensor=y_true, affine=MedImage.affine_matrix)))
60
+ aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=affine),
61
+ mask=tio.LabelMap(tensor=y_true, affine=affine)))
50
62
  return MedImage.create(aug['img'].data), MedMask.create(aug['mask'].data)
51
63
 
52
64
  aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img)))
53
65
  return MedImage.create(aug['img'].data), y_true
54
66
 
55
-
56
67
  # %% ../nbs/03_vision_augment.ipynb 7
57
68
  def do_pad_or_crop(o, target_shape, padding_mode, mask_name, dtype=torch.Tensor):
58
69
  #TODO:refactorize
@@ -72,6 +83,11 @@ class PadOrCrop(DisplayedTransform):
72
83
  padding_mode=padding_mode,
73
84
  mask_name=mask_name)
74
85
 
86
+ @property
87
+ def tio_transform(self):
88
+ """Return the underlying TorchIO transform."""
89
+ return self.pad_or_crop
90
+
75
91
  def encodes(self, o: (MedImage, MedMask)):
76
92
  return type(o)(self.pad_or_crop(o))
77
93
 
@@ -85,6 +101,11 @@ class ZNormalization(DisplayedTransform):
85
101
  self.z_normalization = tio.ZNormalization(masking_method=masking_method)
86
102
  self.channel_wise = channel_wise
87
103
 
104
+ @property
105
+ def tio_transform(self):
106
+ """Return the underlying TorchIO transform."""
107
+ return self.z_normalization
108
+
88
109
  def encodes(self, o: MedImage):
89
110
  try:
90
111
  if self.channel_wise:
@@ -132,7 +153,12 @@ class RescaleIntensity(DisplayedTransform):
132
153
 
133
154
  def __init__(self, out_min_max: tuple[float, float], in_min_max: tuple[float, float]):
134
155
  self.rescale = tio.RescaleIntensity(out_min_max=out_min_max, in_min_max=in_min_max)
135
-
156
+
157
+ @property
158
+ def tio_transform(self):
159
+ """Return the underlying TorchIO transform."""
160
+ return self.rescale
161
+
136
162
  def encodes(self, o: MedImage):
137
163
  return MedImage.create(self.rescale(o))
138
164
 
@@ -211,8 +237,17 @@ class RandomGhosting(DisplayedTransform):
211
237
  def __init__(self, intensity=(0.5, 1), p=0.5):
212
238
  self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)
213
239
 
240
+ @property
241
+ def tio_transform(self):
242
+ """Return the underlying TorchIO transform."""
243
+ return self.add_ghosts
244
+
214
245
  def encodes(self, o: MedImage):
215
- return MedImage.create(self.add_ghosts(o))
246
+ result = self.add_ghosts(o)
247
+ # Handle potential complex values from k-space operations
248
+ if result.is_complex():
249
+ result = torch.real(result)
250
+ return MedImage.create(result)
216
251
 
217
252
  def encodes(self, o: MedMask):
218
253
  return o
@@ -221,26 +256,40 @@ class RandomGhosting(DisplayedTransform):
221
256
  class RandomSpike(DisplayedTransform):
222
257
  '''Apply TorchIO `RandomSpike`.'''
223
258
 
224
- split_idx,order=0,1
259
+ split_idx, order = 0, 1
225
260
 
226
261
  def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5):
227
262
  self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p)
228
263
 
229
- def encodes(self, o:MedImage):
230
- return MedImage.create(self.add_spikes(o))
264
+ @property
265
+ def tio_transform(self):
266
+ """Return the underlying TorchIO transform."""
267
+ return self.add_spikes
268
+
269
+ def encodes(self, o: MedImage):
270
+ result = self.add_spikes(o)
271
+ # Handle potential complex values from k-space operations
272
+ if result.is_complex():
273
+ result = torch.real(result)
274
+ return MedImage.create(result)
231
275
 
232
- def encodes(self, o:MedMask):
276
+ def encodes(self, o: MedMask):
233
277
  return o
234
278
 
235
279
  # %% ../nbs/03_vision_augment.ipynb 16
236
280
  class RandomNoise(DisplayedTransform):
237
281
  '''Apply TorchIO `RandomNoise`.'''
238
282
 
239
- split_idx,order=0,1
283
+ split_idx, order = 0, 1
240
284
 
241
285
  def __init__(self, mean=0, std=(0, 0.25), p=0.5):
242
286
  self.add_noise = tio.RandomNoise(mean=mean, std=std, p=p)
243
287
 
288
+ @property
289
+ def tio_transform(self):
290
+ """Return the underlying TorchIO transform."""
291
+ return self.add_noise
292
+
244
293
  def encodes(self, o: MedImage):
245
294
  return MedImage.create(self.add_noise(o))
246
295
 
@@ -251,11 +300,16 @@ class RandomNoise(DisplayedTransform):
251
300
  class RandomBiasField(DisplayedTransform):
252
301
  '''Apply TorchIO `RandomBiasField`.'''
253
302
 
254
- split_idx,order=0,1
303
+ split_idx, order = 0, 1
255
304
 
256
305
  def __init__(self, coefficients=0.5, order=3, p=0.5):
257
306
  self.add_biasfield = tio.RandomBiasField(coefficients=coefficients, order=order, p=p)
258
307
 
308
+ @property
309
+ def tio_transform(self):
310
+ """Return the underlying TorchIO transform."""
311
+ return self.add_biasfield
312
+
259
313
  def encodes(self, o: MedImage):
260
314
  return MedImage.create(self.add_biasfield(o))
261
315
 
@@ -264,13 +318,18 @@ class RandomBiasField(DisplayedTransform):
264
318
 
265
319
  # %% ../nbs/03_vision_augment.ipynb 18
266
320
  class RandomBlur(DisplayedTransform):
267
- '''Apply TorchIO `RandomBiasField`.'''
321
+ '''Apply TorchIO `RandomBlur`.'''
268
322
 
269
- split_idx,order=0,1
323
+ split_idx, order = 0, 1
270
324
 
271
325
  def __init__(self, std=(0, 2), p=0.5):
272
326
  self.add_blur = tio.RandomBlur(std=std, p=p)
273
-
327
+
328
+ @property
329
+ def tio_transform(self):
330
+ """Return the underlying TorchIO transform."""
331
+ return self.add_blur
332
+
274
333
  def encodes(self, o: MedImage):
275
334
  return MedImage.create(self.add_blur(o))
276
335
 
@@ -281,12 +340,16 @@ class RandomBlur(DisplayedTransform):
281
340
  class RandomGamma(DisplayedTransform):
282
341
  '''Apply TorchIO `RandomGamma`.'''
283
342
 
284
-
285
- split_idx,order=0,1
343
+ split_idx, order = 0, 1
286
344
 
287
345
  def __init__(self, log_gamma=(-0.3, 0.3), p=0.5):
288
346
  self.add_gamma = tio.RandomGamma(log_gamma=log_gamma, p=p)
289
347
 
348
+ @property
349
+ def tio_transform(self):
350
+ """Return the underlying TorchIO transform."""
351
+ return self.add_gamma
352
+
290
353
  def encodes(self, o: MedImage):
291
354
  return MedImage.create(self.add_gamma(o))
292
355
 
@@ -294,6 +357,38 @@ class RandomGamma(DisplayedTransform):
294
357
  return o
295
358
 
296
359
  # %% ../nbs/03_vision_augment.ipynb 20
360
+ class RandomIntensityScale(DisplayedTransform):
361
+ """Randomly scale image intensities by a multiplicative factor.
362
+
363
+ Useful for domain generalization across different acquisition protocols
364
+ with varying intensity ranges.
365
+
366
+ Args:
367
+ scale_range (tuple[float, float]): Range of scale factors (min, max).
368
+ Values > 1 increase intensity, < 1 decrease intensity.
369
+ p (float): Probability of applying the transform (default: 0.5)
370
+
371
+ Example:
372
+ # Scale intensities randomly between 0.5x and 2.0x
373
+ transform = RandomIntensityScale(scale_range=(0.5, 2.0), p=0.3)
374
+ """
375
+
376
+ split_idx, order = 0, 1
377
+
378
+ def __init__(self, scale_range: tuple[float, float] = (0.5, 2.0), p: float = 0.5):
379
+ self.scale_range = scale_range
380
+ self.p = p
381
+
382
+ def encodes(self, o: MedImage):
383
+ if torch.rand(1).item() > self.p:
384
+ return o
385
+ scale = torch.empty(1).uniform_(self.scale_range[0], self.scale_range[1]).item()
386
+ return MedImage.create(o * scale)
387
+
388
+ def encodes(self, o: MedMask):
389
+ return o
390
+
391
+ # %% ../nbs/03_vision_augment.ipynb 21
297
392
  class RandomMotion(DisplayedTransform):
298
393
  """Apply TorchIO `RandomMotion`."""
299
394
 
@@ -315,13 +410,22 @@ class RandomMotion(DisplayedTransform):
315
410
  p=p
316
411
  )
317
412
 
413
+ @property
414
+ def tio_transform(self):
415
+ """Return the underlying TorchIO transform."""
416
+ return self.add_motion
417
+
318
418
  def encodes(self, o: MedImage):
319
- return MedImage.create(self.add_motion(o))
419
+ result = self.add_motion(o)
420
+ # Handle potential complex values from k-space operations
421
+ if result.is_complex():
422
+ result = torch.real(result)
423
+ return MedImage.create(result)
320
424
 
321
425
  def encodes(self, o: MedMask):
322
426
  return o
323
427
 
324
- # %% ../nbs/03_vision_augment.ipynb 22
428
+ # %% ../nbs/03_vision_augment.ipynb 23
325
429
  class RandomElasticDeformation(CustomDictTransform):
326
430
  """Apply TorchIO `RandomElasticDeformation`."""
327
431
 
@@ -334,7 +438,7 @@ class RandomElasticDeformation(CustomDictTransform):
334
438
  image_interpolation=image_interpolation,
335
439
  p=p))
336
440
 
337
- # %% ../nbs/03_vision_augment.ipynb 23
441
+ # %% ../nbs/03_vision_augment.ipynb 24
338
442
  class RandomAffine(CustomDictTransform):
339
443
  """Apply TorchIO `RandomAffine`."""
340
444
 
@@ -350,14 +454,14 @@ class RandomAffine(CustomDictTransform):
350
454
  default_pad_value=default_pad_value,
351
455
  p=p))
352
456
 
353
- # %% ../nbs/03_vision_augment.ipynb 24
457
+ # %% ../nbs/03_vision_augment.ipynb 25
354
458
  class RandomFlip(CustomDictTransform):
355
459
  """Apply TorchIO `RandomFlip`."""
356
460
 
357
461
  def __init__(self, axes='LR', p=0.5):
358
462
  super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))
359
463
 
360
- # %% ../nbs/03_vision_augment.ipynb 25
464
+ # %% ../nbs/03_vision_augment.ipynb 26
361
465
  class OneOf(CustomDictTransform):
362
466
  """Apply only one of the given transforms using TorchIO `OneOf`."""
363
467
 
fastMONAI/vision_core.py CHANGED
@@ -1,7 +1,7 @@
1
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_vision_core.ipynb.
2
2
 
3
3
  # %% auto 0
4
- __all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask', 'VSCodeProgressCallback', 'setup_vscode_progress']
4
+ __all__ = ['med_img_reader', 'MetaResolver', 'MedBase', 'MedImage', 'MedMask']
5
5
 
6
6
  # %% ../nbs/01_vision_core.ipynb 2
7
7
  from .vision_plot import *
@@ -10,26 +10,26 @@ from torchio import ScalarImage, LabelMap, ToCanonical, Resample
10
10
  import copy
11
11
 
12
12
  # %% ../nbs/01_vision_core.ipynb 5
13
- def _preprocess(obj, reorder, resample):
13
+ def _preprocess(obj, apply_reorder, target_spacing):
14
14
  """
15
15
  Preprocesses the given object.
16
16
 
17
17
  Args:
18
18
  obj: The object to preprocess.
19
- reorder: Whether to reorder the object.
20
- resample: Whether to resample the object.
19
+ apply_reorder: Whether to reorder the object.
20
+ target_spacing: Whether to resample the object.
21
21
 
22
22
  Returns:
23
23
  The preprocessed object and its original size.
24
24
  """
25
- if reorder:
25
+ if apply_reorder:
26
26
  transform = ToCanonical()
27
27
  obj = transform(obj)
28
28
 
29
29
  original_size = obj.shape[1:]
30
30
 
31
- if resample and not all(np.isclose(obj.spacing, resample)):
32
- transform = Resample(resample)
31
+ if target_spacing and not all(np.isclose(obj.spacing, target_spacing)):
32
+ transform = Resample(target_spacing)
33
33
  obj = transform(obj)
34
34
 
35
35
  if MedBase.affine_matrix is None:
@@ -38,33 +38,33 @@ def _preprocess(obj, reorder, resample):
38
38
  return obj, original_size
39
39
 
40
40
  # %% ../nbs/01_vision_core.ipynb 6
41
- def _load_and_preprocess(file_path, reorder, resample, dtype):
41
+ def _load_and_preprocess(file_path, apply_reorder, target_spacing, dtype):
42
42
  """
43
43
  Helper function to load and preprocess an image.
44
44
 
45
45
  Args:
46
46
  file_path: Image file path.
47
- reorder: Whether to reorder data for canonical (RAS+) orientation.
48
- resample: Whether to resample image to different voxel sizes and dimensions.
47
+ apply_reorder: Whether to reorder data for canonical (RAS+) orientation.
48
+ target_spacing: Whether to resample image to different voxel sizes and dimensions.
49
49
  dtype: Desired datatype for output.
50
50
 
51
51
  Returns:
52
52
  tuple: Original image, preprocessed image, and its original size.
53
53
  """
54
54
  org_img = LabelMap(file_path) if dtype is MedMask else ScalarImage(file_path) #_load(file_path, dtype=dtype)
55
- input_img, org_size = _preprocess(org_img, reorder, resample)
55
+ input_img, org_size = _preprocess(org_img, apply_reorder, target_spacing)
56
56
 
57
57
  return org_img, input_img, org_size
58
58
 
59
59
  # %% ../nbs/01_vision_core.ipynb 7
60
- def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_tensor: bool, dtype):
60
+ def _multi_channel(image_paths: L | list, apply_reorder: bool, target_spacing: list, only_tensor: bool, dtype):
61
61
  """
62
62
  Load and preprocess multisequence data.
63
63
 
64
64
  Args:
65
65
  image_paths: List of image paths (e.g., T1, T2, T1CE, DWI).
66
- reorder: Whether to reorder data for canonical (RAS+) orientation.
67
- resample: Whether to resample image to different voxel sizes and dimensions.
66
+ apply_reorder: Whether to reorder data for canonical (RAS+) orientation.
67
+ target_spacing: Whether to resample image to different voxel sizes and dimensions.
68
68
  only_tensor: Whether to return only image tensor.
69
69
  dtype: Desired datatype for output.
70
70
 
@@ -72,7 +72,7 @@ def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_te
72
72
  torch.Tensor: A stacked 4D tensor, if `only_tensor` is True.
73
73
  tuple: Original image, preprocessed image, original size, if `only_tensor` is False.
74
74
  """
75
- image_data = [_load_and_preprocess(image, reorder, resample, dtype) for image in image_paths]
75
+ image_data = [_load_and_preprocess(image, apply_reorder, target_spacing, dtype) for image in image_paths]
76
76
  org_img, input_img, org_size = image_data[-1]
77
77
 
78
78
  tensor = torch.stack([img.data[0] for _, img, _ in image_data], dim=0)
@@ -84,15 +84,15 @@ def _multi_channel(image_paths: L | list, reorder: bool, resample: list, only_te
84
84
  return org_img, input_img, org_size
85
85
 
86
86
  # %% ../nbs/01_vision_core.ipynb 8
87
- def med_img_reader(file_path: str | Path | L | list, reorder: bool = False, resample: list = None,
87
+ def med_img_reader(file_path: str | Path | L | list, apply_reorder: bool = False, target_spacing: list = None,
88
88
  only_tensor: bool = True, dtype = torch.Tensor):
89
89
  """Loads and preprocesses a medical image.
90
90
 
91
91
  Args:
92
92
  file_path: Path to the image. Can be a string, Path object or a list.
93
- reorder: Whether to reorder the data to be closest to canonical
93
+ apply_reorder: Whether to reorder the data to be closest to canonical
94
94
  (RAS+) orientation. Defaults to False.
95
- resample: Whether to resample image to different voxel sizes and
95
+ target_spacing: Whether to resample image to different voxel sizes and
96
96
  image dimensions. Defaults to None.
97
97
  only_tensor: Whether to return only image tensor. Defaults to True.
98
98
  dtype: Datatype for the return value. Defaults to torch.Tensor.
@@ -104,10 +104,10 @@ def med_img_reader(file_path: str | Path | L | list, reorder: bool = False, resa
104
104
  """
105
105
 
106
106
  if isinstance(file_path, (list, L)):
107
- return _multi_channel(file_path, reorder, resample, only_tensor, dtype)
107
+ return _multi_channel(file_path, apply_reorder, target_spacing, only_tensor, dtype)
108
108
 
109
109
  org_img, input_img, org_size = _load_and_preprocess(
110
- file_path, reorder, resample, dtype)
110
+ file_path, apply_reorder, target_spacing, dtype)
111
111
 
112
112
  if only_tensor:
113
113
  return dtype(input_img.data.type(torch.float))
@@ -129,7 +129,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
129
129
 
130
130
  _bypass_type = torch.Tensor
131
131
  _show_args = {'cmap':'gray'}
132
- resample, reorder = None, False
132
+ target_spacing, apply_reorder = None, False
133
133
  affine_matrix = None
134
134
 
135
135
  @classmethod
@@ -150,7 +150,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
150
150
  if isinstance(fn, torch.Tensor):
151
151
  return cls(fn)
152
152
 
153
- return med_img_reader(fn, resample=cls.resample, reorder=cls.reorder, dtype=cls)
153
+ return med_img_reader(fn, target_spacing=cls.target_spacing, apply_reorder=cls.apply_reorder, dtype=cls)
154
154
 
155
155
  def __new__(cls, x, **kwargs):
156
156
  """Creates a new instance of MedBase from a tensor."""
@@ -196,18 +196,18 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
196
196
  return copied
197
197
 
198
198
  @classmethod
199
- def item_preprocessing(cls, resample: (list, int, tuple), reorder: bool):
199
+ def item_preprocessing(cls, target_spacing: (list, int, tuple), apply_reorder: bool):
200
200
  """
201
- Changes the values for the class variables `resample` and `reorder`.
201
+ Changes the values for the class variables `target_spacing` and `apply_reorder`.
202
202
 
203
203
  Args:
204
- resample : (list, int, tuple)
204
+ target_spacing : (list, int, tuple)
205
205
  A list with voxel spacing.
206
- reorder : bool
206
+ apply_reorder : bool
207
207
  Whether to reorder the data to be closest to canonical (RAS+) orientation.
208
208
  """
209
- cls.resample = resample
210
- cls.reorder = reorder
209
+ cls.target_spacing = target_spacing
210
+ cls.apply_reorder = apply_reorder
211
211
 
212
212
  def show(self, ctx=None, channel: int = 0, slice_index: int = None, anatomical_plane: int = 0, **kwargs):
213
213
  """
@@ -230,7 +230,7 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
230
230
  """
231
231
  return show_med_img(
232
232
  self, ctx=ctx, channel=channel, slice_index=slice_index,
233
- anatomical_plane=anatomical_plane, voxel_size=self.resample,
233
+ anatomical_plane=anatomical_plane, voxel_size=self.target_spacing,
234
234
  **merge(self._show_args, kwargs)
235
235
  )
236
236
 
@@ -247,106 +247,3 @@ class MedImage(MedBase):
247
247
  class MedMask(MedBase):
248
248
  """Subclass of MedBase that represents an mask object."""
249
249
  _show_args = {'alpha':0.5, 'cmap':'tab20'}
250
-
251
- # %% ../nbs/01_vision_core.ipynb 14
252
- import os
253
- from fastai.callback.progress import ProgressCallback
254
- from fastai.callback.core import Callback
255
- import sys
256
- from IPython import get_ipython
257
-
258
- class VSCodeProgressCallback(ProgressCallback):
259
- """Enhanced progress callback that works better in VS Code notebooks."""
260
-
261
- def __init__(self, **kwargs):
262
- super().__init__(**kwargs)
263
- self.is_vscode = self._detect_vscode_environment()
264
- self.lr_find_progress = None
265
-
266
- def _detect_vscode_environment(self):
267
- """Detect if running in VS Code Jupyter environment."""
268
- ipython = get_ipython()
269
- if ipython is None:
270
- return True # Assume VS Code if no IPython (safer default)
271
- # VS Code detection - more comprehensive check
272
- kernel_name = str(type(ipython.kernel)).lower() if hasattr(ipython, 'kernel') else ''
273
- return ('vscode' in kernel_name or
274
- 'zmq' in kernel_name or # VS Code often uses ZMQInteractiveShell
275
- not hasattr(ipython, 'display_pub')) # Missing display publisher often indicates VS Code
276
-
277
- def before_fit(self):
278
- """Initialize progress tracking before training."""
279
- if self.is_vscode:
280
- if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
281
- # This is lr_find, handle differently
282
- print("🔍 Starting Learning Rate Finder...")
283
- self.lr_find_progress = 0
284
- else:
285
- # Regular training
286
- print(f"🚀 Training for {self.learn.n_epoch} epochs...")
287
- super().before_fit()
288
-
289
- def before_epoch(self):
290
- """Initialize epoch progress."""
291
- if self.is_vscode:
292
- if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
293
- print(f"📊 LR Find - Testing learning rates...")
294
- else:
295
- print(f"📈 Epoch {self.epoch+1}/{self.learn.n_epoch}")
296
- sys.stdout.flush()
297
- super().before_epoch()
298
-
299
- def after_batch(self):
300
- """Update progress after each batch."""
301
- super().after_batch()
302
- if self.is_vscode:
303
- if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
304
- # Special handling for lr_find
305
- self.lr_find_progress = getattr(self, 'iter', 0) + 1
306
- total = getattr(self, 'n_iter', 100)
307
- if self.lr_find_progress % max(1, total // 10) == 0:
308
- progress = (self.lr_find_progress / total) * 100
309
- print(f"⏳ LR Find Progress: {self.lr_find_progress}/{total} ({progress:.1f}%)")
310
- sys.stdout.flush()
311
- else:
312
- # Regular training progress
313
- if hasattr(self, 'iter') and hasattr(self, 'n_iter'):
314
- if self.iter % max(1, self.n_iter // 20) == 0:
315
- progress = (self.iter / self.n_iter) * 100
316
- print(f"⏳ Batch {self.iter}/{self.n_iter} ({progress:.1f}%)")
317
- sys.stdout.flush()
318
-
319
- def after_fit(self):
320
- """Complete progress tracking after training."""
321
- if self.is_vscode:
322
- if hasattr(self.learn, 'lr_finder') and self.learn.lr_finder:
323
- print("✅ Learning Rate Finder completed!")
324
- else:
325
- print("✅ Training completed!")
326
- sys.stdout.flush()
327
- super().after_fit()
328
-
329
- def before_validate(self):
330
- """Update before validation."""
331
- if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):
332
- print("🔄 Validating...")
333
- sys.stdout.flush()
334
- super().before_validate()
335
-
336
- def after_validate(self):
337
- """Update after validation."""
338
- if self.is_vscode and not (hasattr(self.learn, 'lr_finder') and self.learn.lr_finder):
339
- print("✅ Validation completed")
340
- sys.stdout.flush()
341
- super().after_validate()
342
-
343
- def setup_vscode_progress():
344
- """Configure fastai to use VS Code-compatible progress callback."""
345
- from fastai.learner import defaults
346
-
347
- # Replace default ProgressCallback with VSCodeProgressCallback
348
- if ProgressCallback in defaults.callbacks:
349
- defaults.callbacks = [cb if cb != ProgressCallback else VSCodeProgressCallback
350
- for cb in defaults.callbacks]
351
-
352
- print("✅ Configured VS Code-compatible progress callback")
fastMONAI/vision_data.py CHANGED
@@ -27,7 +27,7 @@ def pred_to_multiclass_mask(pred: torch.Tensor) -> torch.Tensor:
27
27
 
28
28
  pred = pred.softmax(dim=0)
29
29
 
30
- return pred.argmax(dim=0, keepdims=True)
30
+ return pred.argmax(dim=0, keepdim=True)
31
31
 
32
32
  # %% ../nbs/02_vision_data.ipynb 6
33
33
  def batch_pred_to_multiclass_mask(pred: torch.Tensor) -> (torch.Tensor, int):
@@ -68,12 +68,12 @@ class MedDataBlock(DataBlock):
68
68
  #TODO add get_x
69
69
  def __init__(self, blocks: list = None, dl_type: TfmdDL = None, getters: list = None,
70
70
  n_inp: int | None = None, item_tfms: list = None, batch_tfms: list = None,
71
- reorder: bool = False, resample: (int, list) = None, **kwargs):
71
+ apply_reorder: bool = False, target_spacing: (int, list) = None, **kwargs):
72
72
 
73
73
  super().__init__(blocks, dl_type, getters, n_inp, item_tfms,
74
74
  batch_tfms, **kwargs)
75
75
 
76
- MedBase.item_preprocessing(resample, reorder)
76
+ MedBase.item_preprocessing(target_spacing, apply_reorder)
77
77
 
78
78
  # %% ../nbs/02_vision_data.ipynb 11
79
79
  def MedMaskBlock():
@@ -88,7 +88,7 @@ class MedImageDataLoaders(DataLoaders):
88
88
  @delegates(DataLoaders.from_dblock)
89
89
  def from_df(cls, df, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='',
90
90
  label_col=1, label_delim=None, y_block=None, valid_col=None,
91
- item_tfms=None, batch_tfms=None, reorder=False, resample=None, **kwargs):
91
+ item_tfms=None, batch_tfms=None, apply_reorder=False, target_spacing=None, **kwargs):
92
92
  """Create from DataFrame."""
93
93
 
94
94
  if y_block is None:
@@ -104,8 +104,8 @@ class MedImageDataLoaders(DataLoaders):
104
104
  get_y=ColReader(label_col, label_delim=label_delim),
105
105
  splitter=splitter,
106
106
  item_tfms=item_tfms,
107
- reorder=reorder,
108
- resample=resample
107
+ apply_reorder=apply_reorder,
108
+ target_spacing=target_spacing
109
109
  )
110
110
 
111
111
  return cls.from_dblock(dblock, df, **kwargs)