fastMONAI 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,8 +3,8 @@
3
3
  # %% auto 0
4
4
  __all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity',
5
5
  'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField',
6
- 'RandomBlur', 'RandomGamma', 'RandomMotion', 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip',
7
- 'OneOf']
6
+ 'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomCutout',
7
+ 'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf']
8
8
 
9
9
  # %% ../nbs/03_vision_augment.ipynb 2
10
10
  from fastai.data.all import *
@@ -14,10 +14,10 @@ from monai.transforms import NormalizeIntensity as MonaiNormalizeIntensity
14
14
 
15
15
  # %% ../nbs/03_vision_augment.ipynb 5
16
16
  class CustomDictTransform(ItemTransform):
17
- """A class that serves as a wrapper to perform an identical transformation on both
17
+ """A class that serves as a wrapper to perform an identical transformation on both
18
18
  the image and the target (if it's a mask).
19
19
  """
20
-
20
+
21
21
  split_idx = 0 # Only perform transformations on training data. Use TTA() for transformations on validation data.
22
22
 
23
23
  def __init__(self, aug):
@@ -28,31 +28,42 @@ class CustomDictTransform(ItemTransform):
28
28
  """
29
29
  self.aug = aug
30
30
 
31
+ @property
32
+ def tio_transform(self):
33
+ """Return the underlying TorchIO transform.
34
+
35
+ This property enables using fastMONAI wrappers in patch-based workflows
36
+ where raw TorchIO transforms are needed for tio.Compose().
37
+ """
38
+ return self.aug
39
+
31
40
  def encodes(self, x):
32
41
  """
33
- Applies the stored transformation to an image, and the same random transformation
42
+ Applies the stored transformation to an image, and the same random transformation
34
43
  to the target if it is a mask. If the target is not a mask, it returns the target as is.
35
44
 
36
45
  Args:
37
- x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the
46
+ x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the
38
47
  image and the target.
39
48
 
40
49
  Returns:
41
- Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.
42
- If the target is a mask, it's transformed identically to the image. If the target
50
+ Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.
51
+ If the target is a mask, it's transformed identically to the image. If the target
43
52
  is not a mask, the original target is returned.
44
53
  """
45
54
  img, y_true = x
55
+
56
+ # Use identity affine if MedImage.affine_matrix is not set
57
+ affine = MedImage.affine_matrix if MedImage.affine_matrix is not None else np.eye(4)
46
58
 
47
59
  if isinstance(y_true, (MedMask)):
48
- aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=MedImage.affine_matrix),
49
- mask=tio.LabelMap(tensor=y_true, affine=MedImage.affine_matrix)))
60
+ aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=affine),
61
+ mask=tio.LabelMap(tensor=y_true, affine=affine)))
50
62
  return MedImage.create(aug['img'].data), MedMask.create(aug['mask'].data)
51
63
 
52
64
  aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img)))
53
65
  return MedImage.create(aug['img'].data), y_true
54
66
 
55
-
56
67
  # %% ../nbs/03_vision_augment.ipynb 7
57
68
  def do_pad_or_crop(o, target_shape, padding_mode, mask_name, dtype=torch.Tensor):
58
69
  #TODO:refactorize
@@ -72,6 +83,11 @@ class PadOrCrop(DisplayedTransform):
72
83
  padding_mode=padding_mode,
73
84
  mask_name=mask_name)
74
85
 
86
+ @property
87
+ def tio_transform(self):
88
+ """Return the underlying TorchIO transform."""
89
+ return self.pad_or_crop
90
+
75
91
  def encodes(self, o: (MedImage, MedMask)):
76
92
  return type(o)(self.pad_or_crop(o))
77
93
 
@@ -85,6 +101,11 @@ class ZNormalization(DisplayedTransform):
85
101
  self.z_normalization = tio.ZNormalization(masking_method=masking_method)
86
102
  self.channel_wise = channel_wise
87
103
 
