biomedisa 24.5.23__py3-none-any.whl → 24.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -45,6 +45,7 @@ from biomedisa.features.biomedisa_helper import (
45
45
  img_resize, load_data, save_data, set_labels_to_zero, id_generator, unique_file_path)
46
46
  from biomedisa.features.remove_outlier import clean, fill
47
47
  from biomedisa.features.active_contour import activeContour
48
+ from tifffile import TiffFile, imread
48
49
  import matplotlib.pyplot as plt
49
50
  import SimpleITK as sitk
50
51
  import tensorflow as tf
@@ -100,7 +101,7 @@ def save_history(history, path_to_model, val_dice, train_dice):
100
101
  # save history dictonary
101
102
  np.save(path_to_model.replace('.h5','.npy'), history)
102
103
 
103
- def predict_blocksize(labelData, x_puffer, y_puffer, z_puffer):
104
+ def predict_blocksize(labelData, x_puffer=25, y_puffer=25, z_puffer=25):
104
105
  zsh, ysh, xsh = labelData.shape
105
106
  argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x = zsh, 0, ysh, 0, xsh, 0
106
107
  for k in range(zsh):
@@ -121,7 +122,7 @@ def predict_blocksize(labelData, x_puffer, y_puffer, z_puffer):
121
122
  argmax_z = argmax_z + z_puffer if argmax_z + z_puffer < zsh else zsh
122
123
  return argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x
123
124
 
124
- def get_image_dimensions(header, data):
125
+ def set_image_dimensions(header, data):
125
126
 
126
127
  # read header as string
127
128
  b = header.tobytes()
@@ -148,7 +149,7 @@ def get_image_dimensions(header, data):
148
149
  new_header = np.frombuffer(b2, dtype=header.dtype)
149
150
  return new_header
150
151
 
151
- def get_physical_size(header, img_header):
152
+ def set_physical_size(header, img_header):
152
153
 
153
154
  # read img_header as string
154
155
  b = img_header.tobytes()
@@ -354,7 +355,7 @@ def read_img_list(img_list, label_list, temp_img_dir, temp_label_dir):
354
355
  label_names.append(label_name)
355
356
  return img_names, label_names
356
357
 
357
- def load_training_data(normalize, img_list, label_list, channels, x_scale, y_scale, z_scale, no_scaling,
358
+ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_scale, z_scale, scaling,
358
359
  crop_data, labels_to_compute, labels_to_remove, img_in=None, label_in=None,
359
360
  normalization_parameters=None, allLabels=None, header=None, extension='.tif',
360
361
  x_puffer=25, y_puffer=25, z_puffer=25):
@@ -386,7 +387,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
386
387
  if crop_data:
387
388
  argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(label, x_puffer, y_puffer, z_puffer)
