batchgenerators 0.25__tar.gz → 0.25.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {batchgenerators-0.25 → batchgenerators-0.25.1}/PKG-INFO +10 -1
  2. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/augmentations/spatial_transformations.py +173 -1
  3. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/augmentations/utils.py +24 -1
  4. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/dataloading/multi_threaded_augmenter.py +1 -1
  5. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +12 -6
  6. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/noise_transforms.py +15 -10
  7. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/spatial_transforms.py +193 -1
  8. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/utility_transforms.py +17 -1
  9. batchgenerators-0.25.1/batchgenerators/utilities/file_and_folder_operations.py +135 -0
  10. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators.egg-info/PKG-INFO +10 -1
  11. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators.egg-info/SOURCES.txt +12 -1
  12. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators.egg-info/requires.txt +1 -0
  13. {batchgenerators-0.25 → batchgenerators-0.25.1}/setup.py +2 -1
  14. batchgenerators-0.25.1/tests/test_DataLoader.py +262 -0
  15. batchgenerators-0.25.1/tests/test_augment_zoom.py +73 -0
  16. batchgenerators-0.25.1/tests/test_axis_mirroring.py +173 -0
  17. batchgenerators-0.25.1/tests/test_color_augmentations.py +158 -0
  18. batchgenerators-0.25.1/tests/test_crop.py +470 -0
  19. batchgenerators-0.25.1/tests/test_multithreaded_augmenter.py +222 -0
  20. batchgenerators-0.25.1/tests/test_normalizations.py +235 -0
  21. batchgenerators-0.25.1/tests/test_random_crop.py +76 -0
  22. batchgenerators-0.25.1/tests/test_resample_augmentations.py +70 -0
  23. batchgenerators-0.25.1/tests/test_sanity.py +26 -0
  24. batchgenerators-0.25.1/tests/test_spatial_transformations.py +141 -0
  25. batchgenerators-0.25/batchgenerators/utilities/file_and_folder_operations.py +0 -100
  26. {batchgenerators-0.25 → batchgenerators-0.25.1}/LICENSE +0 -0
  27. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/__init__.py +0 -0
  28. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/augmentations/__init__.py +0 -0
  29. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/augmentations/color_augmentations.py +0 -0
  30. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/augmentations/crop_and_pad_augmentations.py +0 -0
  31. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/augmentations/noise_augmentations.py +0 -0
  32. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/augmentations/normalizations.py +0 -0
  33. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/augmentations/resample_augmentations.py +0 -0
  34. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/dataloading/__init__.py +0 -0
  35. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/dataloading/data_loader.py +0 -0
  36. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/dataloading/dataset.py +0 -0
  37. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/dataloading/single_threaded_augmenter.py +0 -0
  38. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/datasets/__init__.py +0 -0
  39. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/datasets/cifar.py +0 -0
  40. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/examples/__init__.py +0 -0
  41. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/__init__.py +0 -0
  42. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/brats2017_dataloader_2D.py +0 -0
  43. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py +0 -0
  44. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/brats2017_preprocessing.py +0 -0
  45. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/config.py +0 -0
  46. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/examples/cifar10.py +0 -0
  47. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/examples/multithreaded_dataloading.py +0 -0
  48. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/__init__.py +0 -0
  49. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/abstract_transforms.py +0 -0
  50. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/channel_selection_transforms.py +0 -0
  51. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/color_transforms.py +0 -0
  52. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/crop_and_pad_transforms.py +0 -0
  53. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/local_transforms.py +0 -0
  54. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/resample_transforms.py +0 -0
  55. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/transforms/sample_normalization_transforms.py +0 -0
  56. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/utilities/__init__.py +0 -0
  57. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/utilities/custom_types.py +0 -0
  58. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators/utilities/data_splitting.py +0 -0
  59. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators.egg-info/dependency_links.txt +0 -0
  60. {batchgenerators-0.25 → batchgenerators-0.25.1}/batchgenerators.egg-info/top_level.txt +0 -0
  61. {batchgenerators-0.25 → batchgenerators-0.25.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: batchgenerators
3
- Version: 0.25
3
+ Version: 0.25.1
4
4
  Summary: Data augmentation toolkit
5
5
  Home-page: https://github.com/MIC-DKFZ/batchgenerators
6
6
  Author: Division of Medical Image Computing, German Cancer Research Center AND Applied Computer Vision Lab, Helmholtz Imaging Platform
@@ -8,3 +8,12 @@ Author-email: f.isensee@dkfz-heidelberg.de
8
8
  License: Apache License Version 2.0, January 2004
9
9
  Keywords: data augmentation,deep learning,image segmentation,image classification,medical image analysis,medical image segmentation
10
10
  License-File: LICENSE
11
+ Requires-Dist: pillow>=7.1.2
12
+ Requires-Dist: numpy>=1.10.2
13
+ Requires-Dist: scipy
14
+ Requires-Dist: scikit-image
15
+ Requires-Dist: scikit-learn
16
+ Requires-Dist: future
17
+ Requires-Dist: pandas
18
+ Requires-Dist: unittest2
19
+ Requires-Dist: threadpoolctl
@@ -16,10 +16,12 @@
16
16
  from builtins import range
17
17
 
18
18
  import numpy as np
19
+ from scipy.ndimage import map_coordinates
19
20
  from batchgenerators.augmentations.utils import create_zero_centered_coordinate_mesh, elastic_deform_coordinates, \
20
21
  interpolate_img, \
21
22
  rotate_coords_2d, rotate_coords_3d, scale_coords, resize_segmentation, resize_multichannel_image, \
22
- elastic_deform_coordinates_2
23
+ elastic_deform_coordinates_2, \
24
+ get_organ_gradient_field, ignore_anatomy
23
25
  from batchgenerators.augmentations.crop_and_pad_augmentations import random_crop as random_crop_aug
24
26
  from batchgenerators.augmentations.crop_and_pad_augmentations import center_crop as center_crop_aug
25
27
 
@@ -484,3 +486,173 @@ def augment_transpose_axes(data_sample, seg_sample, axes=(0, 1, 2)):
484
486
  if seg_sample is not None:
485
487
  seg_sample = seg_sample.transpose(*static_axes)
486
488
  return data_sample, seg_sample
489
+
490
+ def augment_anatomy_informed(data, seg,
491
+ active_organs, dilation_ranges, directions_of_trans, modalities,
492
+ spacing_ratio=0.3125/3.0, blur=32, anisotropy_safety= True,
493
+ max_annotation_value=1, replace_value=0):
494
+ if sum(active_organs) > 0:
495
+ data_shape = data.shape
496
+ coords = create_zero_centered_coordinate_mesh(data_shape[-3:])
497
+
498
+ for organ_idx, active in reversed(list(enumerate(active_organs))):
499
+ if active:
500
+ dil_magnitude = np.random.uniform(low=dilation_ranges[organ_idx][0], high=dilation_ranges[organ_idx][1])
501
+
502
+ t, u, v = get_organ_gradient_field(seg == organ_idx + 2,
503
+ spacing_ratio=spacing_ratio,
504
+ blur=blur)
505
+
506
+ if directions_of_trans[organ_idx][0]:
507
+ coords[0, :, :, :] = coords[0, :, :, :] + t * dil_magnitude * spacing_ratio
508
+ if directions_of_trans[organ_idx][1]:
509
+ coords[1, :, :, :] = coords[1, :, :, :] + u * dil_magnitude
510
+ if directions_of_trans[organ_idx][2]:
511
+ coords[2, :, :, :] = coords[2, :, :, :] + v * dil_magnitude
512
+
513
+ for d in range(3):
514
+ ctr = data.shape[d+1] / 2 # !!!
515
+ coords[d] += ctr - 0.5 # !!!
516
+
517
+ if anisotropy_safety:
518
+ coords[0, 0, :, :][coords[0, 0, :, :] < 0] = 0.0
519
+ coords[0, 1, :, :][coords[0, 1, :, :] < 0] = 0.0
520
+ coords[0, -1, :, :][coords[0, -1, :, :] > (data_shape[-3] - 1)] = data_shape[-3] - 1
521
+ coords[0, -2, :, :][coords[0, -2, :, :] > (data_shape[-3] - 1)] = data_shape[-3] - 1
522
+
523
+
524
+ for modality in modalities:
525
+ data[modality, :, :, :] = map_coordinates(data[modality, :, :, :], coords, order=1, mode='constant')
526
+
527
+ seg[:, :, :] = ignore_anatomy(seg[:, :, :], max_annotation_value=max_annotation_value, replace_value=replace_value)
528
+ seg[:, :, :] = map_coordinates(seg[:, :, :], coords, order=0, mode='constant')
529
+
530
+ else:
531
+ seg[:, :, :] = ignore_anatomy(seg[:, :, :], max_annotation_value=max_annotation_value, replace_value=replace_value)
532
+
533
+ return data, seg
534
+
535
+
536
+ def augment_misalign(data, seg, data_size,
537
+ im_channels_2_misalign=[0, ],
538
+ label_channels_2_misalign=[0, ],
539
+ do_squeeze=False,
540
+ sq_x=[1.0, 1.0], sq_y=[0.9, 1.1], sq_z=[1.0, 1.0],
541
+ p_sq_per_sample=0.1, p_sq_per_dir=1.0,
542
+ do_rotation=False,
543
+ angle_x=(-0 / 360. * 2 * np.pi, 0 / 360. * 2 * np.pi),
544
+ angle_y=(-0 / 360. * 2 * np.pi, 0 / 360. * 2 * np.pi),
545
+ angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
546
+ p_rot_per_sample=0.1, p_rot_per_axis=1.0,
547
+ tr_x=[-32, 32], tr_y=[-32, 32], tr_z=[-2, 2],
548
+ p_transl_per_sample=0.1, p_transl_per_dir=1.0,
549
+ do_transl=False,
550
+ border_mode_data='constant', border_cval_data=0,
551
+ border_mode_seg='constant', border_cval_seg=0,
552
+ order_data=3, order_seg=0):
553
+
554
+ dim = len(data_size)
555
+
556
+ for sample_id in range(data.shape[0]):
557
+
558
+ if do_squeeze and np.random.uniform() < p_sq_per_sample:
559
+ coords = create_zero_centered_coordinate_mesh(data_size)
560
+ sq = []
561
+ if dim == 3:
562
+ if np.random.uniform() <= p_sq_per_dir:
563
+ sq.append(np.random.uniform(sq_z[0], sq_z[1]))
564
+ else:
565
+ sq.append(1.0)
566
+
567
+ if np.random.uniform() <= p_sq_per_dir:
568
+ sq.append(np.random.uniform(sq_y[0], sq_y[1]))
569
+ else:
570
+ sq.append(1.0)
571
+
572
+ if np.random.uniform() <= p_sq_per_dir:
573
+ sq.append(np.random.uniform(sq_x[0], sq_x[1]))
574
+ else:
575
+ sq.append(1.0)
576
+ coords = scale_coords(coords, sq)
577
+ for d in range(dim):
578
+ ctr = data.shape[d + 2] / 2. - 0.5
579
+ coords[d] += ctr
580
+
581
+ for channel_id in range(data.shape[1]):
582
+ if channel_id in im_channels_2_misalign:
583
+ data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
584
+ border_mode_data, cval=border_cval_data)
585
+ if seg is not None:
586
+ for channel_id in range(seg.shape[1]):
587
+ if channel_id in im_channels_2_misalign:
588
+ seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
589
+ border_mode_seg, cval=border_cval_seg,
590
+ is_seg=True)
591
+
592
+ if do_rotation and np.random.uniform() < p_rot_per_sample:
593
+ coords = create_zero_centered_coordinate_mesh(data_size)
594
+ if np.random.uniform() <= p_rot_per_axis:
595
+ a_z = np.random.uniform(angle_z[0], angle_z[1])
596
+ else:
597
+ a_z = 0
598
+ if dim == 3:
599
+ if np.random.uniform() <= p_rot_per_axis:
600
+ a_y = np.random.uniform(angle_y[0], angle_y[1])
601
+ else:
602
+ a_y = 0
603
+ if np.random.uniform() <= p_rot_per_axis:
604
+ a_x = np.random.uniform(angle_x[0], angle_x[1])
605
+ else:
606
+ a_x = 0
607
+ coords = rotate_coords_3d(coords, a_z, a_y, a_x)
608
+ else:
609
+ coords = rotate_coords_2d(coords, a_z)
610
+ for d in range(dim):
611
+ ctr = data.shape[d + 2] / 2. - 0.5
612
+ coords[d] += ctr
613
+
614
+ for channel_id in range(data.shape[1]):
615
+ if channel_id in im_channels_2_misalign:
616
+ data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
617
+ border_mode_data, cval=border_cval_data)
618
+ if seg is not None:
619
+ for channel_id in range(seg.shape[1]):
620
+ if channel_id in im_channels_2_misalign:
621
+ seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
622
+ border_mode_seg, cval=border_cval_seg,
623
+ is_seg=True)
624
+
625
+ if do_transl and np.random.uniform() < p_transl_per_sample:
626
+ coords = create_zero_centered_coordinate_mesh(data_size)
627
+ tr = []
628
+ if dim == 3:
629
+ if np.random.uniform() <= p_transl_per_dir:
630
+ tr.append(np.random.uniform(tr_z[0], tr_z[1]))
631
+ else:
632
+ tr.append(1.0)
633
+
634
+ if np.random.uniform() <= p_transl_per_dir:
635
+ tr.append(np.random.uniform(tr_y[0], tr_y[1]))
636
+ else:
637
+ tr.append(1.0)
638
+
639
+ if np.random.uniform() <= p_transl_per_dir:
640
+ tr.append(np.random.uniform(tr_x[0], tr_x[1]))
641
+ else:
642
+ tr.append(1.0)
643
+
644
+ for d in range(dim):
645
+ ctr = data.shape[d + 2] / 2. - 0.5 + tr[d]
646
+ coords[d] += ctr
647
+
648
+ for channel_id in range(data.shape[1]):
649
+ if channel_id in im_channels_2_misalign:
650
+ data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
651
+ border_mode_data, cval=border_cval_data)
652
+ if seg is not None:
653
+ for channel_id in range(seg.shape[1]):
654
+ if channel_id in label_channels_2_misalign:
655
+ seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
656
+ border_mode_seg, cval=border_cval_seg,
657
+ is_seg=True)
658
+ return data, seg
@@ -21,6 +21,7 @@ from scipy.ndimage.filters import gaussian_filter, gaussian_gradient_magnitude
21
21
  from scipy.ndimage.morphology import grey_dilation