104
+ @property
105
+ def tio_transform(self):
106
+ """Return the underlying TorchIO transform."""
107
+ return self.z_normalization
108
+
88
109
  def encodes(self, o: MedImage):
89
110
  try:
90
111
  if self.channel_wise:
@@ -132,7 +153,12 @@ class RescaleIntensity(DisplayedTransform):
132
153
 
133
154
  def __init__(self, out_min_max: tuple[float, float], in_min_max: tuple[float, float]):
134
155
  self.rescale = tio.RescaleIntensity(out_min_max=out_min_max, in_min_max=in_min_max)
135
-
156
+
157
+ @property
158
+ def tio_transform(self):
159
+ """Return the underlying TorchIO transform."""
160
+ return self.rescale
161
+
136
162
  def encodes(self, o: MedImage):
137
163
  return MedImage.create(self.rescale(o))
138
164
 
@@ -211,8 +237,17 @@ class RandomGhosting(DisplayedTransform):
211
237
  def __init__(self, intensity=(0.5, 1), p=0.5):
212
238
  self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)
213
239
 
240
+ @property
241
+ def tio_transform(self):
242
+ """Return the underlying TorchIO transform."""
243
+ return self.add_ghosts
244
+
214
245
  def encodes(self, o: MedImage):
215
- return MedImage.create(self.add_ghosts(o))
246
+ result = self.add_ghosts(o)
247
+ # Handle potential complex values from k-space operations
248
+ if result.is_complex():
249
+ result = torch.real(result)
250
+ return MedImage.create(result)
216
251
 
217
252
  def encodes(self, o: MedMask):
218
253
  return o
@@ -221,26 +256,40 @@ class RandomGhosting(DisplayedTransform):
221
256
  class RandomSpike(DisplayedTransform):
222
257
  '''Apply TorchIO `RandomSpike`.'''
223
258
 
224
- split_idx,order=0,1
259
+ split_idx, order = 0, 1
225
260
 
226
261
  def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5):
227
262
  self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p)
228
263
 
229
- def encodes(self, o:MedImage):
230
- return MedImage.create(self.add_spikes(o))
264
+ @property
265
+ def tio_transform(self):
266
+ """Return the underlying TorchIO transform."""
267
+ return self.add_spikes
268
+
269
+ def encodes(self, o: MedImage):
270
+ result = self.add_spikes(o)
271
+ # Handle potential complex values from k-space operations
272
+ if result.is_complex():
273
+ result = torch.real(result)
274
+ return MedImage.create(result)
231
275
 
232
- def encodes(self, o:MedMask):
276
+ def encodes(self, o: MedMask):
233
277
  return o
234
278
 
235
279
  # %% ../nbs/03_vision_augment.ipynb 16
236
280
  class RandomNoise(DisplayedTransform):
237
281
  '''Apply TorchIO `RandomNoise`.'''
238
282
 
239
- split_idx,order=0,1
283
+ split_idx, order = 0, 1
240
284
 
241
285
  def __init__(self, mean=0, std=(0, 0.25), p=0.5):
242
286
  self.add_noise = tio.RandomNoise(mean=mean, std=std, p=p)
243
287
 
288
+ @property
289
+ def tio_transform(self):
290
+ """Return the underlying TorchIO transform."""
291
+ return self.add_noise
292
+
244
293
  def encodes(self, o: MedImage):
245
294
  return MedImage.create(self.add_noise(o))
246
295
 
@@ -251,11 +300,16 @@ class RandomNoise(DisplayedTransform):
251
300
  class RandomBiasField(DisplayedTransform):
252
301
  '''Apply TorchIO `RandomBiasField`.'''
253
302
 
254
- split_idx,order=0,1
303
+ split_idx, order = 0, 1
255
304
 
256
305
  def __init__(self, coefficients=0.5, order=3, p=0.5):
257
306
  self.add_biasfield = tio.RandomBiasField(coefficients=coefficients, order=order, p=p)
