batchgenerators 0.24__tar.gz → 0.25.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {batchgenerators-0.24 → batchgenerators-0.25.1}/PKG-INFO +10 -5
  2. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/spatial_transformations.py +174 -2
  3. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/utils.py +24 -1
  4. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/multi_threaded_augmenter.py +5 -7
  5. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +37 -33
  6. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/noise_transforms.py +15 -10
  7. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/spatial_transforms.py +194 -2
  8. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/utility_transforms.py +17 -1
  9. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/utilities/custom_types.py +4 -2
  10. batchgenerators-0.25.1/batchgenerators/utilities/file_and_folder_operations.py +135 -0
  11. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/PKG-INFO +10 -5
  12. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/SOURCES.txt +12 -1
  13. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/requires.txt +1 -0
  14. {batchgenerators-0.24 → batchgenerators-0.25.1}/setup.py +2 -1
  15. batchgenerators-0.25.1/tests/test_DataLoader.py +262 -0
  16. batchgenerators-0.25.1/tests/test_augment_zoom.py +73 -0
  17. batchgenerators-0.25.1/tests/test_axis_mirroring.py +173 -0
  18. batchgenerators-0.25.1/tests/test_color_augmentations.py +158 -0
  19. batchgenerators-0.25.1/tests/test_crop.py +470 -0
  20. batchgenerators-0.25.1/tests/test_multithreaded_augmenter.py +222 -0
  21. batchgenerators-0.25.1/tests/test_normalizations.py +235 -0
  22. batchgenerators-0.25.1/tests/test_random_crop.py +76 -0
  23. batchgenerators-0.25.1/tests/test_resample_augmentations.py +70 -0
  24. batchgenerators-0.25.1/tests/test_sanity.py +26 -0
  25. batchgenerators-0.25.1/tests/test_spatial_transformations.py +141 -0
  26. batchgenerators-0.24/batchgenerators/utilities/file_and_folder_operations.py +0 -100
  27. {batchgenerators-0.24 → batchgenerators-0.25.1}/LICENSE +0 -0
  28. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/__init__.py +0 -0
  29. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/__init__.py +0 -0
  30. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/color_augmentations.py +0 -0
  31. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/crop_and_pad_augmentations.py +0 -0
  32. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/noise_augmentations.py +0 -0
  33. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/normalizations.py +0 -0
  34. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/resample_augmentations.py +0 -0
  35. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/__init__.py +0 -0
  36. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/data_loader.py +0 -0
  37. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/dataset.py +0 -0
  38. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/single_threaded_augmenter.py +0 -0
  39. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/datasets/__init__.py +0 -0
  40. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/datasets/cifar.py +0 -0
  41. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/__init__.py +0 -0
  42. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/__init__.py +0 -0
  43. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/brats2017_dataloader_2D.py +0 -0
  44. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py +0 -0
  45. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/brats2017_preprocessing.py +0 -0
  46. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/config.py +0 -0
  47. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/cifar10.py +0 -0
  48. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/multithreaded_dataloading.py +0 -0
  49. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/__init__.py +0 -0
  50. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/abstract_transforms.py +0 -0
  51. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/channel_selection_transforms.py +0 -0
  52. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/color_transforms.py +0 -0
  53. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/crop_and_pad_transforms.py +0 -0
  54. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/local_transforms.py +0 -0
  55. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/resample_transforms.py +0 -0
  56. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/sample_normalization_transforms.py +0 -0
  57. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/utilities/__init__.py +0 -0
  58. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/utilities/data_splitting.py +0 -0
  59. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/dependency_links.txt +0 -0
  60. {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/top_level.txt +0 -0
  61. {batchgenerators-0.24 → batchgenerators-0.25.1}/setup.cfg +0 -0
@@ -1,14 +1,19 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: batchgenerators
3
- Version: 0.24
3
+ Version: 0.25.1
4
4
  Summary: Data augmentation toolkit
5
5
  Home-page: https://github.com/MIC-DKFZ/batchgenerators
6
6
  Author: Division of Medical Image Computing, German Cancer Research Center AND Applied Computer Vision Lab, Helmholtz Imaging Platform
7
7
  Author-email: f.isensee@dkfz-heidelberg.de
8
8
  License: Apache License Version 2.0, January 2004
9
9
  Keywords: data augmentation,deep learning,image segmentation,image classification,medical image analysis,medical image segmentation
10
- Platform: UNKNOWN
11
10
  License-File: LICENSE
12
-
13
- UNKNOWN
14
-
11
+ Requires-Dist: pillow>=7.1.2
12
+ Requires-Dist: numpy>=1.10.2
13
+ Requires-Dist: scipy
14
+ Requires-Dist: scikit-image
15
+ Requires-Dist: scikit-learn
16
+ Requires-Dist: future
17
+ Requires-Dist: pandas
18
+ Requires-Dist: unittest2
19
+ Requires-Dist: threadpoolctl
@@ -16,10 +16,12 @@
16
16
  from builtins import range
17
17
 
18
18
  import numpy as np
19
+ from scipy.ndimage import map_coordinates
19
20
  from batchgenerators.augmentations.utils import create_zero_centered_coordinate_mesh, elastic_deform_coordinates, \
20
21
  interpolate_img, \
21
22
  rotate_coords_2d, rotate_coords_3d, scale_coords, resize_segmentation, resize_multichannel_image, \
22
- elastic_deform_coordinates_2
23
+ elastic_deform_coordinates_2, \
24
+ get_organ_gradient_field, ignore_anatomy
23
25
  from batchgenerators.augmentations.crop_and_pad_augmentations import random_crop as random_crop_aug
24
26
  from batchgenerators.augmentations.crop_and_pad_augmentations import center_crop as center_crop_aug
25
27
 
@@ -298,7 +300,7 @@ def augment_spatial_2(data, seg, patch_size, patch_center_dist_from_border=30,
298
300
  do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3,
299
301
  border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1,
300
302
  p_scale_per_sample=1, p_rot_per_sample=1, independent_scale_for_each_axis=False,
301
- p_rot_per_axis: float = 1, p_independent_scale_per_axis: int = 1):
303
+ p_rot_per_axis: float = 1, p_independent_scale_per_axis: float = 1):
302
304
  """
303
305
 
304
306
  :param data:
@@ -484,3 +486,173 @@ def augment_transpose_axes(data_sample, seg_sample, axes=(0, 1, 2)):
484
486
  if seg_sample is not None:
485
487
  seg_sample = seg_sample.transpose(*static_axes)
486
488
  return data_sample, seg_sample
489
+
490
+ def augment_anatomy_informed(data, seg,
491
+ active_organs, dilation_ranges, directions_of_trans, modalities,
492
+ spacing_ratio=0.3125/3.0, blur=32, anisotropy_safety= True,
493
+ max_annotation_value=1, replace_value=0):
494
+ if sum(active_organs) > 0:
495
+ data_shape = data.shape
496
+ coords = create_zero_centered_coordinate_mesh(data_shape[-3:])
497
+
498
+ for organ_idx, active in reversed(list(enumerate(active_organs))):
499
+ if active:
500
+ dil_magnitude = np.random.uniform(low=dilation_ranges[organ_idx][0], high=dilation_ranges[organ_idx][1])
501
+
502
+ t, u, v = get_organ_gradient_field(seg == organ_idx + 2,
503
+ spacing_ratio=spacing_ratio,
504
+ blur=blur)
505
+
506
+ if directions_of_trans[organ_idx][0]:
507
+ coords[0, :, :, :] = coords[0, :, :, :] + t * dil_magnitude * spacing_ratio
508
+ if directions_of_trans[organ_idx][1]:
509
+ coords[1, :, :, :] = coords[1, :, :, :] + u * dil_magnitude
510
+ if directions_of_trans[organ_idx][2]:
511
+ coords[2, :, :, :] = coords[2, :, :, :] + v * dil_magnitude
512
+
513
+ for d in range(3):
514
+ ctr = data.shape[d+1] / 2 # !!!
515
+ coords[d] += ctr - 0.5 # !!!
516
+
517
+ if anisotropy_safety:
518
+ coords[0, 0, :, :][coords[0, 0, :, :] < 0] = 0.0
519
+ coords[0, 1, :, :][coords[0, 1, :, :] < 0] = 0.0
520
+ coords[0, -1, :, :][coords[0, -1, :, :] > (data_shape[-3] - 1)] = data_shape[-3] - 1
521
+ coords[0, -2, :, :][coords[0, -2, :, :] > (data_shape[-3] - 1)] = data_shape[-3] - 1
522
+
523
+
524
+ for modality in modalities:
525
+ data[modality, :, :, :] = map_coordinates(data[modality, :, :, :], coords, order=1, mode='constant')
526
+
527
+ seg[:, :, :] = ignore_anatomy(seg[:, :, :], max_annotation_value=max_annotation_value, replace_value=replace_value)
528
+ seg[:, :, :] = map_coordinates(seg[:, :, :], coords, order=0, mode='constant')
529
+
530
+ else:
531
+ seg[:, :, :] = ignore_anatomy(seg[:, :, :], max_annotation_value=max_annotation_value, replace_value=replace_value)
532
+
533
+ return data, seg
534
+
535
+
536
+ def augment_misalign(data, seg, data_size,
537
+ im_channels_2_misalign=[0, ],
538
+ label_channels_2_misalign=[0, ],
539
+ do_squeeze=False,
540
+ sq_x=[1.0, 1.0], sq_y=[0.9, 1.1], sq_z=[1.0, 1.0],
541
+ p_sq_per_sample=0.1, p_sq_per_dir=1.0,
542
+ do_rotation=False,
543
+ angle_x=(-0 / 360. * 2 * np.pi, 0 / 360. * 2 * np.pi),
544
+ angle_y=(-0 / 360. * 2 * np.pi, 0 / 360. * 2 * np.pi),
545
+ angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
546
+ p_rot_per_sample=0.1, p_rot_per_axis=1.0,
547
+ tr_x=[-32, 32], tr_y=[-32, 32], tr_z=[-2, 2],
548
+ p_transl_per_sample=0.1, p_transl_per_dir=1.0,
549
+ do_transl=False,
550
+ border_mode_data='constant', border_cval_data=0,
551
+ border_mode_seg='constant', border_cval_seg=0,
552
+ order_data=3, order_seg=0):
553
+
554
+ dim = len(data_size)
555
+
556
+ for sample_id in range(data.shape[0]):
557
+
558
+ if do_squeeze and np.random.uniform() < p_sq_per_sample:
559
+ coords = create_zero_centered_coordinate_mesh(data_size)
560
+ sq = []
561
+ if dim == 3:
562
+ if np.random.uniform() <= p_sq_per_dir:
563
+ sq.append(np.random.uniform(sq_z[0], sq_z[1]))
564
+ else:
565
+ sq.append(1.0)
566
+
567
+ if np.random.uniform() <= p_sq_per_dir:
568
+ sq.append(np.random.uniform(sq_y[0], sq_y[1]))
569
+ else:
570
+ sq.append(1.0)
571
+
572
+ if np.random.uniform() <= p_sq_per_dir:
573
+ sq.append(np.random.uniform(sq_x[0], sq_x[1]))
574
+ else:
575
+ sq.append(1.0)
576
+ coords = scale_coords(coords, sq)
577
+ for d in range(dim):
578
+ ctr = data.shape[d + 2] / 2. - 0.5
579
+ coords[d] += ctr
580
+
581
+ for channel_id in range(data.shape[1]):
582
+ if channel_id in im_channels_2_misalign:
583
+ data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
584
+ border_mode_data, cval=border_cval_data)
585
+ if seg is not None:
586
+ for channel_id in range(seg.shape[1]):
587
+ if channel_id in im_channels_2_misalign:
588
+ seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
589
+ border_mode_seg, cval=border_cval_seg,
590
+ is_seg=True)
591
+
592
+ if do_rotation and np.random.uniform() < p_rot_per_sample:
593
+ coords = create_zero_centered_coordinate_mesh(data_size)
594
+ if np.random.uniform() <= p_rot_per_axis:
595
+ a_z = np.random.uniform(angle_z[0], angle_z[1])
596
+ else:
597
+ a_z = 0
598
+ if dim == 3:
599
+ if np.random.uniform() <= p_rot_per_axis:
600
+ a_y = np.random.uniform(angle_y[0], angle_y[1])
601
+ else:
602
+ a_y = 0
603
+ if np.random.uniform() <= p_rot_per_axis:
604
+ a_x = np.random.uniform(angle_x[0], angle_x[1])
605
+ else:
606
+ a_x = 0
607
+ coords = rotate_coords_3d(coords, a_z, a_y, a_x)
608
+ else:
609
+ coords = rotate_coords_2d(coords, a_z)
610
+ for d in range(dim):
611
+ ctr = data.shape[d + 2] / 2. - 0.5
612
+ coords[d] += ctr
613
+
614
+ for channel_id in range(data.shape[1]):
615
+ if channel_id in im_channels_2_misalign:
616
+ data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
617
+ border_mode_data, cval=border_cval_data)
618
+ if seg is not None:
619
+ for channel_id in range(seg.shape[1]):
620
+ if channel_id in im_channels_2_misalign:
621
+ seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
622
+ border_mode_seg, cval=border_cval_seg,
623
+ is_seg=True)
624
+
625
+ if do_transl and np.random.uniform() < p_transl_per_sample:
626
+ coords = create_zero_centered_coordinate_mesh(data_size)
627
+ tr = []
628
+ if dim == 3:
629
+ if np.random.uniform() <= p_transl_per_dir:
630
+ tr.append(np.random.uniform(tr_z[0], tr_z[1]))
631
+ else:
632
+ tr.append(1.0)
633
+
634
+ if np.random.uniform() <= p_transl_per_dir:
635
+ tr.append(np.random.uniform(tr_y[0], tr_y[1]))
636
+ else:
637
+ tr.append(1.0)
638
+
639
+ if np.random.uniform() <= p_transl_per_dir:
640
+ tr.append(np.random.uniform(tr_x[0], tr_x[1]))
641
+ else:
642
+ tr.append(1.0)
643
+
644
+ for d in range(dim):
645
+ ctr = data.shape[d + 2] / 2. - 0.5 + tr[d]
646
+ coords[d] += ctr
647
+
648
+ for channel_id in range(data.shape[1]):
649
+ if channel_id in im_channels_2_misalign:
650
+ data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
651
+ border_mode_data, cval=border_cval_data)
652
+ if seg is not None:
653
+ for channel_id in range(seg.shape[1]):
654
+ if channel_id in label_channels_2_misalign:
655
+ seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
656
+ border_mode_seg, cval=border_cval_seg,
657
+ is_seg=True)
658
+ return data, seg
@@ -21,6 +21,7 @@ from scipy.ndimage.filters import gaussian_filter, gaussian_gradient_magnitude
21
21
  from scipy.ndimage.morphology import grey_dilation
22
22
  from skimage.transform import resize
23
23
  from scipy.ndimage.measurements import label as lb
24
+ import pandas as pd
24
25
 
25
26
 
26
27
  def generate_elastic_transform_coordinates(shape, alpha, sigma):
@@ -591,13 +592,13 @@ def resize_segmentation(segmentation, new_shape, order=3):
591
592
  :return:
592
593
  '''
593
594
  tpe = segmentation.dtype
594
- unique_labels = np.unique(segmentation)
595
595
  assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
596
596
  if order == 0:
597
597
  return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype(tpe)
598
598
  else:
599
599
  reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
600
600
 
601
+ unique_labels = np.sort(pd.unique(segmentation.ravel()))
601
602
  for i, c in enumerate(unique_labels):
602
603
  mask = segmentation == c
603
604
  reshaped_multihot = resize(mask.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
@@ -773,3 +774,25 @@ def mask_random_squares(img, square_size, n_squares, n_val, channel_wise_n_val=F
773
774
  img = mask_random_square(img, square_size, n_val, channel_wise_n_val=channel_wise_n_val,
774
775
  square_pos=square_pos)
775
776
  return img
777
+
778
+ def get_organ_gradient_field(organ, spacing_ratio=0.3125/3.0, blur=32):
779
+ """
780
+ Calculates the gradient field around the organ segmentations for the anatomy-informed augmentation
781
+
782
+ :param organ: binary organ segmentation
783
+ :param spacing_ratio: ratio of the axial spacing and the slice thickness, needed for the right vector field calculation
784
+ :param blur: kernel constant
785
+ """
786
+ organ_blurred = gaussian_filter(organ.astype(float),
787
+ sigma=(blur * spacing_ratio, blur, blur),
788
+ order=0,
789
+ mode='nearest')
790
+
791
+ t, u, v = np.gradient(organ_blurred)
792
+ t = t * spacing_ratio
793
+
794
+ return t, u, v
795
+
796
+ def ignore_anatomy(segm, max_annotation_value=1, replace_value=0):
797
+ segm[segm > max_annotation_value] = replace_value
798
+ return segm
@@ -80,7 +80,7 @@ def results_loop(in_queues: List[Queue], out_queue: thrQueue, abort_event: Event
80
80
  end_ctr = 0
81
81
 
82
82
  while True:
83
- # if abort_event is set we need to clean up. This is where it hangs sometimes so it makes sense to drain all
83
+ # if abort_event is set we need to clean up. This is where it hangs sometimes, so it makes sense to drain all
84
84
  # the incoming queues and ignore all the errors occuring during this process.
85
85
  try:
86
86
  if abort_event.is_set():
@@ -89,9 +89,8 @@ def results_loop(in_queues: List[Queue], out_queue: thrQueue, abort_event: Event
89
89
  # check if all workers are still alive
90
90
  if not all([i.is_alive() for i in worker_list]):
91
91
  abort_event.set()
92
- raise RuntimeError("Abort event was set. So someone died and we should end this madness. \nIMPORTANT: "
93
- "This is not the actual error message! Look further up to see what caused the "
94
- "error. Please also check whether your RAM was full")
92
+ raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the print"
93
+ " statements above for the actual error message")
95
94
 
96
95
  # if we don't have an item we need to fetch it first. If the queue we want to get it from it empty, try
97
96
  # again later
@@ -187,9 +186,8 @@ class MultiThreadedAugmenter(object):
187
186
  while item is None:
188
187
  if self.abort_event.is_set():
189
188
  self._finish()
190
- raise RuntimeError("MultiThreadedAugmenter.abort_event was set, something went wrong. Maybe one of "
191
- "your workers crashed. This is not the actual error message! Look further up your "
192
- "stdout to see what caused the error. Please also check whether your RAM was full")
189
+ raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
190
+ "print statements above for the actual error message")
193
191
 
194
192
  if not self.pin_memory_queue.empty():
195
193
  item = self.pin_memory_queue.get()
@@ -38,39 +38,40 @@ except ImportError:
38
38
  def producer(queue: Queue, data_loader, transform, thread_id: int, seed,
39
39
  abort_event: Event, wait_time: float = 0.02):
40
40
  # the producer will set the abort event if something happens
41
- np.random.seed(seed)
42
- data_loader.set_thread_id(thread_id)
43
- item = None
44
-
45
- try:
46
- while True:
41
+ with threadpool_limits(1, None):
42
+ np.random.seed(seed)
43
+ data_loader.set_thread_id(thread_id)
44
+ item = None
47
45
 
48
- if abort_event.is_set():
49
- return
50
- else:
51
- if item is None:
52
- item = next(data_loader)
53
- if transform is not None:
54
- item = transform(**item)
46
+ try:
47
+ while True:
55
48
 
56
49
  if abort_event.is_set():
57
50
  return
58
-
59
- if not queue.full():
60
- queue.put(item)
61
- item = None
62
51
  else:
63
- sleep(wait_time)
52
+ if item is None:
53
+ item = next(data_loader)
54
+ if transform is not None:
55
+ item = transform(**item)
64
56
 
65
- except KeyboardInterrupt:
66
- abort_event.set()
67
- return
57
+ if abort_event.is_set():
58
+ return
68
59
 
69
- except Exception as e:
70
- print("Exception in background worker %d:\n" % thread_id, e)
71
- traceback.print_exc()
72
- abort_event.set()
73
- return
60
+ if not queue.full():
61
+ queue.put(item)
62
+ item = None
63
+ else:
64
+ sleep(wait_time)
65
+
66
+ except KeyboardInterrupt:
67
+ abort_event.set()
68
+ return
69
+
70
+ except Exception as e:
71
+ print("Exception in background worker %d:\n" % thread_id, e)
72
+ traceback.print_exc()
73
+ abort_event.set()
74
+ return
74
75
 
75
76
 
76
77
  def pin_memory_of_all_eligible_items_in_dict(result_dict):
@@ -99,9 +100,8 @@ def results_loop(in_queue: Queue, out_queue: thrQueue, abort_event: Event,
99
100
  # check if all workers are still alive
100
101
  if not all([i.is_alive() for i in worker_list]):
101
102
  abort_event.set()
102
- raise RuntimeError("Abort event was set. So someone died and we should end this madness. \nIMPORTANT: "
103
- "This is not the actual error message! Look further up to see what caused the "
104
- "error. Please also check whether your RAM was full")
103
+ raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
104
+ "print statements above for the actual error message")
105
105
 
106
106
  if item is None:
107
107
  if not in_queue.empty():
@@ -178,9 +178,8 @@ class NonDetMultiThreadedAugmenter(object):
178
178
  if self.abort_event.is_set():
179
179
  # self.communication_thread handles checking for dead workers and will set the abort event if necessary
180
180
  self._finish()
181
- raise RuntimeError("MultiThreadedAugmenter.abort_event was set, something went wrong. Maybe one of "
182
- "your workers crashed. This is not the actual error message! Look further up your "
183
- "stdout to see what caused the error. Please also check whether your RAM was full")
181
+ raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
182
+ "print statements above for the actual error message")
184
183
 
185
184
  if not self.results_loop_queue.empty():
186
185
  item = self.results_loop_queue.get()
@@ -209,13 +208,18 @@ class NonDetMultiThreadedAugmenter(object):
209
208
  if isinstance(self.generator, DataLoader):
210
209
  self.generator.was_initialized = False
211
210
 
211
+ if torch is not None:
212
+ torch_nthreads = torch.get_num_threads()
213
+ torch.set_num_threads(1)
212
214
  with threadpool_limits(limits=1, user_api=None):
213
215
  for i in range(self.num_processes):
214
216
  self._processes.append(Process(target=producer, args=(
215
217
  self._queue, self.generator, self.transform, i, self.seeds[i], self.abort_event, self.wait_time
216
218
  )))
217
219
  self._processes[-1].daemon = True
218
- _ = [i.start() for i in self._processes]
220
+ _ = [i.start() for i in self._processes]
221
+ if torch is not None:
222
+ torch.set_num_threads(torch_nthreads)
219
223
 
220
224
  if torch is not None and torch.cuda.is_available():
221
225
  gpu = torch.cuda.current_device()
@@ -133,6 +133,20 @@ class BlankSquareNoiseTransform(AbstractTransform):
133
133
  self.channel_wise_n_val, self.square_pos)
134
134
  return data_dict
135
135
 
136
+ class ColorFunctionExtractor:
137
+ def __init__(self, rectangle_value):
138
+ self.rectangle_value = rectangle_value
139
+
140
+ def __call__(self, x):
141
+ if np.isscalar(self.rectangle_value):
142
+ return self.rectangle_value
143
+ elif callable(self.rectangle_value):
144
+ return self.rectangle_value(x)
145
+ elif isinstance(self.rectangle_value, (tuple, list)):
146
+ return np.random.uniform(*self.rectangle_value)
147
+ else:
148
+ raise RuntimeError("unrecognized format for rectangle_value")
149
+
136
150
 
137
151
  class BlankRectangleTransform(AbstractTransform):
138
152
  def __init__(self, rectangle_size, rectangle_value, num_rectangles, force_square=False, p_per_sample=0.5,
@@ -179,16 +193,7 @@ class BlankRectangleTransform(AbstractTransform):
179
193
  self.p_per_sample = p_per_sample
180
194
  self.p_per_channel = p_per_channel
181
195
  self.apply_to_keys = apply_to_keys
182
-
183
- # intensity value
184
- if np.isscalar(rectangle_value):
185
- self.color_fn = lambda x: rectangle_value
186
- elif callable(rectangle_value):
187
- self.color_fn = lambda x: rectangle_value(x)
188
- elif isinstance(rectangle_value, (tuple, list)):
189
- self.color_fn = lambda x: np.random.uniform(*rectangle_value)
190
- else:
191
- raise RuntimeError("unrecognized format for rectangle_value")
196
+ self.color_fn = ColorFunctionExtractor(rectangle_value)
192
197
 
193
198
  def __call__(self, **data_dict):
194
199
  for k in self.apply_to_keys: