biomedisa 24.5.23__py3-none-any.whl → 24.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- biomedisa/deeplearning.py +38 -33
- biomedisa/features/DataGenerator.py +1 -1
- biomedisa/features/active_contour.py +3 -10
- biomedisa/features/biomedisa_helper.py +37 -27
- biomedisa/features/create_slices.py +4 -3
- biomedisa/features/crop_helper.py +2 -1
- biomedisa/features/keras_helper.py +290 -115
- biomedisa/features/remove_outlier.py +3 -9
- biomedisa/interpolation.py +9 -15
- biomedisa/mesh.py +12 -11
- {biomedisa-24.5.23.dist-info → biomedisa-24.8.1.dist-info}/METADATA +18 -12
- {biomedisa-24.5.23.dist-info → biomedisa-24.8.1.dist-info}/RECORD +15 -15
- {biomedisa-24.5.23.dist-info → biomedisa-24.8.1.dist-info}/WHEEL +1 -1
- {biomedisa-24.5.23.dist-info → biomedisa-24.8.1.dist-info}/LICENSE +0 -0
- {biomedisa-24.5.23.dist-info → biomedisa-24.8.1.dist-info}/top_level.txt +0 -0
@@ -45,6 +45,7 @@ from biomedisa.features.biomedisa_helper import (
|
|
45
45
|
img_resize, load_data, save_data, set_labels_to_zero, id_generator, unique_file_path)
|
46
46
|
from biomedisa.features.remove_outlier import clean, fill
|
47
47
|
from biomedisa.features.active_contour import activeContour
|
48
|
+
from tifffile import TiffFile, imread
|
48
49
|
import matplotlib.pyplot as plt
|
49
50
|
import SimpleITK as sitk
|
50
51
|
import tensorflow as tf
|
@@ -100,7 +101,7 @@ def save_history(history, path_to_model, val_dice, train_dice):
|
|
100
101
|
# save history dictonary
|
101
102
|
np.save(path_to_model.replace('.h5','.npy'), history)
|
102
103
|
|
103
|
-
def predict_blocksize(labelData, x_puffer, y_puffer, z_puffer):
|
104
|
+
def predict_blocksize(labelData, x_puffer=25, y_puffer=25, z_puffer=25):
|
104
105
|
zsh, ysh, xsh = labelData.shape
|
105
106
|
argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x = zsh, 0, ysh, 0, xsh, 0
|
106
107
|
for k in range(zsh):
|
@@ -121,7 +122,7 @@ def predict_blocksize(labelData, x_puffer, y_puffer, z_puffer):
|
|
121
122
|
argmax_z = argmax_z + z_puffer if argmax_z + z_puffer < zsh else zsh
|
122
123
|
return argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x
|
123
124
|
|
124
|
-
def
|
125
|
+
def set_image_dimensions(header, data):
|
125
126
|
|
126
127
|
# read header as string
|
127
128
|
b = header.tobytes()
|
@@ -148,7 +149,7 @@ def get_image_dimensions(header, data):
|
|
148
149
|
new_header = np.frombuffer(b2, dtype=header.dtype)
|
149
150
|
return new_header
|
150
151
|
|
151
|
-
def
|
152
|
+
def set_physical_size(header, img_header):
|
152
153
|
|
153
154
|
# read img_header as string
|
154
155
|
b = img_header.tobytes()
|
@@ -354,7 +355,7 @@ def read_img_list(img_list, label_list, temp_img_dir, temp_label_dir):
|
|
354
355
|
label_names.append(label_name)
|
355
356
|
return img_names, label_names
|
356
357
|
|
357
|
-
def load_training_data(normalize, img_list, label_list, channels, x_scale, y_scale, z_scale,
|
358
|
+
def load_training_data(normalize, img_list, label_list, channels, x_scale, y_scale, z_scale, scaling,
|
358
359
|
crop_data, labels_to_compute, labels_to_remove, img_in=None, label_in=None,
|
359
360
|
normalization_parameters=None, allLabels=None, header=None, extension='.tif',
|
360
361
|
x_puffer=25, y_puffer=25, z_puffer=25):
|
@@ -386,7 +387,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
|
|
386
387
|
if crop_data:
|
387
388
|
argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(label, x_puffer, y_puffer, z_puffer)
|
388
389
|
label = np.copy(label[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
|
389
|
-
if
|
390
|
+
if scaling:
|
390
391
|
label = img_resize(label, z_scale, y_scale, x_scale, labels=True)
|
391
392
|
|
392
393
|
# if header is not single data stream Amira Mesh falling back to Multi-TIFF
|
@@ -412,7 +413,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
|
|
412
413
|
else:
|
413
414
|
img = img_in
|
414
415
|
img_names = ['img_1']
|
415
|
-
if label_dim != img.shape:
|
416
|
+
if label_dim != img.shape[:3]:
|
416
417
|
InputError.message = f'Dimensions of "{os.path.basename(img_names[0])}" and "{os.path.basename(label_names[0])}" do not match'
|
417
418
|
raise InputError()
|
418
419
|
|
@@ -432,7 +433,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
|
|
432
433
|
|
433
434
|
# scale/resize image data
|
434
435
|
img = img.astype(np.float32)
|
435
|
-
if
|
436
|
+
if scaling:
|
436
437
|
img = img_resize(img, z_scale, y_scale, x_scale)
|
437
438
|
|
438
439
|
# normalize image data
|
@@ -469,7 +470,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
|
|
469
470
|
if crop_data:
|
470
471
|
argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(a, x_puffer, y_puffer, z_puffer)
|
471
472
|
a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
|
472
|
-
if
|
473
|
+
if scaling:
|
473
474
|
a = img_resize(a, z_scale, y_scale, x_scale, labels=True)
|
474
475
|
label = np.append(label, a, axis=0)
|
475
476
|
|
@@ -481,7 +482,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
|
|
481
482
|
raise InputError()
|
482
483
|
else:
|
483
484
|
a = img_in[k]
|
484
|
-
if label_dim != a.shape:
|
485
|
+
if label_dim != a.shape[:3]:
|
485
486
|
InputError.message = f'Dimensions of "{os.path.basename(img_names[k])}" and "{os.path.basename(label_names[k])}" do not match'
|
486
487
|
raise InputError()
|
487
488
|
if len(a.shape)==3:
|
@@ -493,7 +494,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
|
|
493
494
|
if crop_data:
|
494
495
|
a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
|
495
496
|
a = a.astype(np.float32)
|
496
|
-
if
|
497
|
+
if scaling:
|
497
498
|
a = img_resize(a, z_scale, y_scale, x_scale)
|
498
499
|
for c in range(channels):
|
499
500
|
a[:,:,:,c] -= np.amin(a[:,:,:,c])
|
@@ -541,18 +542,17 @@ class CustomCallback(Callback):
|
|
541
542
|
time_remaining = str(t // 60) + 'min'
|
542
543
|
else:
|
543
544
|
time_remaining = str(t // 3600) + 'h ' + str((t % 3600) // 60) + 'min'
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
image.message = 'Progress {}%, {} remaining'.format(percentage,time_remaining)
|
545
|
+
image.message = 'Progress {}%, {} remaining'.format(percentage,time_remaining)
|
546
|
+
if 'best_val_dice' in logs:
|
547
|
+
best_val_dice = round(float(logs['best_val_dice'])*100,1)
|
548
|
+
image.message += f', {best_val_dice}% accuracy'
|
549
549
|
image.save()
|
550
550
|
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
|
551
551
|
|
552
552
|
class MetaData(Callback):
|
553
553
|
def __init__(self, path_to_model, configuration_data, allLabels,
|
554
554
|
extension, header, crop_data, cropping_weights, cropping_config,
|
555
|
-
normalization_parameters, cropping_norm):
|
555
|
+
normalization_parameters, cropping_norm, patch_normalization, scaling):
|
556
556
|
|
557
557
|
self.path_to_model = path_to_model
|
558
558
|
self.configuration_data = configuration_data
|
@@ -564,6 +564,8 @@ class MetaData(Callback):
|
|
564
564
|
self.cropping_weights = cropping_weights
|
565
565
|
self.cropping_config = cropping_config
|
566
566
|
self.cropping_norm = cropping_norm
|
567
|
+
self.patch_normalization = patch_normalization
|
568
|
+
self.scaling = scaling
|
567
569
|
|
568
570
|
def on_epoch_end(self, epoch, logs={}):
|
569
571
|
hf = h5py.File(self.path_to_model, 'r')
|
@@ -574,6 +576,8 @@ class MetaData(Callback):
|
|
574
576
|
group.create_dataset('configuration', data=self.configuration_data)
|
575
577
|
group.create_dataset('normalization', data=self.normalization_parameters)
|
576
578
|
group.create_dataset('labels', data=self.allLabels)
|
579
|
+
group.create_dataset('patch_normalization', data=int(self.patch_normalization))
|
580
|
+
group.create_dataset('scaling', data=int(self.scaling))
|
577
581
|
if self.extension == '.am':
|
578
582
|
group.create_dataset('extension', data=self.extension)
|
579
583
|
group.create_dataset('header', data=self.header)
|
@@ -758,8 +762,8 @@ def dice_coef_loss(nb_labels):
|
|
758
762
|
for index in range(1,nb_labels):
|
759
763
|
dice += dice_coef(y_true[:,:,:,:,index], y_pred[:,:,:,:,index])
|
760
764
|
dice = dice / (nb_labels-1)
|
761
|
-
loss = -K.log(dice)
|
762
|
-
|
765
|
+
#loss = -K.log(dice)
|
766
|
+
loss = 1 - dice
|
763
767
|
return loss
|
764
768
|
return loss_fn
|
765
769
|
|
@@ -772,11 +776,13 @@ def train_semantic_segmentation(bm,
|
|
772
776
|
|
773
777
|
# training data
|
774
778
|
img, label, allLabels, normalization_parameters, header, extension, bm.channels = load_training_data(bm.normalize,
|
775
|
-
img_list, label_list, None, bm.x_scale, bm.y_scale, bm.z_scale, bm.
|
779
|
+
img_list, label_list, None, bm.x_scale, bm.y_scale, bm.z_scale, bm.scaling, bm.crop_data,
|
776
780
|
bm.only, bm.ignore, img, label, None, None, header, extension)
|
777
781
|
|
778
782
|
# configuration data
|
779
|
-
configuration_data = np.array([bm.channels,
|
783
|
+
configuration_data = np.array([bm.channels,
|
784
|
+
bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize,
|
785
|
+
normalization_parameters[0,0], normalization_parameters[1,0]])
|
780
786
|
|
781
787
|
# img shape
|
782
788
|
zsh, ysh, xsh, _ = img.shape
|
@@ -784,7 +790,7 @@ def train_semantic_segmentation(bm,
|
|
784
790
|
# validation data
|
785
791
|
if any(val_img_list) or img_val is not None:
|
786
792
|
img_val, label_val, _, _, _, _, _ = load_training_data(bm.normalize,
|
787
|
-
val_img_list, val_label_list, bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.
|
793
|
+
val_img_list, val_label_list, bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.scaling, bm.crop_data,
|
788
794
|
bm.only, bm.ignore, img_val, label_val, normalization_parameters, allLabels)
|
789
795
|
|
790
796
|
elif bm.validation_split:
|
@@ -908,7 +914,7 @@ def train_semantic_segmentation(bm,
|
|
908
914
|
# save meta data
|
909
915
|
meta_data = MetaData(bm.path_to_model, configuration_data, allLabels,
|
910
916
|
extension, header, bm.crop_data, bm.cropping_weights, bm.cropping_config,
|
911
|
-
normalization_parameters, bm.cropping_norm)
|
917
|
+
normalization_parameters, bm.cropping_norm, bm.patch_normalization, bm.scaling)
|
912
918
|
|
913
919
|
# model checkpoint
|
914
920
|
if img_val is not None:
|
@@ -946,22 +952,33 @@ def train_semantic_segmentation(bm,
|
|
946
952
|
if img_val is not None and not bm.val_dice:
|
947
953
|
save_history(history.history, bm.path_to_model, False, bm.train_dice)
|
948
954
|
|
949
|
-
def load_prediction_data(
|
950
|
-
|
951
|
-
|
955
|
+
def load_prediction_data(bm, channels, normalize, normalization_parameters,
|
956
|
+
region_of_interest, img, img_header, load_blockwise=False, z=None):
|
957
|
+
|
952
958
|
# read image data
|
953
959
|
if img is None:
|
954
|
-
|
960
|
+
if load_blockwise:
|
961
|
+
img_header = None
|
962
|
+
tif = TiffFile(bm.path_to_image)
|
963
|
+
img = imread(bm.path_to_image, key=range(z,min(len(tif.pages),z+bm.z_patch)))
|
964
|
+
if img.shape[0] < bm.z_patch:
|
965
|
+
rest = bm.z_patch - img.shape[0]
|
966
|
+
tmp = imread(bm.path_to_image, key=range(len(tif.pages)-rest,len(tif.pages)))
|
967
|
+
img = np.append(img, tmp[::-1], axis=0)
|
968
|
+
else:
|
969
|
+
img, img_header = load_data(bm.path_to_image, 'first_queue')
|
955
970
|
|
956
971
|
# verify validity
|
957
972
|
if img is None:
|
958
|
-
InputError.message = f'Invalid image data: {os.path.basename(
|
973
|
+
InputError.message = f'Invalid image data: {os.path.basename(bm.path_to_image)}.'
|
959
974
|
raise InputError()
|
960
975
|
|
961
|
-
# preserve original image data
|
962
|
-
img_data =
|
976
|
+
# preserve original image data for post-processing
|
977
|
+
img_data = None
|
978
|
+
if bm.acwe:
|
979
|
+
img_data = img.copy()
|
963
980
|
|
964
|
-
# handle all images
|
981
|
+
# handle all images using number of channels >=1
|
965
982
|
if len(img.shape)==3:
|
966
983
|
z_shape, y_shape, x_shape = img.shape
|
967
984
|
img = img.reshape(z_shape, y_shape, x_shape, 1)
|
@@ -969,7 +986,7 @@ def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
|
|
969
986
|
InputError.message = f'Number of channels must be {channels}.'
|
970
987
|
raise InputError()
|
971
988
|
|
972
|
-
# image shape
|
989
|
+
# original image shape
|
973
990
|
z_shape, y_shape, x_shape, _ = img.shape
|
974
991
|
|
975
992
|
# automatic cropping of image to region of interest
|
@@ -981,8 +998,8 @@ def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
|
|
981
998
|
|
982
999
|
# scale/resize image data
|
983
1000
|
img = img.astype(np.float32)
|
984
|
-
if
|
985
|
-
img = img_resize(img, z_scale, y_scale, x_scale)
|
1001
|
+
if bm.scaling:
|
1002
|
+
img = img_resize(img, bm.z_scale, bm.y_scale, bm.x_scale)
|
986
1003
|
|
987
1004
|
# normalize image data
|
988
1005
|
for c in range(channels):
|
@@ -1000,94 +1017,256 @@ def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
|
|
1000
1017
|
|
1001
1018
|
return img, img_header, z_shape, y_shape, x_shape, region_of_interest, img_data
|
1002
1019
|
|
1003
|
-
def
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1020
|
+
def append_ghost_areas(bm, img):
|
1021
|
+
# append ghost areas to make image dimensions divisible by patch size (mirror edge areas)
|
1022
|
+
zsh, ysh, xsh, _ = img.shape
|
1023
|
+
z_rest = bm.z_patch - (zsh % bm.z_patch)
|
1024
|
+
if z_rest == bm.z_patch:
|
1025
|
+
z_rest = -zsh
|
1026
|
+
else:
|
1027
|
+
img = np.append(img, img[-z_rest:][::-1], axis=0)
|
1028
|
+
y_rest = bm.y_patch - (ysh % bm.y_patch)
|
1029
|
+
if y_rest == bm.y_patch:
|
1030
|
+
y_rest = -ysh
|
1031
|
+
else:
|
1032
|
+
img = np.append(img, img[:,-y_rest:][:,::-1], axis=1)
|
1033
|
+
x_rest = bm.x_patch - (xsh % bm.x_patch)
|
1034
|
+
if x_rest == bm.x_patch:
|
1035
|
+
x_rest = -xsh
|
1036
|
+
else:
|
1037
|
+
img = np.append(img, img[:,:,-x_rest:][:,:,::-1], axis=2)
|
1038
|
+
return img, z_rest, y_rest, x_rest
|
1007
1039
|
|
1008
|
-
|
1040
|
+
def predict_semantic_segmentation(bm,
|
1041
|
+
header, img_header,
|
1042
|
+
region_of_interest, extension, img_data,
|
1043
|
+
channels, normalization_parameters):
|
1009
1044
|
|
1010
|
-
#
|
1011
|
-
|
1045
|
+
# initialize results
|
1046
|
+
results = {}
|
1012
1047
|
|
1013
1048
|
# number of labels
|
1014
|
-
nb_labels = len(allLabels)
|
1049
|
+
nb_labels = len(bm.allLabels)
|
1050
|
+
results['allLabels'] = bm.allLabels
|
1015
1051
|
|
1016
|
-
#
|
1017
|
-
|
1052
|
+
# load model
|
1053
|
+
if bm.dice_loss:
|
1054
|
+
def loss_fn(y_true, y_pred):
|
1055
|
+
dice = 0
|
1056
|
+
for index in range(1, nb_labels):
|
1057
|
+
dice += dice_coef(y_true[:,:,:,:,index], y_pred[:,:,:,:,index])
|
1058
|
+
dice = dice / (nb_labels-1)
|
1059
|
+
#loss = -K.log(dice)
|
1060
|
+
loss = 1 - dice
|
1061
|
+
return loss
|
1062
|
+
custom_objects = {'dice_coef_loss': dice_coef_loss,'loss_fn': loss_fn}
|
1063
|
+
model = load_model(bm.path_to_model, custom_objects=custom_objects)
|
1064
|
+
else:
|
1065
|
+
model = load_model(bm.path_to_model)
|
1066
|
+
|
1067
|
+
# check if data can be loaded blockwise to save host memory
|
1068
|
+
load_blockwise = False
|
1069
|
+
if not bm.scaling and not bm.normalize and bm.path_to_image and not np.any(region_of_interest) and \
|
1070
|
+
os.path.splitext(bm.path_to_image)[1] in ['.tif', '.tiff'] and not bm.acwe:
|
1071
|
+
tif = TiffFile(bm.path_to_image)
|
1072
|
+
zsh = len(tif.pages)
|
1073
|
+
ysh, xsh = tif.pages[0].shape
|
1074
|
+
|
1075
|
+
# determine new image size after appending ghost areas to make image dimensions divisible by patch size
|
1076
|
+
z_rest = bm.z_patch - (zsh % bm.z_patch)
|
1077
|
+
if z_rest == bm.z_patch:
|
1078
|
+
z_rest = -zsh
|
1079
|
+
else:
|
1080
|
+
zsh += z_rest
|
1081
|
+
y_rest = bm.y_patch - (ysh % bm.y_patch)
|
1082
|
+
if y_rest == bm.y_patch:
|
1083
|
+
y_rest = -ysh
|
1084
|
+
else:
|
1085
|
+
ysh += y_rest
|
1086
|
+
x_rest = bm.x_patch - (xsh % bm.x_patch)
|
1087
|
+
if x_rest == bm.x_patch:
|
1088
|
+
x_rest = -xsh
|
1089
|
+
else:
|
1090
|
+
xsh += x_rest
|
1091
|
+
|
1092
|
+
# get Ids of patches
|
1093
|
+
list_IDs = []
|
1094
|
+
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
1095
|
+
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1096
|
+
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
1097
|
+
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1018
1098
|
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
for m in range(0, xsh-x_patch+1, stride_size):
|
1023
|
-
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1099
|
+
# make length of list divisible by batch size
|
1100
|
+
rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
|
1101
|
+
list_IDs = list_IDs + list_IDs[:rest]
|
1024
1102
|
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1103
|
+
# prediction
|
1104
|
+
if len(list_IDs) > 400:
|
1105
|
+
load_blockwise = True
|
1028
1106
|
|
1029
|
-
#
|
1030
|
-
|
1107
|
+
# load image data and calculate patch IDs
|
1108
|
+
if not load_blockwise:
|
1031
1109
|
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
'batch_size': batch_size,
|
1036
|
-
'n_channels': csh,
|
1037
|
-
'patch_normalization': bm.patch_normalization}
|
1110
|
+
# load prediction data
|
1111
|
+
img, img_header, z_shape, y_shape, x_shape, region_of_interest, img_data = load_prediction_data(
|
1112
|
+
bm, channels, bm.normalize, normalization_parameters, region_of_interest, img_data, img_header)
|
1038
1113
|
|
1039
|
-
|
1040
|
-
|
1114
|
+
# append ghost areas
|
1115
|
+
img, z_rest, y_rest, x_rest = append_ghost_areas(bm, img)
|
1041
1116
|
|
1042
|
-
|
1043
|
-
|
1117
|
+
# img shape
|
1118
|
+
zsh, ysh, xsh, _ = img.shape
|
1119
|
+
|
1120
|
+
# list of IDs
|
1121
|
+
list_IDs = []
|
1122
|
+
|
1123
|
+
# get Ids of patches
|
1124
|
+
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
1125
|
+
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1126
|
+
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
1127
|
+
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1128
|
+
|
1129
|
+
# make length of list divisible by batch size
|
1130
|
+
rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
|
1131
|
+
list_IDs = list_IDs + list_IDs[:rest]
|
1132
|
+
|
1133
|
+
# number of patches
|
1134
|
+
nb_patches = len(list_IDs)
|
1135
|
+
|
1136
|
+
# load all patches on GPU memory
|
1137
|
+
if not load_blockwise and nb_patches < 400:
|
1138
|
+
|
1139
|
+
# parameters
|
1140
|
+
params = {'dim': (bm.z_patch, bm.y_patch, bm.x_patch),
|
1141
|
+
'dim_img': (zsh, ysh, xsh),
|
1142
|
+
'batch_size': bm.batch_size,
|
1143
|
+
'n_channels': channels,
|
1144
|
+
'patch_normalization': bm.patch_normalization}
|
1044
1145
|
|
1045
|
-
|
1046
|
-
|
1146
|
+
# data generator
|
1147
|
+
predict_generator = PredictDataGenerator(img, list_IDs, **params)
|
1148
|
+
|
1149
|
+
# predict probabilities
|
1047
1150
|
probabilities = model.predict(predict_generator, verbose=0, steps=None)
|
1151
|
+
|
1152
|
+
# create final
|
1153
|
+
final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
|
1154
|
+
nb = 0
|
1155
|
+
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
1156
|
+
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1157
|
+
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
1158
|
+
final[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += probabilities[nb]
|
1159
|
+
nb += 1
|
1160
|
+
|
1161
|
+
# calculate result
|
1162
|
+
label = np.argmax(final, axis=-1).astype(np.uint8)
|
1163
|
+
|
1048
1164
|
else:
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
#
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1165
|
+
# stream data batchwise to GPU to reduce memory usage
|
1166
|
+
X = np.empty((bm.batch_size, bm.z_patch, bm.y_patch, bm.x_patch, channels), dtype=np.float32)
|
1167
|
+
|
1168
|
+
# allocate final array
|
1169
|
+
if bm.return_probs:
|
1170
|
+
final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
|
1171
|
+
|
1172
|
+
# allocate result array
|
1173
|
+
label = np.zeros((zsh, ysh, xsh), dtype=np.uint8)
|
1174
|
+
|
1175
|
+
# predict segmentation block by block
|
1176
|
+
z_indices = range(0, zsh-bm.z_patch+1, bm.stride_size)
|
1177
|
+
for j, z in enumerate(z_indices):
|
1178
|
+
|
1179
|
+
# load blockwise
|
1180
|
+
if load_blockwise:
|
1181
|
+
img, _, _, _, _, _, _ = load_prediction_data(bm,
|
1182
|
+
channels, bm.normalize, normalization_parameters,
|
1183
|
+
region_of_interest, img_data, img_header, load_blockwise, z)
|
1184
|
+
img, _, _, _ = append_ghost_areas(bm, img)
|
1185
|
+
|
1186
|
+
# list of IDs
|
1187
|
+
list_IDs = []
|
1188
|
+
|
1189
|
+
# get Ids of patches
|
1190
|
+
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1191
|
+
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
1192
|
+
list_IDs.append(z*ysh*xsh+l*xsh+m)
|
1193
|
+
|
1194
|
+
# make length of list divisible by batch size
|
1195
|
+
max_i = len(list_IDs)
|
1196
|
+
rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
|
1197
|
+
list_IDs = list_IDs + list_IDs[:rest]
|
1198
|
+
|
1199
|
+
# number of patches
|
1200
|
+
nb_patches = len(list_IDs)
|
1201
|
+
|
1202
|
+
# allocate tmp probabilities array
|
1203
|
+
probs = np.zeros((bm.z_patch, ysh, xsh, nb_labels), dtype=np.float32)
|
1204
|
+
|
1205
|
+
# get one batch of image patches
|
1206
|
+
for step in range(nb_patches//bm.batch_size):
|
1207
|
+
for i, ID in enumerate(list_IDs[step*bm.batch_size:(step+1)*bm.batch_size]):
|
1208
|
+
|
1209
|
+
# get patch indices
|
1210
|
+
k=0 if load_blockwise else ID // (ysh*xsh)
|
1211
|
+
rest = ID % (ysh*xsh)
|
1212
|
+
l = rest // xsh
|
1213
|
+
m = rest % xsh
|
1214
|
+
|
1215
|
+
# get patch
|
1216
|
+
tmp_X = img[k:k+bm.z_patch,l:l+bm.y_patch,m:m+bm.x_patch]
|
1217
|
+
if bm.patch_normalization:
|
1218
|
+
tmp_X = np.copy(tmp_X, order='C')
|
1219
|
+
for c in range(channels):
|
1220
|
+
tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
|
1221
|
+
tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
|
1222
|
+
X[i] = tmp_X
|
1223
|
+
|
1224
|
+
# predict batch
|
1225
|
+
Y = model.predict(X, verbose=0, steps=None, batch_size=bm.batch_size)
|
1226
|
+
|
1227
|
+
# loop over result patches
|
1228
|
+
for i, ID in enumerate(list_IDs[step*bm.batch_size:(step+1)*bm.batch_size]):
|
1229
|
+
rest = ID % (ysh*xsh)
|
1230
|
+
l = rest // xsh
|
1231
|
+
m = rest % xsh
|
1232
|
+
if step*bm.batch_size+i < max_i:
|
1233
|
+
probs[:,l:l+bm.y_patch,m:m+bm.x_patch] += Y[i]
|
1234
|
+
|
1235
|
+
# overlap in z direction
|
1236
|
+
if bm.stride_size < bm.z_patch:
|
1237
|
+
if j>0:
|
1238
|
+
probs[:bm.stride_size] += overlap
|
1239
|
+
overlap = probs[bm.stride_size:].copy()
|
1240
|
+
|
1241
|
+
# calculate result
|
1242
|
+
if z==z_indices[-1]:
|
1243
|
+
label[z:z+bm.z_patch] = np.argmax(probs, axis=-1).astype(np.uint8)
|
1082
1244
|
if bm.return_probs:
|
1083
|
-
|
1084
|
-
|
1245
|
+
final[z:z+bm.z_patch] = probs
|
1246
|
+
else:
|
1247
|
+
block_zsh = min(bm.stride_size, bm.z_patch)
|
1248
|
+
label[z:z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
|
1249
|
+
if bm.return_probs:
|
1250
|
+
final[z:z+block_zsh] = probs[:block_zsh]
|
1251
|
+
|
1252
|
+
# remove appendix
|
1253
|
+
if bm.return_probs:
|
1254
|
+
final = final[:-z_rest,:-y_rest,:-x_rest]
|
1255
|
+
label = label[:-z_rest,:-y_rest,:-x_rest]
|
1256
|
+
zsh, ysh, xsh = label.shape
|
1085
1257
|
|
1086
1258
|
# return probabilities
|
1087
1259
|
if bm.return_probs:
|
1260
|
+
counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
|
1261
|
+
nb = 0
|
1262
|
+
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
1263
|
+
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1264
|
+
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
1265
|
+
counter[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += 1
|
1266
|
+
nb += 1
|
1088
1267
|
counter[counter==0] = 1
|
1089
1268
|
probabilities = final / counter
|
1090
|
-
if
|
1269
|
+
if bm.scaling:
|
1091
1270
|
probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
|
1092
1271
|
if np.any(region_of_interest):
|
1093
1272
|
min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
|
@@ -1096,12 +1275,8 @@ def predict_semantic_segmentation(bm, img, path_to_model,
|
|
1096
1275
|
probabilities = np.copy(tmp, order='C')
|
1097
1276
|
results['probs'] = probabilities
|
1098
1277
|
|
1099
|
-
# get final
|
1100
|
-
label = np.argmax(final, axis=3)
|
1101
|
-
label = label.astype(np.uint8)
|
1102
|
-
|
1103
1278
|
# rescale final to input size
|
1104
|
-
if
|
1279
|
+
if bm.scaling:
|
1105
1280
|
label = img_resize(label, z_shape, y_shape, x_shape, labels=True)
|
1106
1281
|
|
1107
1282
|
# revert automatic cropping
|
@@ -1112,7 +1287,7 @@ def predict_semantic_segmentation(bm, img, path_to_model,
|
|
1112
1287
|
label = np.copy(tmp, order='C')
|
1113
1288
|
|
1114
1289
|
# get result
|
1115
|
-
label = get_labels(label, allLabels)
|
1290
|
+
label = get_labels(label, bm.allLabels)
|
1116
1291
|
results['regular'] = label
|
1117
1292
|
|
1118
1293
|
# load header from file
|
@@ -1130,10 +1305,10 @@ def predict_semantic_segmentation(bm, img, path_to_model,
|
|
1130
1305
|
# handle amira header
|
1131
1306
|
if header is not None:
|
1132
1307
|
if extension == '.am':
|
1133
|
-
header =
|
1308
|
+
header = set_image_dimensions(header[0], label)
|
1134
1309
|
if img_header is not None:
|
1135
1310
|
try:
|
1136
|
-
header =
|
1311
|
+
header = set_physical_size(header, img_header[0])
|
1137
1312
|
except:
|
1138
1313
|
pass
|
1139
1314
|
header = [header]
|
@@ -1151,7 +1326,7 @@ def predict_semantic_segmentation(bm, img, path_to_model,
|
|
1151
1326
|
|
1152
1327
|
# save result
|
1153
1328
|
if bm.path_to_image:
|
1154
|
-
save_data(bm.path_to_final, label, header=header, compress=
|
1329
|
+
save_data(bm.path_to_final, label, header=header, compress=bm.compression)
|
1155
1330
|
|
1156
1331
|
# paths to optional results
|
1157
1332
|
filename, extension = os.path.splitext(bm.path_to_final)
|
@@ -1169,17 +1344,17 @@ def predict_semantic_segmentation(bm, img, path_to_model,
|
|
1169
1344
|
cleaned_result = clean(label, bm.clean)
|
1170
1345
|
results['cleaned'] = cleaned_result
|
1171
1346
|
if bm.path_to_image:
|
1172
|
-
save_data(path_to_cleaned, cleaned_result, header=header, compress=
|
1347
|
+
save_data(path_to_cleaned, cleaned_result, header=header, compress=bm.compression)
|
1173
1348
|
if bm.fill:
|
1174
1349
|
filled_result = clean(label, bm.fill)
|
1175
1350
|
results['filled'] = filled_result
|
1176
1351
|
if bm.path_to_image:
|
1177
|
-
save_data(path_to_filled, filled_result, header=header, compress=
|
1352
|
+
save_data(path_to_filled, filled_result, header=header, compress=bm.compression)
|
1178
1353
|
if bm.clean and bm.fill:
|
1179
1354
|
cleaned_filled_result = cleaned_result + (filled_result - label)
|
1180
1355
|
results['cleaned_filled'] = cleaned_filled_result
|
1181
1356
|
if bm.path_to_image:
|
1182
|
-
save_data(path_to_cleaned_filled, cleaned_filled_result, header=header, compress=
|
1357
|
+
save_data(path_to_cleaned_filled, cleaned_filled_result, header=header, compress=bm.compression)
|
1183
1358
|
|
1184
1359
|
# post-processing with active contour
|
1185
1360
|
if bm.acwe:
|
@@ -1188,8 +1363,8 @@ def predict_semantic_segmentation(bm, img, path_to_model,
|
|
1188
1363
|
results['acwe'] = acwe_result
|
1189
1364
|
results['refined'] = refined_result
|
1190
1365
|
if bm.path_to_image:
|
1191
|
-
save_data(path_to_acwe, acwe_result, header=header, compress=
|
1192
|
-
save_data(path_to_refined, refined_result, header=header, compress=
|
1366
|
+
save_data(path_to_acwe, acwe_result, header=header, compress=bm.compression)
|
1367
|
+
save_data(path_to_refined, refined_result, header=header, compress=bm.compression)
|
1193
1368
|
|
1194
1369
|
return results, bm
|
1195
1370
|
|