258
307
 
308
+ @property
309
+ def tio_transform(self):
310
+ """Return the underlying TorchIO transform."""
311
+ return self.add_biasfield
312
+
259
313
  def encodes(self, o: MedImage):
260
314
  return MedImage.create(self.add_biasfield(o))
261
315
 
@@ -264,13 +318,18 @@ class RandomBiasField(DisplayedTransform):
264
318
 
265
319
  # %% ../nbs/03_vision_augment.ipynb 18
266
320
  class RandomBlur(DisplayedTransform):
267
- '''Apply TorchIO `RandomBiasField`.'''
321
+ '''Apply TorchIO `RandomBlur`.'''
268
322
 
269
- split_idx,order=0,1
323
+ split_idx, order = 0, 1
270
324
 
271
325
  def __init__(self, std=(0, 2), p=0.5):
272
326
  self.add_blur = tio.RandomBlur(std=std, p=p)
273
-
327
+
328
+ @property
329
+ def tio_transform(self):
330
+ """Return the underlying TorchIO transform."""
331
+ return self.add_blur
332
+
274
333
  def encodes(self, o: MedImage):
275
334
  return MedImage.create(self.add_blur(o))
276
335
 
@@ -281,12 +340,16 @@ class RandomBlur(DisplayedTransform):
281
340
  class RandomGamma(DisplayedTransform):
282
341
  '''Apply TorchIO `RandomGamma`.'''
283
342
 
284
-
285
- split_idx,order=0,1
343
+ split_idx, order = 0, 1
286
344
 
287
345
  def __init__(self, log_gamma=(-0.3, 0.3), p=0.5):
288
346
  self.add_gamma = tio.RandomGamma(log_gamma=log_gamma, p=p)
289
347
 
348
+ @property
349
+ def tio_transform(self):
350
+ """Return the underlying TorchIO transform."""
351
+ return self.add_gamma
352
+
290
353
  def encodes(self, o: MedImage):
291
354
  return MedImage.create(self.add_gamma(o))
292
355
 
@@ -294,6 +357,38 @@ class RandomGamma(DisplayedTransform):
294
357
  return o
295
358
 
296
359
  # %% ../nbs/03_vision_augment.ipynb 20
360
+ class RandomIntensityScale(DisplayedTransform):
361
+ """Randomly scale image intensities by a multiplicative factor.
362
+
363
+ Useful for domain generalization across different acquisition protocols
364
+ with varying intensity ranges.
365
+
366
+ Args:
367
+ scale_range (tuple[float, float]): Range of scale factors (min, max).
368
+ Values > 1 increase intensity, < 1 decrease intensity.
369
+ p (float): Probability of applying the transform (default: 0.5)
370
+
371
+ Example:
372
+ # Scale intensities randomly between 0.5x and 2.0x
373
+ transform = RandomIntensityScale(scale_range=(0.5, 2.0), p=0.3)
374
+ """
375
+
376
+ split_idx, order = 0, 1
377
+
378
+ def __init__(self, scale_range: tuple[float, float] = (0.5, 2.0), p: float = 0.5):
379
+ self.scale_range = scale_range
380
+ self.p = p
381
+
382
+ def encodes(self, o: MedImage):
383
+ if torch.rand(1).item() > self.p:
384
+ return o
385
+ scale = torch.empty(1).uniform_(self.scale_range[0], self.scale_range[1]).item()
386
+ return MedImage.create(o * scale)
387
+
388
+ def encodes(self, o: MedMask):
389
+ return o
390
+
391
+ # %% ../nbs/03_vision_augment.ipynb 21
297
392
  class RandomMotion(DisplayedTransform):
298
393
  """Apply TorchIO `RandomMotion`."""
299
394
 
@@ -315,13 +410,154 @@ class RandomMotion(DisplayedTransform):
315
410
  p=p
316
411
  )
317
412
 