388
389
  label = np.copy(label[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
389
- if not no_scaling:
390
+ if scaling:
390
391
  label = img_resize(label, z_scale, y_scale, x_scale, labels=True)
391
392
 
392
393
  # if header is not single data stream Amira Mesh falling back to Multi-TIFF
@@ -412,7 +413,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
412
413
  else:
413
414
  img = img_in
414
415
  img_names = ['img_1']
415
- if label_dim != img.shape:
416
+ if label_dim != img.shape[:3]:
416
417
  InputError.message = f'Dimensions of "{os.path.basename(img_names[0])}" and "{os.path.basename(label_names[0])}" do not match'
417
418
  raise InputError()
418
419
 
@@ -432,7 +433,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
432
433
 
433
434
  # scale/resize image data
434
435
  img = img.astype(np.float32)
435
- if not no_scaling:
436
+ if scaling:
436
437
  img = img_resize(img, z_scale, y_scale, x_scale)
437
438
 
438
439
  # normalize image data
@@ -469,7 +470,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
469
470
  if crop_data:
470
471
  argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(a, x_puffer, y_puffer, z_puffer)
471
472
  a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
472
- if not no_scaling:
473
+ if scaling:
473
474
  a = img_resize(a, z_scale, y_scale, x_scale, labels=True)
474
475
  label = np.append(label, a, axis=0)
475
476
 
@@ -481,7 +482,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
481
482
  raise InputError()
482
483
  else:
483
484
  a = img_in[k]
484
- if label_dim != a.shape:
485
+ if label_dim != a.shape[:3]:
485
486
  InputError.message = f'Dimensions of "{os.path.basename(img_names[k])}" and "{os.path.basename(label_names[k])}" do not match'
486
487
  raise InputError()
487
488
  if len(a.shape)==3:
@@ -493,7 +494,7 @@ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_sca
493
494
  if crop_data:
494
495
  a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
495
496
  a = a.astype(np.float32)
496
- if not no_scaling:
497
+ if scaling:
497
498
  a = img_resize(a, z_scale, y_scale, x_scale)
498
499
  for c in range(channels):
499
500
  a[:,:,:,c] -= np.amin(a[:,:,:,c])
@@ -541,18 +542,17 @@ class CustomCallback(Callback):
541
542
  time_remaining = str(t // 60) + 'min'
542
543
  else:
543
544
  time_remaining = str(t // 3600) + 'h ' + str((t % 3600) // 60) + 'min'
544
- try:
545
- val_accuracy = round(float(logs["val_accuracy"])*100,1)
546
- image.message = 'Progress {}%, {} remaining, {}% accuracy'.format(percentage,time_remaining,val_accuracy)
547
- except KeyError:
548
- image.message = 'Progress {}%, {} remaining'.format(percentage,time_remaining)
545
+ image.message = 'Progress {}%, {} remaining'.format(percentage,time_remaining)
546
+ if 'best_val_dice' in logs:
547
+ best_val_dice = round(float(logs['best_val_dice'])*100,1)
548
+ image.message += f', {best_val_dice}% accuracy'
549
549
  image.save()
550
550
  print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
551
551
 
552
552
  class MetaData(Callback):
553
553
  def __init__(self, path_to_model, configuration_data, allLabels,
554
554
  extension, header, crop_data, cropping_weights, cropping_config,
555
- normalization_parameters, cropping_norm):
555
+ normalization_parameters, cropping_norm, patch_normalization, scaling):
556
556
 
557
557
  self.path_to_model = path_to_model
558
558
  self.configuration_data = configuration_data
@@ -564,6 +564,8 @@ class MetaData(Callback):
564
564
  self.cropping_weights = cropping_weights
565
565
  self.cropping_config = cropping_config
566
566
  self.cropping_norm = cropping_norm
567
+ self.patch_normalization = patch_normalization
568
+ self.scaling = scaling
567
569
 
568
570
  def on_epoch_end(self, epoch, logs={}):
569
571
  hf = h5py.File(self.path_to_model, 'r')
@@ -574,6 +576,8 @@ class MetaData(Callback):
574
576
  group.create_dataset('configuration', data=self.configuration_data)
575
577
  group.create_dataset('normalization', data=self.normalization_parameters)
576
578
  group.create_dataset('labels', data=self.allLabels)
579
+ group.create_dataset('patch_normalization', data=int(self.patch_normalization))
580
+ group.create_dataset('scaling', data=int(self.scaling))
577
581
  if self.extension == '.am':
578
582
  group.create_dataset('extension', data=self.extension)
579
583
  group.create_dataset('header', data=self.header)
@@ -758,8 +762,8 @@ def dice_coef_loss(nb_labels):
758
762
  for index in range(1,nb_labels):
759
763
  dice += dice_coef(y_true[:,:,:,:,index], y_pred[:,:,:,:,index])
760
764
  dice = dice / (nb_labels-1)
761
- loss = -K.log(dice)
762
- #loss = 1 - dice
765
+ #loss = -K.log(dice)
766
+ loss = 1 - dice
763
767
  return loss
764
768
  return loss_fn
765
769
 
@@ -772,11 +776,13 @@ def train_semantic_segmentation(bm,
772
776
 
773
777
  # training data
774
778
  img, label, allLabels, normalization_parameters, header, extension, bm.channels = load_training_data(bm.normalize,
775
- img_list, label_list, None, bm.x_scale, bm.y_scale, bm.z_scale, bm.no_scaling, bm.crop_data,
779
+ img_list, label_list, None, bm.x_scale, bm.y_scale, bm.z_scale, bm.scaling, bm.crop_data,
776
780
  bm.only, bm.ignore, img, label, None, None, header, extension)
777
781
 
778
782
  # configuration data
779
- configuration_data = np.array([bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, 0, 1])
783
+ configuration_data = np.array([bm.channels,
784
+ bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize,
785
+ normalization_parameters[0,0], normalization_parameters[1,0]])
780
786
 
781
787
  # img shape
782
788
  zsh, ysh, xsh, _ = img.shape
@@ -784,7 +790,7 @@ def train_semantic_segmentation(bm,
784
790
  # validation data
785
791
  if any(val_img_list) or img_val is not None:
786
792
  img_val, label_val, _, _, _, _, _ = load_training_data(bm.normalize,
787
- val_img_list, val_label_list, bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.no_scaling, bm.crop_data,
793
+ val_img_list, val_label_list, bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.scaling, bm.crop_data,
788
794
  bm.only, bm.ignore, img_val, label_val, normalization_parameters, allLabels)
789
795
 
790
796
  elif bm.validation_split:
@@ -908,7 +914,7 @@ def train_semantic_segmentation(bm,
908
914
  # save meta data
909
915
  meta_data = MetaData(bm.path_to_model, configuration_data, allLabels,
910
916
  extension, header, bm.crop_data, bm.cropping_weights, bm.cropping_config,
911
- normalization_parameters, bm.cropping_norm)
917
+ normalization_parameters, bm.cropping_norm, bm.patch_normalization, bm.scaling)
912
918
 
913
919
  # model checkpoint
914
920
  if img_val is not None:
@@ -946,22 +952,33 @@ def train_semantic_segmentation(bm,
946
952
  if img_val is not None and not bm.val_dice:
947
953
  save_history(history.history, bm.path_to_model, False, bm.train_dice)
948
954
 
949
- def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
950
- no_scaling, normalize, normalization_parameters, region_of_interest,
951
- img, img_header):
955
+ def load_prediction_data(bm, channels, normalize, normalization_parameters,
956
+ region_of_interest, img, img_header, load_blockwise=False, z=None):
957
+
952
958
  # read image data
953
959
  if img is None:
954
- img, img_header = load_data(path_to_img, 'first_queue')
960
+ if load_blockwise:
961
+ img_header = None
962
+ tif = TiffFile(bm.path_to_image)
963
+ img = imread(bm.path_to_image, key=range(z,min(len(tif.pages),z+bm.z_patch)))
964
+ if img.shape[0] < bm.z_patch:
965
+ rest = bm.z_patch - img.shape[0]
966
+ tmp = imread(bm.path_to_image, key=range(len(tif.pages)-rest,len(tif.pages)))
967
+ img = np.append(img, tmp[::-1], axis=0)
968
+ else:
969
+ img, img_header = load_data(bm.path_to_image, 'first_queue')
955
970
 
956
971
  # verify validity
957
972
  if img is None:
958
- InputError.message = f'Invalid image data: {os.path.basename(path_to_img)}.'
973
+ InputError.message = f'Invalid image data: {os.path.basename(bm.path_to_image)}.'
959
974
  raise InputError()
960
975
 
961
- # preserve original image data
962
- img_data = img.copy()
976
+ # preserve original image data for post-processing
977
+ img_data = None
978
+ if bm.acwe:
979
+ img_data = img.copy()
963
980
 
964
- # handle all images having channels >=1
981
+ # handle all images using number of channels >=1
965
982
  if len(img.shape)==3:
966
983
  z_shape, y_shape, x_shape = img.shape
967
984
  img = img.reshape(z_shape, y_shape, x_shape, 1)
@@ -969,7 +986,7 @@ def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
969
986
  InputError.message = f'Number of channels must be {channels}.'
970
987
  raise InputError()
971
988
 
972
- # image shape
989
+ # original image shape
973
990
  z_shape, y_shape, x_shape, _ = img.shape
974
991
 
975
992
  # automatic cropping of image to region of interest
@@ -981,8 +998,8 @@ def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
981
998
 
982
999
  # scale/resize image data
983
1000
  img = img.astype(np.float32)
984
- if not no_scaling:
985
- img = img_resize(img, z_scale, y_scale, x_scale)
1001
+ if bm.scaling:
1002
+ img = img_resize(img, bm.z_scale, bm.y_scale, bm.x_scale)
986
1003
 
987
1004
  # normalize image data
988
1005
  for c in range(channels):
@@ -1000,94 +1017,256 @@ def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
1000
1017
 
1001
1018
  return img, img_header, z_shape, y_shape, x_shape, region_of_interest, img_data
1002
1019
 
1003
- def predict_semantic_segmentation(bm, img, path_to_model,
1004
- z_patch, y_patch, x_patch, z_shape, y_shape, x_shape, compress, header,
1005
- img_header, stride_size, allLabels, batch_size, region_of_interest,
1006
- no_scaling, extension, img_data):
1020
+ def append_ghost_areas(bm, img):
1021
+ # append ghost areas to make image dimensions divisible by patch size (mirror edge areas)
1022
+ zsh, ysh, xsh, _ = img.shape
1023
+ z_rest = bm.z_patch - (zsh % bm.z_patch)
1024
+ if z_rest == bm.z_patch:
1025
+ z_rest = -zsh
1026
+ else:
1027
+ img = np.append(img, img[-z_rest:][::-1], axis=0)
1028
+ y_rest = bm.y_patch - (ysh % bm.y_patch)
1029
+ if y_rest == bm.y_patch:
1030
+ y_rest = -ysh
1031
+ else:
1032
+ img = np.append(img, img[:,-y_rest:][:,::-1], axis=1)
1033
+ x_rest = bm.x_patch - (xsh % bm.x_patch)
1034
+ if x_rest == bm.x_patch:
1035
+ x_rest = -xsh
1036
+ else:
1037
+ img = np.append(img, img[:,:,-x_rest:][:,:,::-1], axis=2)
1038
+ return img, z_rest, y_rest, x_rest
1007
1039
 
1008
- results = {}
1040
+ def predict_semantic_segmentation(bm,
1041
+ header, img_header,
1042
+ region_of_interest, extension, img_data,
1043
+ channels, normalization_parameters):
1009
1044
 
1010
- # img shape
1011
- zsh, ysh, xsh, csh = img.shape
1045
+ # initialize results
1046
+ results = {}
1012
1047
 
1013
1048
  # number of labels
1014
- nb_labels = len(allLabels)
1049
+ nb_labels = len(bm.allLabels)
1050
+ results['allLabels'] = bm.allLabels
1015
1051
 
1016
- # list of IDs
1017
- list_IDs = []
1052
+ # load model
1053
+ if bm.dice_loss:
1054
+ def loss_fn(y_true, y_pred):
1055
+ dice = 0
1056
+ for index in range(1, nb_labels):
1057
+ dice += dice_coef(y_true[:,:,:,:,index], y_pred[:,:,:,:,index])
1058
+ dice = dice / (nb_labels-1)
1059
+ #loss = -K.log(dice)
1060
+ loss = 1 - dice
1061
+ return loss
1062
+ custom_objects = {'dice_coef_loss': dice_coef_loss,'loss_fn': loss_fn}
1063
+ model = load_model(bm.path_to_model, custom_objects=custom_objects)
1064
+ else:
1065
+ model = load_model(bm.path_to_model)
1066
+
1067
+ # check if data can be loaded blockwise to save host memory
1068
+ load_blockwise = False
1069
+ if not bm.scaling and not bm.normalize and bm.path_to_image and not np.any(region_of_interest) and \
1070
+ os.path.splitext(bm.path_to_image)[1] in ['.tif', '.tiff'] and not bm.acwe:
1071
+ tif = TiffFile(bm.path_to_image)
1072
+ zsh = len(tif.pages)
1073
+ ysh, xsh = tif.pages[0].shape
1074
+
1075
+ # determine new image size after appending ghost areas to make image dimensions divisible by patch size
1076
+ z_rest = bm.z_patch - (zsh % bm.z_patch)
1077
+ if z_rest == bm.z_patch:
1078
+ z_rest = -zsh
1079
+ else:
1080
+ zsh += z_rest
1081
+ y_rest = bm.y_patch - (ysh % bm.y_patch)
1082
+ if y_rest == bm.y_patch:
1083
+ y_rest = -ysh
1084
+ else:
1085
+ ysh += y_rest
1086
+ x_rest = bm.x_patch - (xsh % bm.x_patch)
1087
+ if x_rest == bm.x_patch:
1088
+ x_rest = -xsh
1089
+ else:
1090
+ xsh += x_rest
1091
+
1092
+ # get Ids of patches
1093
+ list_IDs = []
1094
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1095
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1096
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1097
+ list_IDs.append(k*ysh*xsh+l*xsh+m)
1018
1098
 
1019
- # get Ids of patches
1020
- for k in range(0, zsh-z_patch+1, stride_size):
1021
- for l in range(0, ysh-y_patch+1, stride_size):
1022
- for m in range(0, xsh-x_patch+1, stride_size):
1023
- list_IDs.append(k*ysh*xsh+l*xsh+m)
1099
+ # make length of list divisible by batch size
1100
+ rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
1101
+ list_IDs = list_IDs + list_IDs[:rest]
1024
1102
 
1025
- # make length of list divisible by batch size
1026
- rest = batch_size - (len(list_IDs) % batch_size)
1027
- list_IDs = list_IDs + list_IDs[:rest]
1103
+ # prediction
1104
+ if len(list_IDs) > 400:
1105
+ load_blockwise = True
1028
1106
 
1029
- # number of patches
1030
- nb_patches = len(list_IDs)
1107
+ # load image data and calculate patch IDs
1108
+ if not load_blockwise:
1031
1109
 
1032
- # parameters
1033
- params = {'dim': (z_patch, y_patch, x_patch),
1034
- 'dim_img': (zsh, ysh, xsh),
1035
- 'batch_size': batch_size,
1036
- 'n_channels': csh,
1037
- 'patch_normalization': bm.patch_normalization}
1110
+ # load prediction data
1111
+ img, img_header, z_shape, y_shape, x_shape, region_of_interest, img_data = load_prediction_data(
1112
+ bm, channels, bm.normalize, normalization_parameters, region_of_interest, img_data, img_header)
1038
1113
 
1039
- # data generator
1040
- predict_generator = PredictDataGenerator(img, list_IDs, **params)
1114
+ # append ghost areas
1115
+ img, z_rest, y_rest, x_rest = append_ghost_areas(bm, img)
1041
1116
 
1042
- # load model
1043
- model = load_model(str(path_to_model))
1117
+ # img shape
1118
+ zsh, ysh, xsh, _ = img.shape
1119
+
1120
+ # list of IDs
1121
+ list_IDs = []
1122
+
1123
+ # get Ids of patches
1124
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1125
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1126
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1127
+ list_IDs.append(k*ysh*xsh+l*xsh+m)
1128
+
1129
+ # make length of list divisible by batch size
1130
+ rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
1131
+ list_IDs = list_IDs + list_IDs[:rest]
1132
+
1133
+ # number of patches
1134
+ nb_patches = len(list_IDs)
1135
+
1136
+ # load all patches on GPU memory
1137
+ if not load_blockwise and nb_patches < 400:
1138
+
1139
+ # parameters
1140
+ params = {'dim': (bm.z_patch, bm.y_patch, bm.x_patch),
1141
+ 'dim_img': (zsh, ysh, xsh),
1142
+ 'batch_size': bm.batch_size,
1143
+ 'n_channels': channels,
1144
+ 'patch_normalization': bm.patch_normalization}
1044
1145
 
1045
- # predict
1046
- if nb_patches < 400:
1146
+ # data generator
1147
+ predict_generator = PredictDataGenerator(img, list_IDs, **params)
1148
+
1149
+ # predict probabilities
1047
1150
  probabilities = model.predict(predict_generator, verbose=0, steps=None)
1151
+
1152
+ # create final
1153
+ final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1154
+ nb = 0
1155
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1156
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1157
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1158
+ final[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += probabilities[nb]
1159
+ nb += 1
1160
+
1161
+ # calculate result
1162
+ label = np.argmax(final, axis=-1).astype(np.uint8)
1163
+
1048
1164
  else:
1049
- X = np.empty((batch_size, z_patch, y_patch, x_patch, csh), dtype=np.float32)
1050
- probabilities = np.zeros((nb_patches, z_patch, y_patch, x_patch, nb_labels), dtype=np.float32)
1051
-
1052
- # get image patches
1053
- for step in range(nb_patches//batch_size):
1054
- for i, ID in enumerate(list_IDs[step*batch_size:(step+1)*batch_size]):
1055
-
1056
- # get patch indices
1057
- k = ID // (ysh*xsh)
1058
- rest = ID % (ysh*xsh)
1059
- l = rest // xsh
1060
- m = rest % xsh
1061
-
1062
- # get patch
1063
- tmp_X = img[k:k+z_patch,l:l+y_patch,m:m+x_patch]
1064
- if bm.patch_normalization:
1065
- tmp_X = np.copy(tmp_X, order='C')
1066
- for c in range(csh):
1067
- tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
1068
- tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
1069
- X[i] = tmp_X
1070
-
1071
- probabilities[step*batch_size:(step+1)*batch_size] = model.predict(X, verbose=0, steps=None, batch_size=batch_size)
1072
-
1073
- # create final
1074
- final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1075
- if bm.return_probs:
1076
- counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1077
- nb = 0
1078
- for k in range(0, zsh-z_patch+1, stride_size):
1079
- for l in range(0, ysh-y_patch+1, stride_size):
1080
- for m in range(0, xsh-x_patch+1, stride_size):
1081
- final[k:k+z_patch, l:l+y_patch, m:m+x_patch] += probabilities[nb]
1165
+ # stream data batchwise to GPU to reduce memory usage
1166
+ X = np.empty((bm.batch_size, bm.z_patch, bm.y_patch, bm.x_patch, channels), dtype=np.float32)
1167
+
1168
+ # allocate final array
1169
+ if bm.return_probs:
1170
+ final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1171
+
1172
+ # allocate result array
1173
+ label = np.zeros((zsh, ysh, xsh), dtype=np.uint8)
1174
+
1175
+ # predict segmentation block by block
1176
+ z_indices = range(0, zsh-bm.z_patch+1, bm.stride_size)
1177
+ for j, z in enumerate(z_indices):
1178
+
1179
+ # load blockwise
1180
+ if load_blockwise:
1181
+ img, _, _, _, _, _, _ = load_prediction_data(bm,
1182
+ channels, bm.normalize, normalization_parameters,
1183
+ region_of_interest, img_data, img_header, load_blockwise, z)
1184
+ img, _, _, _ = append_ghost_areas(bm, img)
1185
+
1186
+ # list of IDs
1187
+ list_IDs = []
1188
+
1189
+ # get Ids of patches
1190
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1191
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1192
+ list_IDs.append(z*ysh*xsh+l*xsh+m)
1193
+
1194
+ # make length of list divisible by batch size
1195
+ max_i = len(list_IDs)
1196
+ rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
1197
+ list_IDs = list_IDs + list_IDs[:rest]
1198
+
1199
+ # number of patches
1200
+ nb_patches = len(list_IDs)
1201
+
1202
+ # allocate tmp probabilities array
1203
+ probs = np.zeros((bm.z_patch, ysh, xsh, nb_labels), dtype=np.float32)
1204
+
1205
+ # get one batch of image patches
1206
+ for step in range(nb_patches//bm.batch_size):
1207
+ for i, ID in enumerate(list_IDs[step*bm.batch_size:(step+1)*bm.batch_size]):
1208
+
1209
+ # get patch indices
1210
+ k=0 if load_blockwise else ID // (ysh*xsh)
1211
+ rest = ID % (ysh*xsh)
1212
+ l = rest // xsh
1213
+ m = rest % xsh
1214
+
1215
+ # get patch
1216
+ tmp_X = img[k:k+bm.z_patch,l:l+bm.y_patch,m:m+bm.x_patch]
1217
+ if bm.patch_normalization:
1218
+ tmp_X = np.copy(tmp_X, order='C')
1219
+ for c in range(channels):
1220
+ tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
1221
+ tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
1222
+ X[i] = tmp_X
1223
+
1224
+ # predict batch
1225
+ Y = model.predict(X, verbose=0, steps=None, batch_size=bm.batch_size)
1226
+
1227
+ # loop over result patches
1228
+ for i, ID in enumerate(list_IDs[step*bm.batch_size:(step+1)*bm.batch_size]):
1229
+ rest = ID % (ysh*xsh)
1230
+ l = rest // xsh
1231
+ m = rest % xsh
1232
+ if step*bm.batch_size+i < max_i:
1233
+ probs[:,l:l+bm.y_patch,m:m+bm.x_patch] += Y[i]
1234
+
1235
+ # overlap in z direction
1236
+ if bm.stride_size < bm.z_patch:
1237
+ if j>0:
1238
+ probs[:bm.stride_size] += overlap
1239
+ overlap = probs[bm.stride_size:].copy()
1240
+
1241
+ # calculate result
1242
+ if z==z_indices[-1]:
1243
+ label[z:z+bm.z_patch] = np.argmax(probs, axis=-1).astype(np.uint8)
1082
1244
  if bm.return_probs:
1083
- counter[k:k+z_patch, l:l+y_patch, m:m+x_patch] += 1
1084
- nb += 1
1245
+ final[z:z+bm.z_patch] = probs
1246
+ else:
1247
+ block_zsh = min(bm.stride_size, bm.z_patch)
1248
+ label[z:z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
1249
+ if bm.return_probs:
1250
+ final[z:z+block_zsh] = probs[:block_zsh]
1251
+
1252
+ # remove appendix
1253
+ if bm.return_probs:
1254
+ final = final[:-z_rest,:-y_rest,:-x_rest]
1255
+ label = label[:-z_rest,:-y_rest,:-x_rest]
1256
+ zsh, ysh, xsh = label.shape
1085
1257
 
1086
1258
  # return probabilities
1087
1259
  if bm.return_probs:
1260
+ counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1261
+ nb = 0
1262
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1263
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1264
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1265
+ counter[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += 1
1266
+ nb += 1
1088
1267
  counter[counter==0] = 1
1089
1268
  probabilities = final / counter
1090
- if not no_scaling:
1269
+ if bm.scaling:
1091
1270
  probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
1092
1271
  if np.any(region_of_interest):
1093
1272
  min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
@@ -1096,12 +1275,8 @@ def predict_semantic_segmentation(bm, img, path_to_model,
1096
1275
  probabilities = np.copy(tmp, order='C')
1097
1276
  results['probs'] = probabilities
1098
1277
 
1099
- # get final
1100
- label = np.argmax(final, axis=3)
1101
- label = label.astype(np.uint8)
1102
-
1103
1278
  # rescale final to input size
1104
- if not no_scaling:
1279
+ if bm.scaling:
1105
1280
  label = img_resize(label, z_shape, y_shape, x_shape, labels=True)
1106
1281
 
1107
1282
  # revert automatic cropping
@@ -1112,7 +1287,7 @@ def predict_semantic_segmentation(bm, img, path_to_model,
1112
1287
  label = np.copy(tmp, order='C')
1113
1288
 
1114
1289
  # get result
1115
- label = get_labels(label, allLabels)
1290
+ label = get_labels(label, bm.allLabels)
1116
1291
  results['regular'] = label
1117
1292
 
1118
1293
  # load header from file
@@ -1130,10 +1305,10 @@ def predict_semantic_segmentation(bm, img, path_to_model,
1130
1305
  # handle amira header
1131
1306
  if header is not None:
1132
1307
  if extension == '.am':
1133
- header = get_image_dimensions(header[0], label)
1308
+ header = set_image_dimensions(header[0], label)
1134
1309
  if img_header is not None:
1135
1310
  try:
1136
- header = get_physical_size(header, img_header[0])
1311
+ header = set_physical_size(header, img_header[0])
1137
1312
  except:
1138
1313
  pass
1139
1314
  header = [header]
@@ -1151,7 +1326,7 @@ def predict_semantic_segmentation(bm, img, path_to_model,
1151
1326
 
1152
1327
  # save result
1153
1328
  if bm.path_to_image:
1154
- save_data(bm.path_to_final, label, header=header, compress=compress)
1329
+ save_data(bm.path_to_final, label, header=header, compress=bm.compression)
1155
1330
 
1156
1331
  # paths to optional results
1157
1332
  filename, extension = os.path.splitext(bm.path_to_final)
@@ -1169,17 +1344,17 @@ def predict_semantic_segmentation(bm, img, path_to_model,
1169
1344
  cleaned_result = clean(label, bm.clean)
1170
1345
  results['cleaned'] = cleaned_result
1171
1346
  if bm.path_to_image:
1172
- save_data(path_to_cleaned, cleaned_result, header=header, compress=compress)
1347
+ save_data(path_to_cleaned, cleaned_result, header=header, compress=bm.compression)
1173
1348
  if bm.fill:
1174
1349
  filled_result = clean(label, bm.fill)
1175
1350
  results['filled'] = filled_result
1176
1351
  if bm.path_to_image:
1177
- save_data(path_to_filled, filled_result, header=header, compress=compress)
1352
+ save_data(path_to_filled, filled_result, header=header, compress=bm.compression)
1178
1353
  if bm.clean and bm.fill:
1179
1354
  cleaned_filled_result = cleaned_result + (filled_result - label)
1180
1355
  results['cleaned_filled'] = cleaned_filled_result
1181
1356
  if bm.path_to_image:
1182
- save_data(path_to_cleaned_filled, cleaned_filled_result, header=header, compress=compress)
1357
+ save_data(path_to_cleaned_filled, cleaned_filled_result, header=header, compress=bm.compression)
1183
1358
 
1184
1359
  # post-processing with active contour
1185
1360
  if bm.acwe:
@@ -1188,8 +1363,8 @@ def predict_semantic_segmentation(bm, img, path_to_model,
1188
1363
  results['acwe'] = acwe_result
1189
1364
  results['refined'] = refined_result
1190
1365
  if bm.path_to_image:
1191
- save_data(path_to_acwe, acwe_result, header=header, compress=compress)
1192
- save_data(path_to_refined, refined_result, header=header, compress=compress)
1366
+ save_data(path_to_acwe, acwe_result, header=header, compress=bm.compression)
1367
+ save_data(path_to_refined, refined_result, header=header, compress=bm.compression)
1193
1368
 
1194
1369
  return results, bm
1195
1370