22
22
  from skimage.transform import resize
23
23
  from scipy.ndimage.measurements import label as lb
24
+ import pandas as pd
24
25
 
25
26
 
26
27
  def generate_elastic_transform_coordinates(shape, alpha, sigma):
@@ -591,13 +592,13 @@ def resize_segmentation(segmentation, new_shape, order=3):
591
592
  :return:
592
593
  '''
593
594
  tpe = segmentation.dtype
594
- unique_labels = np.unique(segmentation)
595
595
  assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
596
596
  if order == 0:
597
597
  return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype(tpe)
598
598
  else:
599
599
  reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
600
600
 
601
+ unique_labels = np.sort(pd.unique(segmentation.ravel()))
601
602
  for i, c in enumerate(unique_labels):
602
603
  mask = segmentation == c
603
604
  reshaped_multihot = resize(mask.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
@@ -773,3 +774,25 @@ def mask_random_squares(img, square_size, n_squares, n_val, channel_wise_n_val=F
773
774
  img = mask_random_square(img, square_size, n_val, channel_wise_n_val=channel_wise_n_val,
774
775
  square_pos=square_pos)
775
776
  return img
777
+
778
+ def get_organ_gradient_field(organ, spacing_ratio=0.3125/3.0, blur=32):
779
+ """
780
+ Calculates the gradient field around the organ segmentations for the anatomy-informed augmentation
781
+
782
+ :param organ: binary organ segmentation
783
+ :param spacing_ratio: ratio of the axial spacing and the slice thickness, needed for the right vector field calculation
784
+ :param blur: kernel constant
785
+ """
786
+ organ_blurred = gaussian_filter(organ.astype(float),
787
+ sigma=(blur * spacing_ratio, blur, blur),
788
+ order=0,
789
+ mode='nearest')
790
+
791
+ t, u, v = np.gradient(organ_blurred)
792
+ t = t * spacing_ratio
793
+
794
+ return t, u, v
795
+
796
+ def ignore_anatomy(segm, max_annotation_value=1, replace_value=0):
797
+ segm[segm > max_annotation_value] = replace_value
798
+ return segm
@@ -80,7 +80,7 @@ def results_loop(in_queues: List[Queue], out_queue: thrQueue, abort_event: Event
80
80
  end_ctr = 0
81
81
 
82
82
  while True:
83
- # if abort_event is set we need to clean up. This is where it hangs sometimes so it makes sense to drain all
83
+ # if abort_event is set we need to clean up. This is where it hangs sometimes, so it makes sense to drain all
84
84
  # the incoming queues and ignore all the errors occuring during this process.
85
85
  try:
86
86
  if abort_event.is_set():
@@ -208,12 +208,18 @@ class NonDetMultiThreadedAugmenter(object):
208
208
  if isinstance(self.generator, DataLoader):
209
209
  self.generator.was_initialized = False
210
210
 
211
- for i in range(self.num_processes):
212
- self._processes.append(Process(target=producer, args=(
213
- self._queue, self.generator, self.transform, i, self.seeds[i], self.abort_event, self.wait_time
214
- )))
215
- self._processes[-1].daemon = True
216
- _ = [i.start() for i in self._processes]
211
+ if torch is not None:
212
+ torch_nthreads = torch.get_num_threads()
213
+ torch.set_num_threads(1)
214
+ with threadpool_limits(limits=1, user_api=None):
215
+ for i in range(self.num_processes):
216
+ self._processes.append(Process(target=producer, args=(
217
+ self._queue, self.generator, self.transform, i, self.seeds[i], self.abort_event, self.wait_time
218
+ )))
219
+ self._processes[-1].daemon = True
220
+ _ = [i.start() for i in self._processes]
221
+ if torch is not None:
222
+ torch.set_num_threads(torch_nthreads)
217
223
 
218
224
  if torch is not None and torch.cuda.is_available():
219
225
  gpu = torch.cuda.current_device()
@@ -133,6 +133,20 @@ class BlankSquareNoiseTransform(AbstractTransform):
133
133
  self.channel_wise_n_val, self.square_pos)
134
134
  return data_dict
135
135
 
136
+ class ColorFunctionExtractor:
137
+ def __init__(self, rectangle_value):
138
+ self.rectangle_value = rectangle_value
139
+
140
+ def __call__(self, x):
141
+ if np.isscalar(self.rectangle_value):
142
+ return self.rectangle_value
143
+ elif callable(self.rectangle_value):
144
+ return self.rectangle_value(x)
145
+ elif isinstance(self.rectangle_value, (tuple, list)):
146
+ return np.random.uniform(*self.rectangle_value)
147
+ else:
148
+ raise RuntimeError("unrecognized format for rectangle_value")
149
+
136
150
 
137
151
  class BlankRectangleTransform(AbstractTransform):
138
152
  def __init__(self, rectangle_size, rectangle_value, num_rectangles, force_square=False, p_per_sample=0.5,
@@ -179,16 +193,7 @@ class BlankRectangleTransform(AbstractTransform):
179
193
  self.p_per_sample = p_per_sample
180
194
  self.p_per_channel = p_per_channel
181
195
  self.apply_to_keys = apply_to_keys
182
-
183
- # intensity value
184
- if np.isscalar(rectangle_value):
185
- self.color_fn = lambda x: rectangle_value
186
- elif callable(rectangle_value):
187
- self.color_fn = lambda x: rectangle_value(x)
188
- elif isinstance(rectangle_value, (tuple, list)):
189
- self.color_fn = lambda x: np.random.uniform(*rectangle_value)
190
- else:
191
- raise RuntimeError("unrecognized format for rectangle_value")
196
+ self.color_fn = ColorFunctionExtractor(rectangle_value)
192
197
 
193
198
  def __call__(self, **data_dict):
194
199
  for k in self.apply_to_keys:
@@ -16,8 +16,10 @@
16
16
  from batchgenerators.transforms.abstract_transforms import AbstractTransform
17
17
  from batchgenerators.augmentations.spatial_transformations import augment_spatial, augment_spatial_2, \
18
18
  augment_channel_translation, \
19
- augment_mirroring, augment_transpose_axes, augment_zoom, augment_resize, augment_rot90
19
+ augment_mirroring, augment_transpose_axes, augment_zoom, augment_resize, augment_rot90, \
20
+ augment_anatomy_informed, augment_misalign
20
21
  import numpy as np
22
+ from batchgenerators.augmentations.utils import get_organ_gradient_field
21
23
 
22
24
 
23
25
  class Rot90Transform(AbstractTransform):
@@ -524,3 +526,193 @@ class TransposeAxesTransform(AbstractTransform):
524
526
  data_dict[self.label_key] = seg
525
527
  return data_dict
526
528
 
529
+ class AnatomyInformedTransform(AbstractTransform):
530
+ """
531
+ The data augmentation is presented at MICCAI 2023 in the proceedings of 'Anatomy-informed Data Augmentation for enhanced Prostate Cancer Detection'.
532
+ It simulates the distension or evacuation of the bladder or/and rectal space to mimic typical physiological soft tissue deformations of the prostate
533
+ and generates unique lesion shapes without altering their label.
534
+ You can find more information here: https://github.com/MIC-DKFZ/anatomy_informed_DA
535
+ If you use this augmentation please cite it.
536
+
537
+ Args:
538
+ `dil_ranges`: dilation range per organs
539
+ `modalities`: on which input channels should the transformation be applied
540
+ `directions_of_trans`: to which directions should the organs be dilated per organs
541
+ `p_per_sample`: probability of the transformation per organs
542
+ `spacing_ratio`: ratio of the transversal plane spacing and the slice thickness, in our case it was 0.3125/3
543
+ `blur`: Gaussian kernel parameter, we used the value 32 for 0.3125mm transversal plane spacing
544
+ `anisotropy_safety`: it provides a certain protection against transformation artifacts in 2 slices from the image border
545
+ `max_annotation_value`: the value that should be still relevant for the main task
546
+ `replace_value`: segmentation values larger than the `max_annotation_value` will be replaced with
547
+ """
548
+ def __init__(self, dil_ranges, modalities, directions_of_trans, p_per_sample,
549
+ spacing_ratio=0.3125/3.0, blur=32, anisotropy_safety= True,
550
+ max_annotation_value=1, replace_value=0):
551
+ self.dil_ranges = dil_ranges
552
+ self.modalities = modalities
553
+
554
+ self.directions_of_trans = directions_of_trans
555
+ self.p_per_sample = p_per_sample
556
+ self.spacing_ratio = spacing_ratio
557
+ self.blur = blur
558
+ self.anisotropy_safety = anisotropy_safety
559
+
560
+ self.max_annotation_value = max_annotation_value
561
+ self.replace_value = replace_value
562
+
563
+ self.dim = 3
564
+
565
+ def __call__(self, **data_dict):
566
+
567
+ data_shape = data_dict['data'].shape
568
+ if len(data_shape) == 5:
569
+ self.dim = 3
570
+
571
+ active_organs = []
572
+ for prob in self.p_per_sample:
573
+ if np.random.uniform() < prob:
574
+ active_organs.append(1)
575
+ else:
576
+ active_organs.append(0)
577
+
578
+ for b in range(data_shape[0]):
579
+ data_dict['data'][b, :, :, :, :], data_dict['seg'][b, 0, :, :, :] = augment_anatomy_informed(data=data_dict['data'][b, :, :, :, :],
580
+ seg=data_dict['seg'][b, 0, :, :, :],
581
+ active_organs=active_organs,
582
+ dilation_ranges=self.dil_ranges,
583
+ directions_of_trans=self.directions_of_trans,
584
+ modalities=self.modalities,
585
+ spacing_ratio=self.spacing_ratio,
586
+ blur=self.blur,
587
+ anisotropy_safety=self.anisotropy_safety,
588
+ max_annotation_value=self.max_annotation_value,
589
+ replace_value=self.replace_value)
590
+ return data_dict
591
+
592
+
593
+ class MisalignTransform(AbstractTransform):
594
+ """
595
+ The misalignment data augmentation is introduced in Nature Scientific reports 2023.
596
+ It simulates additional misalignments/registration errors between multi-channel (multi-modal, longitudinal)
597
+ data to make neural networks robust for misalignments.
598
+ Currently the following transformations are supported, but they can be extended easily:
599
+ - squeezing/scaling (good approximation for misalignments between the T2w and DWI MRI sequences)
600
+ - rotation
601
+ - channel shift/translation
602
+ You can find more information here: https://github.com/MIC-DKFZ/misalignmnet_DA
603
+ If you use this augmentation please cite it: https://www.nature.com/articles/s41598-023-46747-z
604
+ Always double check whether the directions/axes are correct!
605
+
606
+ Additional Misalignment-related Args to the Spatial Transforms:
607
+ `im_channels_2_misalign`: on which image channels should the transformation be applied
608
+ `label_channels_2_misalign`: on which segmentation channels should the transformation be applied
609
+ `do_squeeze`: whether misalignment resulted from squeezing is necessary
610
+ `sq_x, sq_y`, `sq_z`: squeezing/scaling ranges per directions, randomly sampled from interval.
611
+ `p_sq_per_sample`: probability of the transformation per sample
612
+ `p_sq_per_dir`: probability of the transformation per direction
613
+ `do_rotation`: whether misalignment resulted from rotation is necessary
614
+ `angle_x`, `angle_y`, `angle_z`: rotation angels per axes, randomly sampled from interval.
615
+ `p_rot_per_sample`: probability of the transformation per sample
616
+ `p_rot_per_axis`: probability of the transformation per axes
617
+ `do_transl`: whether misalignment resulted from rotation is necessary
618
+ `tr_x`, `tr_y`, `tr_z`: shift/translation per directions, randomly sampled from interval.
619
+ `p_transl_per_sample`: probability of the transformation per sample
620
+ `p_transl_per_dir`: probability of the transformation per direction
621
+ """
622
+
623
+ def __init__(self, data_key="data", label_key="seg",
624
+ im_channels_2_misalign=[0, ], label_channels_2_misalign=[0, ],
625
+ do_squeeze=True, sq_x=[1.0, 1.0], sq_y=[0.9, 1.1], sq_z=[1.0, 1.0],
626
+ p_sq_per_sample=0.1, p_sq_per_dir=1.0,
627
+ do_rotation=True,
628
+ angle_x=(-0 / 360. * 2 * np.pi, 0 / 360. * 2 * np.pi),
629
+ angle_y=(-0 / 360. * 2 * np.pi, 0 / 360. * 2 * np.pi),
630
+ angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
631
+ p_rot_per_sample=0.1, p_rot_per_axis=1.0,
632
+ do_transl=True, tr_x=[-32, 32], tr_y=[-32, 32], tr_z=[-2, 2],
633
+ p_transl_per_sample=0.1, p_transl_per_dir=1.0,
634
+ border_mode_data='constant', border_cval_data=0,
635
+ border_mode_seg='constant', border_cval_seg=0,
636
+ order_data=3, order_seg=0):
637
+
638
+ self.data_key = data_key
639
+ self.label_key = label_key
640
+
641
+ self.im_channels_2_misalign = im_channels_2_misalign
642
+ self.label_channels_2_misalign = label_channels_2_misalign
643
+
644
+ self.do_squeeze = do_squeeze
645
+ self.sq_x = sq_x
646
+ self.sq_y = sq_y
647
+ self.sq_z = sq_z
648
+ self.p_sq_per_sample = p_sq_per_sample
649
+ self.p_sq_per_dir = p_sq_per_dir
650
+
651
+ self.do_rotation = do_rotation
652
+ self.angle_x = angle_x
653
+ self.angle_y = angle_y
654
+ self.angle_z = angle_z
655
+ self.p_rot_per_sample = p_rot_per_sample
656
+ self.p_rot_per_axis = p_rot_per_axis
657
+
658
+ self.do_transl = do_transl
659
+ self.tr_x = tr_x
660
+ self.tr_y = tr_y
661
+ self.tr_z = tr_z
662
+ self.p_transl_per_sample = p_transl_per_sample
663
+ self.p_transl_per_dir = p_transl_per_dir
664
+
665
+ self.order_data = order_data
666
+ self.order_seg = order_seg
667
+ self.border_mode_data = border_mode_data
668
+ self.border_cval_data = border_cval_data
669
+ self.border_mode_seg = border_mode_seg
670
+ self.border_cval_seg = border_cval_seg
671
+
672
+ def __call__(self, **data_dict):
673
+ data = data_dict.get(self.data_key)
674
+ seg = data_dict.get(self.label_key)
675
+
676
+ if data.shape[1] < 2:
677
+ raise ValueError("only support multi-modal images")
678
+ else:
679
+ if len(data.shape) == 4:
680
+ data_size = (data.shape[2], data.shape[3])
681
+ elif len(data.shape) == 5:
682
+ data_size = (data.shape[2], data.shape[3], data.shape[4])
683
+ else:
684
+ raise ValueError("only support 2D/3D batch data.")
685
+
686
+ ret_val = augment_misalign(data, seg, data_size=data_size,
687
+ im_channels_2_misalign=self.im_channels_2_misalign,
688
+ label_channels_2_misalign=self.label_channels_2_misalign,
689
+ do_squeeze=self.do_squeeze,
690
+ sq_x=self.sq_x,
691
+ sq_y=self.sq_y,
692
+ sq_z=self.sq_z,
693
+ p_sq_per_sample=self.p_sq_per_sample,
694
+ p_sq_per_dir=self.p_sq_per_dir,
695
+ do_rotation=self.do_rotation,
696
+ angle_x=self.angle_x,
697
+ angle_y=self.angle_y,
698
+ angle_z=self.angle_z,
699
+ p_rot_per_sample=self.p_rot_per_sample,
700
+ p_rot_per_axis=self.p_rot_per_axis,
701
+ do_transl=self.do_transl,
702
+ tr_x=self.tr_x,
703
+ tr_y=self.tr_y,
704
+ tr_z=self.tr_z,
705
+ p_transl_per_sample=self.p_transl_per_sample,
706
+ p_transl_per_dir=self.p_transl_per_dir,
707
+ order_data=self.order_data,
708
+ border_mode_data=self.border_mode_data,
709
+ border_cval_data=self.border_cval_data,
710
+ order_seg=self.order_seg,
711
+ border_mode_seg=self.border_mode_seg,
712
+ border_cval_seg=self.border_cval_seg)
713
+
714
+ data_dict[self.data_key] = ret_val[0]
715
+ if seg is not None:
716
+ data_dict[self.label_key] = ret_val[1]
717
+
718
+ return data_dict
@@ -19,7 +19,8 @@ from typing import List, Type, Union, Tuple
19
19
  import numpy as np
20
20
 
21
21
  from batchgenerators.augmentations.utils import convert_seg_image_to_one_hot_encoding, \
22
- convert_seg_to_bounding_box_coordinates, transpose_channels
22
+ convert_seg_to_bounding_box_coordinates, transpose_channels, \
23
+ ignore_anatomy
23
24
  from batchgenerators.transforms.abstract_transforms import AbstractTransform
24
25
 
25
26
 
@@ -276,6 +277,21 @@ class RemoveLabelTransform(AbstractTransform):
276
277
  return data_dict
277
278
 
278
279
 
280
+ class IgnoreAnatomy(AbstractTransform):
281
+ """
282
+ Replaces every annotation values larger than the max_annotation_value with the replace_value.
283
+ This transform is used for the anatomy-informed augmentation scheme to remove all additional anatomical annotations.
284
+ You can find more information here: https://github.com/MIC-DKFZ/anatomy_informed_DA
285
+ """
286
+ def __init__(self, max_annotation_value=1, replace_value=0):
287
+ self.max_annotation_value = max_annotation_value
288
+ self.replace_value = replace_value
289
+
290
+ def __call__(self, **data_dict):
291
+ data_dict['seg'] = ignore_anatomy(data_dict['seg'], max_annotation_value=self.max_annotation_value, replace_value=self.replace_value)
292
+ return data_dict
293
+
294
+
279
295
  class RenameTransform(AbstractTransform):
280
296
  '''
281
297
  Saves the value of data_dict[in_key] to data_dict[out_key]. Optionally removes data_dict[in_key] from the dict.