413
+ @property
414
+ def tio_transform(self):
415
+ """Return the underlying TorchIO transform."""
416
+ return self.add_motion
417
+
318
418
  def encodes(self, o: MedImage):
319
- return MedImage.create(self.add_motion(o))
419
+ result = self.add_motion(o)
420
+ # Handle potential complex values from k-space operations
421
+ if result.is_complex():
422
+ result = torch.real(result)
423
+ return MedImage.create(result)
320
424
 
321
425
  def encodes(self, o: MedMask):
322
426
  return o
323
427
 
324
428
  # %% ../nbs/03_vision_augment.ipynb 22
429
+ def _create_ellipsoid_mask(shape, center, radii):
430
+ """Create a 3D ellipsoid mask.
431
+
432
+ Args:
433
+ shape: (D, H, W) shape of the volume
434
+ center: (z, y, x) center of ellipsoid
435
+ radii: (rz, ry, rx) radii along each axis
436
+
437
+ Returns:
438
+ Boolean mask where True = inside ellipsoid
439
+ """
440
+ z, y, x = torch.meshgrid(
441
+ torch.arange(shape[0]),
442
+ torch.arange(shape[1]),
443
+ torch.arange(shape[2]),
444
+ indexing='ij'
445
+ )
446
+ dist = ((z - center[0]) / radii[0]) ** 2 + \
447
+ ((y - center[1]) / radii[1]) ** 2 + \
448
+ ((x - center[2]) / radii[2]) ** 2
449
+ return dist <= 1.0
450
+
451
+ # %% ../nbs/03_vision_augment.ipynb 23
452
+ class _TioRandomCutout(tio.IntensityTransform):
453
+ """TorchIO-compatible RandomCutout for patch-based workflows."""
454
+
455
+ def __init__(self, holes=1, spatial_size=8, fill_value=None,
456
+ max_holes=None, max_spatial_size=None, p=0.2, **kwargs):
457
+ super().__init__(p=p, **kwargs)
458
+ self.holes = holes
459
+ self.spatial_size = spatial_size
460
+ self.fill_value = fill_value
461
+ self.max_holes = max_holes
462
+ self.max_spatial_size = max_spatial_size
463
+
464
+ def _apply_cutout(self, data, fill_val):
465
+ """Apply spherical cutout(s) to a tensor."""
466
+ result = data.clone()
467
+ n_holes = torch.randint(self.holes, (self.max_holes or self.holes) + 1, (1,)).item()
468
+
469
+ spatial_shape = data.shape[1:] # (D, H, W)
470
+ min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]
471
+ max_size = self.max_spatial_size or self.spatial_size
472
+ max_size = max_size if isinstance(max_size, int) else max_size[0]
473
+
474
+ for _ in range(n_holes):
475
+ # Random size for this hole
476
+ size = torch.randint(min_size, max_size + 1, (3,))
477
+ radii = size.float() / 2
478
+
479
+ # Random center (ensure hole fits in volume)
480
+ center = [
481
+ torch.randint(int(radii[i].item()),
482
+ max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),
483
+ (1,)).item()
484
+ for i in range(3)
485
+ ]
486
+
487
+ mask = _create_ellipsoid_mask(spatial_shape, center, radii)
488
+ result[:, mask] = fill_val
489
+
490
+ return result
491
+
492
+ def apply_transform(self, subject):
493
+ for image in self.get_images(subject):
494
+ data = image.data
495
+ fill_val = self.fill_value if self.fill_value is not None else float(data.min())
496
+ result = self._apply_cutout(data, fill_val)
497
+ image.set_data(result)
498
+ return subject
499
+
500
+ # %% ../nbs/03_vision_augment.ipynb 24
501
+ class RandomCutout(DisplayedTransform):
502
+ """Randomly erase spherical regions in 3D medical images.
503
+
504
+ Simulates post-operative surgical cavities by filling random ellipsoid
505
+ volumes with specified values. Useful for training on pre-op images
506
+ to generalize to post-op scans.
507
+
508
+ Args:
509
+ holes: Minimum number of cutout regions. Default: 1.
510
+ max_holes: Maximum number of regions. Default: 3.
511
+ spatial_size: Minimum cutout diameter in voxels. Default: 8.
512
+ max_spatial_size: Maximum cutout diameter. Default: 16.
513
+ fill: Fill value - 'min', 'mean', 'random', or float. Default: 'min'.
514
+ p: Probability of applying transform. Default: 0.2.
515
+
516
+ Example:
517
+ >>> # Simulate post-op cavities with dark spherical voids
518
+ >>> tfm = RandomCutout(holes=1, max_holes=2, spatial_size=10,
519
+ ... max_spatial_size=25, fill='min', p=0.2)
520
+ """
521
+
522
+ split_idx, order = 0, 1
523
+
524
+ def __init__(self, holes=1, max_holes=3, spatial_size=8,
525
+ max_spatial_size=16, fill='min', p=0.2):
526
+ self.holes = holes
527
+ self.max_holes = max_holes
528
+ self.spatial_size = spatial_size
529
+ self.max_spatial_size = max_spatial_size
530
+ self.fill = fill
531
+ self.p = p
532
+
533
+ self._tio_cutout = _TioRandomCutout(
534
+ holes=holes, spatial_size=spatial_size,
535
+ fill_value=None if isinstance(fill, str) else fill,
536
+ max_holes=max_holes, max_spatial_size=max_spatial_size, p=p
537
+ )
538
+
539
+ @property
540
+ def tio_transform(self):
541
+ """Return TorchIO-compatible transform for patch-based workflows."""
542
+ return self._tio_cutout
543
+
544
+ def _get_fill_value(self, tensor):
545
+ if self.fill == 'min': return float(tensor.min())
546
+ elif self.fill == 'mean': return float(tensor.mean())
547
+ elif self.fill == 'random':
548
+ return torch.empty(1).uniform_(float(tensor.min()), float(tensor.max())).item()
549
+ else: return self.fill
550
+
551
+ def encodes(self, o: MedImage):
552
+ if torch.rand(1).item() > self.p: return o
553
+ fill_val = self._get_fill_value(o)
554
+ result = self._tio_cutout._apply_cutout(o.clone(), fill_val)
555
+ return MedImage.create(result)
556
+
557
+ def encodes(self, o: MedMask):
558
+ return o
559
+
560
+ # %% ../nbs/03_vision_augment.ipynb 26
325
561
  class RandomElasticDeformation(CustomDictTransform):
