batchgenerators 0.24__tar.gz → 0.25.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {batchgenerators-0.24 → batchgenerators-0.25.1}/PKG-INFO +10 -5
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/spatial_transformations.py +174 -2
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/utils.py +24 -1
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/multi_threaded_augmenter.py +5 -7
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +37 -33
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/noise_transforms.py +15 -10
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/spatial_transforms.py +194 -2
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/utility_transforms.py +17 -1
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/utilities/custom_types.py +4 -2
- batchgenerators-0.25.1/batchgenerators/utilities/file_and_folder_operations.py +135 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/PKG-INFO +10 -5
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/SOURCES.txt +12 -1
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/requires.txt +1 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/setup.py +2 -1
- batchgenerators-0.25.1/tests/test_DataLoader.py +262 -0
- batchgenerators-0.25.1/tests/test_augment_zoom.py +73 -0
- batchgenerators-0.25.1/tests/test_axis_mirroring.py +173 -0
- batchgenerators-0.25.1/tests/test_color_augmentations.py +158 -0
- batchgenerators-0.25.1/tests/test_crop.py +470 -0
- batchgenerators-0.25.1/tests/test_multithreaded_augmenter.py +222 -0
- batchgenerators-0.25.1/tests/test_normalizations.py +235 -0
- batchgenerators-0.25.1/tests/test_random_crop.py +76 -0
- batchgenerators-0.25.1/tests/test_resample_augmentations.py +70 -0
- batchgenerators-0.25.1/tests/test_sanity.py +26 -0
- batchgenerators-0.25.1/tests/test_spatial_transformations.py +141 -0
- batchgenerators-0.24/batchgenerators/utilities/file_and_folder_operations.py +0 -100
- {batchgenerators-0.24 → batchgenerators-0.25.1}/LICENSE +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/__init__.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/__init__.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/color_augmentations.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/crop_and_pad_augmentations.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/noise_augmentations.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/normalizations.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/augmentations/resample_augmentations.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/__init__.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/data_loader.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/dataset.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/dataloading/single_threaded_augmenter.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/datasets/__init__.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/datasets/cifar.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/__init__.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/__init__.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/brats2017_dataloader_2D.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/brats2017_preprocessing.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/brats2017/config.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/cifar10.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/examples/multithreaded_dataloading.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/__init__.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/abstract_transforms.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/channel_selection_transforms.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/color_transforms.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/crop_and_pad_transforms.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/local_transforms.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/resample_transforms.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/sample_normalization_transforms.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/utilities/__init__.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/utilities/data_splitting.py +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/dependency_links.txt +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators.egg-info/top_level.txt +0 -0
- {batchgenerators-0.24 → batchgenerators-0.25.1}/setup.cfg +0 -0
|
@@ -1,14 +1,19 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: batchgenerators
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.25.1
|
|
4
4
|
Summary: Data augmentation toolkit
|
|
5
5
|
Home-page: https://github.com/MIC-DKFZ/batchgenerators
|
|
6
6
|
Author: Division of Medical Image Computing, German Cancer Research Center AND Applied Computer Vision Lab, Helmholtz Imaging Platform
|
|
7
7
|
Author-email: f.isensee@dkfz-heidelberg.de
|
|
8
8
|
License: Apache License Version 2.0, January 2004
|
|
9
9
|
Keywords: data augmentation,deep learning,image segmentation,image classification,medical image analysis,medical image segmentation
|
|
10
|
-
Platform: UNKNOWN
|
|
11
10
|
License-File: LICENSE
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
11
|
+
Requires-Dist: pillow>=7.1.2
|
|
12
|
+
Requires-Dist: numpy>=1.10.2
|
|
13
|
+
Requires-Dist: scipy
|
|
14
|
+
Requires-Dist: scikit-image
|
|
15
|
+
Requires-Dist: scikit-learn
|
|
16
|
+
Requires-Dist: future
|
|
17
|
+
Requires-Dist: pandas
|
|
18
|
+
Requires-Dist: unittest2
|
|
19
|
+
Requires-Dist: threadpoolctl
|
|
@@ -16,10 +16,12 @@
|
|
|
16
16
|
from builtins import range
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
|
+
from scipy.ndimage import map_coordinates
|
|
19
20
|
from batchgenerators.augmentations.utils import create_zero_centered_coordinate_mesh, elastic_deform_coordinates, \
|
|
20
21
|
interpolate_img, \
|
|
21
22
|
rotate_coords_2d, rotate_coords_3d, scale_coords, resize_segmentation, resize_multichannel_image, \
|
|
22
|
-
elastic_deform_coordinates_2
|
|
23
|
+
elastic_deform_coordinates_2, \
|
|
24
|
+
get_organ_gradient_field, ignore_anatomy
|
|
23
25
|
from batchgenerators.augmentations.crop_and_pad_augmentations import random_crop as random_crop_aug
|
|
24
26
|
from batchgenerators.augmentations.crop_and_pad_augmentations import center_crop as center_crop_aug
|
|
25
27
|
|
|
@@ -298,7 +300,7 @@ def augment_spatial_2(data, seg, patch_size, patch_center_dist_from_border=30,
|
|
|
298
300
|
do_scale=True, scale=(0.75, 1.25), border_mode_data='nearest', border_cval_data=0, order_data=3,
|
|
299
301
|
border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=1,
|
|
300
302
|
p_scale_per_sample=1, p_rot_per_sample=1, independent_scale_for_each_axis=False,
|
|
301
|
-
p_rot_per_axis: float = 1, p_independent_scale_per_axis:
|
|
303
|
+
p_rot_per_axis: float = 1, p_independent_scale_per_axis: float = 1):
|
|
302
304
|
"""
|
|
303
305
|
|
|
304
306
|
:param data:
|
|
@@ -484,3 +486,173 @@ def augment_transpose_axes(data_sample, seg_sample, axes=(0, 1, 2)):
|
|
|
484
486
|
if seg_sample is not None:
|
|
485
487
|
seg_sample = seg_sample.transpose(*static_axes)
|
|
486
488
|
return data_sample, seg_sample
|
|
489
|
+
|
|
490
|
+
def augment_anatomy_informed(data, seg,
|
|
491
|
+
active_organs, dilation_ranges, directions_of_trans, modalities,
|
|
492
|
+
spacing_ratio=0.3125/3.0, blur=32, anisotropy_safety= True,
|
|
493
|
+
max_annotation_value=1, replace_value=0):
|
|
494
|
+
if sum(active_organs) > 0:
|
|
495
|
+
data_shape = data.shape
|
|
496
|
+
coords = create_zero_centered_coordinate_mesh(data_shape[-3:])
|
|
497
|
+
|
|
498
|
+
for organ_idx, active in reversed(list(enumerate(active_organs))):
|
|
499
|
+
if active:
|
|
500
|
+
dil_magnitude = np.random.uniform(low=dilation_ranges[organ_idx][0], high=dilation_ranges[organ_idx][1])
|
|
501
|
+
|
|
502
|
+
t, u, v = get_organ_gradient_field(seg == organ_idx + 2,
|
|
503
|
+
spacing_ratio=spacing_ratio,
|
|
504
|
+
blur=blur)
|
|
505
|
+
|
|
506
|
+
if directions_of_trans[organ_idx][0]:
|
|
507
|
+
coords[0, :, :, :] = coords[0, :, :, :] + t * dil_magnitude * spacing_ratio
|
|
508
|
+
if directions_of_trans[organ_idx][1]:
|
|
509
|
+
coords[1, :, :, :] = coords[1, :, :, :] + u * dil_magnitude
|
|
510
|
+
if directions_of_trans[organ_idx][2]:
|
|
511
|
+
coords[2, :, :, :] = coords[2, :, :, :] + v * dil_magnitude
|
|
512
|
+
|
|
513
|
+
for d in range(3):
|
|
514
|
+
ctr = data.shape[d+1] / 2 # !!!
|
|
515
|
+
coords[d] += ctr - 0.5 # !!!
|
|
516
|
+
|
|
517
|
+
if anisotropy_safety:
|
|
518
|
+
coords[0, 0, :, :][coords[0, 0, :, :] < 0] = 0.0
|
|
519
|
+
coords[0, 1, :, :][coords[0, 1, :, :] < 0] = 0.0
|
|
520
|
+
coords[0, -1, :, :][coords[0, -1, :, :] > (data_shape[-3] - 1)] = data_shape[-3] - 1
|
|
521
|
+
coords[0, -2, :, :][coords[0, -2, :, :] > (data_shape[-3] - 1)] = data_shape[-3] - 1
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
for modality in modalities:
|
|
525
|
+
data[modality, :, :, :] = map_coordinates(data[modality, :, :, :], coords, order=1, mode='constant')
|
|
526
|
+
|
|
527
|
+
seg[:, :, :] = ignore_anatomy(seg[:, :, :], max_annotation_value=max_annotation_value, replace_value=replace_value)
|
|
528
|
+
seg[:, :, :] = map_coordinates(seg[:, :, :], coords, order=0, mode='constant')
|
|
529
|
+
|
|
530
|
+
else:
|
|
531
|
+
seg[:, :, :] = ignore_anatomy(seg[:, :, :], max_annotation_value=max_annotation_value, replace_value=replace_value)
|
|
532
|
+
|
|
533
|
+
return data, seg
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def augment_misalign(data, seg, data_size,
|
|
537
|
+
im_channels_2_misalign=[0, ],
|
|
538
|
+
label_channels_2_misalign=[0, ],
|
|
539
|
+
do_squeeze=False,
|
|
540
|
+
sq_x=[1.0, 1.0], sq_y=[0.9, 1.1], sq_z=[1.0, 1.0],
|
|
541
|
+
p_sq_per_sample=0.1, p_sq_per_dir=1.0,
|
|
542
|
+
do_rotation=False,
|
|
543
|
+
angle_x=(-0 / 360. * 2 * np.pi, 0 / 360. * 2 * np.pi),
|
|
544
|
+
angle_y=(-0 / 360. * 2 * np.pi, 0 / 360. * 2 * np.pi),
|
|
545
|
+
angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
|
|
546
|
+
p_rot_per_sample=0.1, p_rot_per_axis=1.0,
|
|
547
|
+
tr_x=[-32, 32], tr_y=[-32, 32], tr_z=[-2, 2],
|
|
548
|
+
p_transl_per_sample=0.1, p_transl_per_dir=1.0,
|
|
549
|
+
do_transl=False,
|
|
550
|
+
border_mode_data='constant', border_cval_data=0,
|
|
551
|
+
border_mode_seg='constant', border_cval_seg=0,
|
|
552
|
+
order_data=3, order_seg=0):
|
|
553
|
+
|
|
554
|
+
dim = len(data_size)
|
|
555
|
+
|
|
556
|
+
for sample_id in range(data.shape[0]):
|
|
557
|
+
|
|
558
|
+
if do_squeeze and np.random.uniform() < p_sq_per_sample:
|
|
559
|
+
coords = create_zero_centered_coordinate_mesh(data_size)
|
|
560
|
+
sq = []
|
|
561
|
+
if dim == 3:
|
|
562
|
+
if np.random.uniform() <= p_sq_per_dir:
|
|
563
|
+
sq.append(np.random.uniform(sq_z[0], sq_z[1]))
|
|
564
|
+
else:
|
|
565
|
+
sq.append(1.0)
|
|
566
|
+
|
|
567
|
+
if np.random.uniform() <= p_sq_per_dir:
|
|
568
|
+
sq.append(np.random.uniform(sq_y[0], sq_y[1]))
|
|
569
|
+
else:
|
|
570
|
+
sq.append(1.0)
|
|
571
|
+
|
|
572
|
+
if np.random.uniform() <= p_sq_per_dir:
|
|
573
|
+
sq.append(np.random.uniform(sq_x[0], sq_x[1]))
|
|
574
|
+
else:
|
|
575
|
+
sq.append(1.0)
|
|
576
|
+
coords = scale_coords(coords, sq)
|
|
577
|
+
for d in range(dim):
|
|
578
|
+
ctr = data.shape[d + 2] / 2. - 0.5
|
|
579
|
+
coords[d] += ctr
|
|
580
|
+
|
|
581
|
+
for channel_id in range(data.shape[1]):
|
|
582
|
+
if channel_id in im_channels_2_misalign:
|
|
583
|
+
data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
|
|
584
|
+
border_mode_data, cval=border_cval_data)
|
|
585
|
+
if seg is not None:
|
|
586
|
+
for channel_id in range(seg.shape[1]):
|
|
587
|
+
if channel_id in im_channels_2_misalign:
|
|
588
|
+
seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
|
|
589
|
+
border_mode_seg, cval=border_cval_seg,
|
|
590
|
+
is_seg=True)
|
|
591
|
+
|
|
592
|
+
if do_rotation and np.random.uniform() < p_rot_per_sample:
|
|
593
|
+
coords = create_zero_centered_coordinate_mesh(data_size)
|
|
594
|
+
if np.random.uniform() <= p_rot_per_axis:
|
|
595
|
+
a_z = np.random.uniform(angle_z[0], angle_z[1])
|
|
596
|
+
else:
|
|
597
|
+
a_z = 0
|
|
598
|
+
if dim == 3:
|
|
599
|
+
if np.random.uniform() <= p_rot_per_axis:
|
|
600
|
+
a_y = np.random.uniform(angle_y[0], angle_y[1])
|
|
601
|
+
else:
|
|
602
|
+
a_y = 0
|
|
603
|
+
if np.random.uniform() <= p_rot_per_axis:
|
|
604
|
+
a_x = np.random.uniform(angle_x[0], angle_x[1])
|
|
605
|
+
else:
|
|
606
|
+
a_x = 0
|
|
607
|
+
coords = rotate_coords_3d(coords, a_z, a_y, a_x)
|
|
608
|
+
else:
|
|
609
|
+
coords = rotate_coords_2d(coords, a_z)
|
|
610
|
+
for d in range(dim):
|
|
611
|
+
ctr = data.shape[d + 2] / 2. - 0.5
|
|
612
|
+
coords[d] += ctr
|
|
613
|
+
|
|
614
|
+
for channel_id in range(data.shape[1]):
|
|
615
|
+
if channel_id in im_channels_2_misalign:
|
|
616
|
+
data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
|
|
617
|
+
border_mode_data, cval=border_cval_data)
|
|
618
|
+
if seg is not None:
|
|
619
|
+
for channel_id in range(seg.shape[1]):
|
|
620
|
+
if channel_id in im_channels_2_misalign:
|
|
621
|
+
seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
|
|
622
|
+
border_mode_seg, cval=border_cval_seg,
|
|
623
|
+
is_seg=True)
|
|
624
|
+
|
|
625
|
+
if do_transl and np.random.uniform() < p_transl_per_sample:
|
|
626
|
+
coords = create_zero_centered_coordinate_mesh(data_size)
|
|
627
|
+
tr = []
|
|
628
|
+
if dim == 3:
|
|
629
|
+
if np.random.uniform() <= p_transl_per_dir:
|
|
630
|
+
tr.append(np.random.uniform(tr_z[0], tr_z[1]))
|
|
631
|
+
else:
|
|
632
|
+
tr.append(1.0)
|
|
633
|
+
|
|
634
|
+
if np.random.uniform() <= p_transl_per_dir:
|
|
635
|
+
tr.append(np.random.uniform(tr_y[0], tr_y[1]))
|
|
636
|
+
else:
|
|
637
|
+
tr.append(1.0)
|
|
638
|
+
|
|
639
|
+
if np.random.uniform() <= p_transl_per_dir:
|
|
640
|
+
tr.append(np.random.uniform(tr_x[0], tr_x[1]))
|
|
641
|
+
else:
|
|
642
|
+
tr.append(1.0)
|
|
643
|
+
|
|
644
|
+
for d in range(dim):
|
|
645
|
+
ctr = data.shape[d + 2] / 2. - 0.5 + tr[d]
|
|
646
|
+
coords[d] += ctr
|
|
647
|
+
|
|
648
|
+
for channel_id in range(data.shape[1]):
|
|
649
|
+
if channel_id in im_channels_2_misalign:
|
|
650
|
+
data[sample_id, channel_id] = interpolate_img(data[sample_id, channel_id], coords, order_data,
|
|
651
|
+
border_mode_data, cval=border_cval_data)
|
|
652
|
+
if seg is not None:
|
|
653
|
+
for channel_id in range(seg.shape[1]):
|
|
654
|
+
if channel_id in label_channels_2_misalign:
|
|
655
|
+
seg[sample_id, channel_id] = interpolate_img(seg[sample_id, channel_id], coords, order_seg,
|
|
656
|
+
border_mode_seg, cval=border_cval_seg,
|
|
657
|
+
is_seg=True)
|
|
658
|
+
return data, seg
|
|
@@ -21,6 +21,7 @@ from scipy.ndimage.filters import gaussian_filter, gaussian_gradient_magnitude
|
|
|
21
21
|
from scipy.ndimage.morphology import grey_dilation
|
|
22
22
|
from skimage.transform import resize
|
|
23
23
|
from scipy.ndimage.measurements import label as lb
|
|
24
|
+
import pandas as pd
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
def generate_elastic_transform_coordinates(shape, alpha, sigma):
|
|
@@ -591,13 +592,13 @@ def resize_segmentation(segmentation, new_shape, order=3):
|
|
|
591
592
|
:return:
|
|
592
593
|
'''
|
|
593
594
|
tpe = segmentation.dtype
|
|
594
|
-
unique_labels = np.unique(segmentation)
|
|
595
595
|
assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
|
|
596
596
|
if order == 0:
|
|
597
597
|
return resize(segmentation.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False).astype(tpe)
|
|
598
598
|
else:
|
|
599
599
|
reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
|
|
600
600
|
|
|
601
|
+
unique_labels = np.sort(pd.unique(segmentation.ravel()))
|
|
601
602
|
for i, c in enumerate(unique_labels):
|
|
602
603
|
mask = segmentation == c
|
|
603
604
|
reshaped_multihot = resize(mask.astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
|
|
@@ -773,3 +774,25 @@ def mask_random_squares(img, square_size, n_squares, n_val, channel_wise_n_val=F
|
|
|
773
774
|
img = mask_random_square(img, square_size, n_val, channel_wise_n_val=channel_wise_n_val,
|
|
774
775
|
square_pos=square_pos)
|
|
775
776
|
return img
|
|
777
|
+
|
|
778
|
+
def get_organ_gradient_field(organ, spacing_ratio=0.3125/3.0, blur=32):
|
|
779
|
+
"""
|
|
780
|
+
Calculates the gradient field around the organ segmentations for the anatomy-informed augmentation
|
|
781
|
+
|
|
782
|
+
:param organ: binary organ segmentation
|
|
783
|
+
:param spacing_ratio: ratio of the axial spacing and the slice thickness, needed for the right vector field calculation
|
|
784
|
+
:param blur: kernel constant
|
|
785
|
+
"""
|
|
786
|
+
organ_blurred = gaussian_filter(organ.astype(float),
|
|
787
|
+
sigma=(blur * spacing_ratio, blur, blur),
|
|
788
|
+
order=0,
|
|
789
|
+
mode='nearest')
|
|
790
|
+
|
|
791
|
+
t, u, v = np.gradient(organ_blurred)
|
|
792
|
+
t = t * spacing_ratio
|
|
793
|
+
|
|
794
|
+
return t, u, v
|
|
795
|
+
|
|
796
|
+
def ignore_anatomy(segm, max_annotation_value=1, replace_value=0):
|
|
797
|
+
segm[segm > max_annotation_value] = replace_value
|
|
798
|
+
return segm
|
|
@@ -80,7 +80,7 @@ def results_loop(in_queues: List[Queue], out_queue: thrQueue, abort_event: Event
|
|
|
80
80
|
end_ctr = 0
|
|
81
81
|
|
|
82
82
|
while True:
|
|
83
|
-
# if abort_event is set we need to clean up. This is where it hangs sometimes so it makes sense to drain all
|
|
83
|
+
# if abort_event is set we need to clean up. This is where it hangs sometimes, so it makes sense to drain all
|
|
84
84
|
# the incoming queues and ignore all the errors occuring during this process.
|
|
85
85
|
try:
|
|
86
86
|
if abort_event.is_set():
|
|
@@ -89,9 +89,8 @@ def results_loop(in_queues: List[Queue], out_queue: thrQueue, abort_event: Event
|
|
|
89
89
|
# check if all workers are still alive
|
|
90
90
|
if not all([i.is_alive() for i in worker_list]):
|
|
91
91
|
abort_event.set()
|
|
92
|
-
raise RuntimeError("
|
|
93
|
-
"
|
|
94
|
-
"error. Please also check whether your RAM was full")
|
|
92
|
+
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the print"
|
|
93
|
+
" statements above for the actual error message")
|
|
95
94
|
|
|
96
95
|
# if we don't have an item we need to fetch it first. If the queue we want to get it from it empty, try
|
|
97
96
|
# again later
|
|
@@ -187,9 +186,8 @@ class MultiThreadedAugmenter(object):
|
|
|
187
186
|
while item is None:
|
|
188
187
|
if self.abort_event.is_set():
|
|
189
188
|
self._finish()
|
|
190
|
-
raise RuntimeError("
|
|
191
|
-
"
|
|
192
|
-
"stdout to see what caused the error. Please also check whether your RAM was full")
|
|
189
|
+
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
|
|
190
|
+
"print statements above for the actual error message")
|
|
193
191
|
|
|
194
192
|
if not self.pin_memory_queue.empty():
|
|
195
193
|
item = self.pin_memory_queue.get()
|
|
@@ -38,39 +38,40 @@ except ImportError:
|
|
|
38
38
|
def producer(queue: Queue, data_loader, transform, thread_id: int, seed,
|
|
39
39
|
abort_event: Event, wait_time: float = 0.02):
|
|
40
40
|
# the producer will set the abort event if something happens
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
try:
|
|
46
|
-
while True:
|
|
41
|
+
with threadpool_limits(1, None):
|
|
42
|
+
np.random.seed(seed)
|
|
43
|
+
data_loader.set_thread_id(thread_id)
|
|
44
|
+
item = None
|
|
47
45
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
else:
|
|
51
|
-
if item is None:
|
|
52
|
-
item = next(data_loader)
|
|
53
|
-
if transform is not None:
|
|
54
|
-
item = transform(**item)
|
|
46
|
+
try:
|
|
47
|
+
while True:
|
|
55
48
|
|
|
56
49
|
if abort_event.is_set():
|
|
57
50
|
return
|
|
58
|
-
|
|
59
|
-
if not queue.full():
|
|
60
|
-
queue.put(item)
|
|
61
|
-
item = None
|
|
62
51
|
else:
|
|
63
|
-
|
|
52
|
+
if item is None:
|
|
53
|
+
item = next(data_loader)
|
|
54
|
+
if transform is not None:
|
|
55
|
+
item = transform(**item)
|
|
64
56
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
return
|
|
57
|
+
if abort_event.is_set():
|
|
58
|
+
return
|
|
68
59
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
60
|
+
if not queue.full():
|
|
61
|
+
queue.put(item)
|
|
62
|
+
item = None
|
|
63
|
+
else:
|
|
64
|
+
sleep(wait_time)
|
|
65
|
+
|
|
66
|
+
except KeyboardInterrupt:
|
|
67
|
+
abort_event.set()
|
|
68
|
+
return
|
|
69
|
+
|
|
70
|
+
except Exception as e:
|
|
71
|
+
print("Exception in background worker %d:\n" % thread_id, e)
|
|
72
|
+
traceback.print_exc()
|
|
73
|
+
abort_event.set()
|
|
74
|
+
return
|
|
74
75
|
|
|
75
76
|
|
|
76
77
|
def pin_memory_of_all_eligible_items_in_dict(result_dict):
|
|
@@ -99,9 +100,8 @@ def results_loop(in_queue: Queue, out_queue: thrQueue, abort_event: Event,
|
|
|
99
100
|
# check if all workers are still alive
|
|
100
101
|
if not all([i.is_alive() for i in worker_list]):
|
|
101
102
|
abort_event.set()
|
|
102
|
-
raise RuntimeError("
|
|
103
|
-
"
|
|
104
|
-
"error. Please also check whether your RAM was full")
|
|
103
|
+
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
|
|
104
|
+
"print statements above for the actual error message")
|
|
105
105
|
|
|
106
106
|
if item is None:
|
|
107
107
|
if not in_queue.empty():
|
|
@@ -178,9 +178,8 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
178
178
|
if self.abort_event.is_set():
|
|
179
179
|
# self.communication_thread handles checking for dead workers and will set the abort event if necessary
|
|
180
180
|
self._finish()
|
|
181
|
-
raise RuntimeError("
|
|
182
|
-
"
|
|
183
|
-
"stdout to see what caused the error. Please also check whether your RAM was full")
|
|
181
|
+
raise RuntimeError("One or more background workers are no longer alive. Exiting. Please check the "
|
|
182
|
+
"print statements above for the actual error message")
|
|
184
183
|
|
|
185
184
|
if not self.results_loop_queue.empty():
|
|
186
185
|
item = self.results_loop_queue.get()
|
|
@@ -209,13 +208,18 @@ class NonDetMultiThreadedAugmenter(object):
|
|
|
209
208
|
if isinstance(self.generator, DataLoader):
|
|
210
209
|
self.generator.was_initialized = False
|
|
211
210
|
|
|
211
|
+
if torch is not None:
|
|
212
|
+
torch_nthreads = torch.get_num_threads()
|
|
213
|
+
torch.set_num_threads(1)
|
|
212
214
|
with threadpool_limits(limits=1, user_api=None):
|
|
213
215
|
for i in range(self.num_processes):
|
|
214
216
|
self._processes.append(Process(target=producer, args=(
|
|
215
217
|
self._queue, self.generator, self.transform, i, self.seeds[i], self.abort_event, self.wait_time
|
|
216
218
|
)))
|
|
217
219
|
self._processes[-1].daemon = True
|
|
218
|
-
|
|
220
|
+
_ = [i.start() for i in self._processes]
|
|
221
|
+
if torch is not None:
|
|
222
|
+
torch.set_num_threads(torch_nthreads)
|
|
219
223
|
|
|
220
224
|
if torch is not None and torch.cuda.is_available():
|
|
221
225
|
gpu = torch.cuda.current_device()
|
{batchgenerators-0.24 → batchgenerators-0.25.1}/batchgenerators/transforms/noise_transforms.py
RENAMED
|
@@ -133,6 +133,20 @@ class BlankSquareNoiseTransform(AbstractTransform):
|
|
|
133
133
|
self.channel_wise_n_val, self.square_pos)
|
|
134
134
|
return data_dict
|
|
135
135
|
|
|
136
|
+
class ColorFunctionExtractor:
|
|
137
|
+
def __init__(self, rectangle_value):
|
|
138
|
+
self.rectangle_value = rectangle_value
|
|
139
|
+
|
|
140
|
+
def __call__(self, x):
|
|
141
|
+
if np.isscalar(self.rectangle_value):
|
|
142
|
+
return self.rectangle_value
|
|
143
|
+
elif callable(self.rectangle_value):
|
|
144
|
+
return self.rectangle_value(x)
|
|
145
|
+
elif isinstance(self.rectangle_value, (tuple, list)):
|
|
146
|
+
return np.random.uniform(*self.rectangle_value)
|
|
147
|
+
else:
|
|
148
|
+
raise RuntimeError("unrecognized format for rectangle_value")
|
|
149
|
+
|
|
136
150
|
|
|
137
151
|
class BlankRectangleTransform(AbstractTransform):
|
|
138
152
|
def __init__(self, rectangle_size, rectangle_value, num_rectangles, force_square=False, p_per_sample=0.5,
|
|
@@ -179,16 +193,7 @@ class BlankRectangleTransform(AbstractTransform):
|
|
|
179
193
|
self.p_per_sample = p_per_sample
|
|
180
194
|
self.p_per_channel = p_per_channel
|
|
181
195
|
self.apply_to_keys = apply_to_keys
|
|
182
|
-
|
|
183
|
-
# intensity value
|
|
184
|
-
if np.isscalar(rectangle_value):
|
|
185
|
-
self.color_fn = lambda x: rectangle_value
|
|
186
|
-
elif callable(rectangle_value):
|
|
187
|
-
self.color_fn = lambda x: rectangle_value(x)
|
|
188
|
-
elif isinstance(rectangle_value, (tuple, list)):
|
|
189
|
-
self.color_fn = lambda x: np.random.uniform(*rectangle_value)
|
|
190
|
-
else:
|
|
191
|
-
raise RuntimeError("unrecognized format for rectangle_value")
|
|
196
|
+
self.color_fn = ColorFunctionExtractor(rectangle_value)
|
|
192
197
|
|
|
193
198
|
def __call__(self, **data_dict):
|
|
194
199
|
for k in self.apply_to_keys:
|