326
562
  """Apply TorchIO `RandomElasticDeformation`."""
327
563
 
@@ -334,7 +570,7 @@ class RandomElasticDeformation(CustomDictTransform):
334
570
  image_interpolation=image_interpolation,
335
571
  p=p))
336
572
 
337
- # %% ../nbs/03_vision_augment.ipynb 23
573
+ # %% ../nbs/03_vision_augment.ipynb 27
338
574
  class RandomAffine(CustomDictTransform):
339
575
  """Apply TorchIO `RandomAffine`."""
340
576
 
@@ -350,14 +586,14 @@ class RandomAffine(CustomDictTransform):
350
586
  default_pad_value=default_pad_value,
351
587
  p=p))
352
588
 
353
- # %% ../nbs/03_vision_augment.ipynb 24
589
+ # %% ../nbs/03_vision_augment.ipynb 28
354
590
  class RandomFlip(CustomDictTransform):
355
591
  """Apply TorchIO `RandomFlip`."""
356
592
 
357
593
  def __init__(self, axes='LR', p=0.5):
358
594
  super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))
359
595
 
360
- # %% ../nbs/03_vision_augment.ipynb 25
596
+ # %% ../nbs/03_vision_augment.ipynb 29
361
597
  class OneOf(CustomDictTransform):
362
598
  """Apply only one of the given transforms using TorchIO `OneOf`."""
363
599