biomedisa 24.5.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. biomedisa/__init__.py +49 -0
  2. biomedisa/__main__.py +18 -0
  3. biomedisa/deeplearning.py +529 -0
  4. biomedisa/features/DataGenerator.py +299 -0
  5. biomedisa/features/DataGeneratorCrop.py +121 -0
  6. biomedisa/features/PredictDataGenerator.py +87 -0
  7. biomedisa/features/PredictDataGeneratorCrop.py +74 -0
  8. biomedisa/features/__init__.py +0 -0
  9. biomedisa/features/active_contour.py +430 -0
  10. biomedisa/features/amira_to_np/__init__.py +0 -0
  11. biomedisa/features/amira_to_np/amira_data_stream.py +980 -0
  12. biomedisa/features/amira_to_np/amira_grammar.py +369 -0
  13. biomedisa/features/amira_to_np/amira_header.py +290 -0
  14. biomedisa/features/amira_to_np/amira_helper.py +72 -0
  15. biomedisa/features/assd.py +167 -0
  16. biomedisa/features/biomedisa_helper.py +842 -0
  17. biomedisa/features/create_slices.py +277 -0
  18. biomedisa/features/crop_helper.py +581 -0
  19. biomedisa/features/curvop_numba.py +149 -0
  20. biomedisa/features/django_env.py +171 -0
  21. biomedisa/features/keras_helper.py +1195 -0
  22. biomedisa/features/nc_reader.py +179 -0
  23. biomedisa/features/pid.py +52 -0
  24. biomedisa/features/process_image.py +251 -0
  25. biomedisa/features/pycuda_test.py +85 -0
  26. biomedisa/features/random_walk/__init__.py +0 -0
  27. biomedisa/features/random_walk/gpu_kernels.py +184 -0
  28. biomedisa/features/random_walk/pycuda_large.py +826 -0
  29. biomedisa/features/random_walk/pycuda_large_allx.py +806 -0
  30. biomedisa/features/random_walk/pycuda_small.py +414 -0
  31. biomedisa/features/random_walk/pycuda_small_allx.py +493 -0
  32. biomedisa/features/random_walk/pyopencl_large.py +760 -0
  33. biomedisa/features/random_walk/pyopencl_small.py +441 -0
  34. biomedisa/features/random_walk/rw_large.py +389 -0
  35. biomedisa/features/random_walk/rw_small.py +307 -0
  36. biomedisa/features/remove_outlier.py +396 -0
  37. biomedisa/features/split_volume.py +167 -0
  38. biomedisa/interpolation.py +369 -0
  39. biomedisa/mesh.py +403 -0
  40. biomedisa-24.5.23.dist-info/LICENSE +191 -0
  41. biomedisa-24.5.23.dist-info/METADATA +261 -0
  42. biomedisa-24.5.23.dist-info/RECORD +44 -0
  43. biomedisa-24.5.23.dist-info/WHEEL +5 -0
  44. biomedisa-24.5.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1195 @@
1
+ ##########################################################################
2
+ ## ##
3
+ ## Copyright (c) 2019-2024 Philipp Lösel. All rights reserved. ##
4
+ ## ##
5
+ ## This file is part of the open source project biomedisa. ##
6
+ ## ##
7
+ ## Licensed under the European Union Public Licence (EUPL) ##
8
+ ## v1.2, or - as soon as they will be approved by the ##
9
+ ## European Commission - subsequent versions of the EUPL; ##
10
+ ## ##
11
+ ## You may redistribute it and/or modify it under the terms ##
12
+ ## of the EUPL v1.2. You may not use this work except in ##
13
+ ## compliance with this Licence. ##
14
+ ## ##
15
+ ## You can obtain a copy of the Licence at: ##
16
+ ## ##
17
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
18
+ ## ##
19
+ ## Unless required by applicable law or agreed to in ##
20
+ ## writing, software distributed under the Licence is ##
21
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
22
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
23
+ ## ##
24
+ ## See the Licence for the specific language governing ##
25
+ ## permissions and limitations under the Licence. ##
26
+ ## ##
27
+ ##########################################################################
28
+
29
+ import os
30
+ try:
31
+ from tensorflow.keras.optimizers.legacy import SGD
32
+ except:
33
+ from tensorflow.keras.optimizers import SGD
34
+ from tensorflow.keras.models import Model, load_model
35
+ from tensorflow.keras.layers import (
36
+ Input, Conv3D, MaxPooling3D, UpSampling3D, Activation, Reshape,
37
+ BatchNormalization, Concatenate, ReLU, Add, GlobalAveragePooling3D,
38
+ Dense, Dropout, MaxPool3D, Flatten, Multiply)
39
+ from tensorflow.keras import backend as K
40
+ from tensorflow.keras.utils import to_categorical
41
+ from tensorflow.keras.callbacks import Callback, ModelCheckpoint, EarlyStopping
42
+ from biomedisa.features.DataGenerator import DataGenerator
43
+ from biomedisa.features.PredictDataGenerator import PredictDataGenerator
44
+ from biomedisa.features.biomedisa_helper import (
45
+ img_resize, load_data, save_data, set_labels_to_zero, id_generator, unique_file_path)
46
+ from biomedisa.features.remove_outlier import clean, fill
47
+ from biomedisa.features.active_contour import activeContour
48
+ import matplotlib.pyplot as plt
49
+ import SimpleITK as sitk
50
+ import tensorflow as tf
51
+ import numpy as np
52
+ import cv2
53
+ import tarfile
54
+ from random import shuffle
55
+ import glob
56
+ import random
57
+ import numba
58
+ import re
59
+ import time
60
+ import h5py
61
+ import atexit
62
+ import tempfile
63
+
64
+ class InputError(Exception):
65
+ def __init__(self, message=None):
66
+ self.message = message
67
+
68
+ def save_history(history, path_to_model, val_dice, train_dice):
69
+ # summarize history for accuracy
70
+ plt.plot(history['accuracy'])
71
+ plt.plot(history['val_accuracy'])
72
+ if val_dice and train_dice:
73
+ plt.plot(history['dice'])
74
+ plt.plot(history['val_dice'])
75
+ plt.legend(['Accuracy (train)', 'Accuracy (test)', 'Dice score (train)', 'Dice score (test)'], loc='upper left')
76
+ elif train_dice:
77
+ plt.plot(history['dice'])
78
+ plt.legend(['Accuracy (train)', 'Accuracy (test)', 'Dice score (train)'], loc='upper left')
79
+ elif val_dice:
80
+ plt.plot(history['val_dice'])
81
+ plt.legend(['Accuracy (train)', 'Accuracy (test)', 'Dice score (test)'], loc='upper left')
82
+ else:
83
+ plt.legend(['Accuracy (train)', 'Accuracy (test)'], loc='upper left')
84
+ plt.title('model accuracy')
85
+ plt.ylabel('accuracy')
86
+ plt.xlabel('epoch')
87
+ plt.tight_layout() # To prevent overlapping of subplots
88
+ plt.savefig(path_to_model.replace('.h5','_acc.png'), dpi=300, bbox_inches='tight')
89
+ plt.clf()
90
+ # summarize history for loss
91
+ plt.plot(history['loss'])
92
+ plt.plot(history['val_loss'])
93
+ plt.title('model loss')
94
+ plt.ylabel('loss')
95
+ plt.xlabel('epoch')
96
+ plt.legend(['train', 'test'], loc='upper left')
97
+ plt.tight_layout() # To prevent overlapping of subplots
98
+ plt.savefig(path_to_model.replace('.h5','_loss.png'), dpi=300, bbox_inches='tight')
99
+ plt.clf()
100
+ # save history dictonary
101
+ np.save(path_to_model.replace('.h5','.npy'), history)
102
+
103
+ def predict_blocksize(labelData, x_puffer, y_puffer, z_puffer):
104
+ zsh, ysh, xsh = labelData.shape
105
+ argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x = zsh, 0, ysh, 0, xsh, 0
106
+ for k in range(zsh):
107
+ y, x = np.nonzero(labelData[k])
108
+ if x.any():
109
+ argmin_x = min(argmin_x, np.amin(x))
110
+ argmax_x = max(argmax_x, np.amax(x))
111
+ argmin_y = min(argmin_y, np.amin(y))
112
+ argmax_y = max(argmax_y, np.amax(y))
113
+ argmin_z = min(argmin_z, k)
114
+ argmax_z = max(argmax_z, k)
115
+ zmin, zmax = argmin_z, argmax_z
116
+ argmin_x = argmin_x - x_puffer if argmin_x - x_puffer > 0 else 0
117
+ argmax_x = argmax_x + x_puffer if argmax_x + x_puffer < xsh else xsh
118
+ argmin_y = argmin_y - y_puffer if argmin_y - y_puffer > 0 else 0
119
+ argmax_y = argmax_y + y_puffer if argmax_y + y_puffer < ysh else ysh
120
+ argmin_z = argmin_z - z_puffer if argmin_z - z_puffer > 0 else 0
121
+ argmax_z = argmax_z + z_puffer if argmax_z + z_puffer < zsh else zsh
122
+ return argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x
123
+
124
+ def get_image_dimensions(header, data):
125
+
126
+ # read header as string
127
+ b = header.tobytes()
128
+ try:
129
+ s = b.decode("utf-8")
130
+ except:
131
+ s = b.decode("latin1")
132
+
133
+ # get image size in header
134
+ lattice = re.search('define Lattice (.*)\n', s)
135
+ lattice = lattice.group(1)
136
+ xsh, ysh, zsh = lattice.split(' ')
137
+ xsh, ysh, zsh = int(xsh), int(ysh), int(zsh)
138
+
139
+ # new image size
140
+ z,y,x = data.shape
141
+
142
+ # change image size in header
143
+ s = s.replace('%s %s %s' %(xsh,ysh,zsh), '%s %s %s' %(x,y,z),1)
144
+ s = s.replace('Content "%sx%sx%s byte' %(xsh,ysh,zsh), 'Content "%sx%sx%s byte' %(x,y,z),1)
145
+
146
+ # return header as array
147
+ b2 = s.encode()
148
+ new_header = np.frombuffer(b2, dtype=header.dtype)
149
+ return new_header
150
+
151
+ def get_physical_size(header, img_header):
152
+
153
+ # read img_header as string
154
+ b = img_header.tobytes()
155
+ try:
156
+ s = b.decode("utf-8")
157
+ except:
158
+ s = b.decode("latin1")
159
+
160
+ # get physical size from image header
161
+ bounding_box = re.search('BoundingBox (.*),\n', s)
162
+ bounding_box = bounding_box.group(1)
163
+ i0, i1, i2, i3, i4, i5 = bounding_box.split(' ')
164
+ bounding_box_i = re.search('&BoundingBox (.*),\n', s)
165
+ bounding_box_i = bounding_box_i.group(1)
166
+
167
+ # read label header as string
168
+ b = header.tobytes()
169
+ try:
170
+ s = b.decode("utf-8")
171
+ except:
172
+ s = b.decode("latin1")
173
+
174
+ # get physical size from label header
175
+ bounding_box = re.search('BoundingBox (.*),\n', s)
176
+ bounding_box = bounding_box.group(1)
177
+ l0, l1, l2, l3, l4, l5 = bounding_box.split(' ')
178
+ bounding_box_l = re.search('&BoundingBox (.*),\n', s)
179
+ bounding_box_l = bounding_box_l.group(1)
180
+
181
+ # change physical size in label header
182
+ s = s.replace('%s %s %s %s %s %s' %(l0,l1,l2,l3,l4,l5),'%s %s %s %s %s %s' %(i0,i1,i2,i3,i4,i5),1)
183
+ s = s.replace(bounding_box_l,bounding_box_i,1)
184
+
185
+ # return header as array
186
+ b2 = s.encode()
187
+ new_header = np.frombuffer(b2, dtype=header.dtype)
188
+ return new_header
189
+
190
+ @numba.jit(nopython=True)
191
+ def compute_position(position, zsh, ysh, xsh):
192
+ zsh_h, ysh_h, xsh_h = zsh//2, ysh//2, xsh//2
193
+ for k in range(zsh):
194
+ for l in range(ysh):
195
+ for m in range(xsh):
196
+ x = (xsh_h-m)**2
197
+ y = (ysh_h-l)**2
198
+ z = (zsh_h-k)**2
199
+ position[k,l,m] = x+y+z
200
+ return position
201
+
202
+ def make_conv_block(nb_filters, input_tensor, block):
203
+ def make_stage(input_tensor, stage):
204
+ name = 'conv_{}_{}'.format(block, stage)
205
+ x = Conv3D(nb_filters, (3, 3, 3), activation='relu',
206
+ padding='same', name=name, data_format="channels_last")(input_tensor)
207
+ name = 'batch_norm_{}_{}'.format(block, stage)
208
+ try:
209
+ x = BatchNormalization(name=name, synchronized=True)(x)
210
+ except:
211
+ x = BatchNormalization(name=name)(x)
212
+ x = Activation('relu')(x)
213
+ return x
214
+
215
+ x = make_stage(input_tensor, 1)
216
+ x = make_stage(x, 2)
217
+ return x
218
+
219
+ def make_conv_block_resnet(nb_filters, input_tensor, block):
220
+
221
+ # Residual/Skip connection
222
+ res = Conv3D(nb_filters, (1, 1, 1), padding='same', use_bias=False, name="Identity{}_1".format(block))(input_tensor)
223
+
224
+ stage = 1
225
+ name = 'conv_{}_{}'.format(block, stage)
226
+ fx = Conv3D(nb_filters, (3, 3, 3), activation='relu', padding='same', name=name, data_format="channels_last")(input_tensor)
227
+ name = 'batch_norm_{}_{}'.format(block, stage)
228
+ try:
229
+ fx = BatchNormalization(name=name, synchronized=True)(fx)
230
+ except:
231
+ fx = BatchNormalization(name=name)(fx)
232
+ fx = Activation('relu')(fx)
233
+
234
+ stage = 2
235
+ name = 'conv_{}_{}'.format(block, stage)
236
+ fx = Conv3D(nb_filters, (3, 3, 3), padding='same', name=name, data_format="channels_last")(fx)
237
+ name = 'batch_norm_{}_{}'.format(block, stage)
238
+ try:
239
+ fx = BatchNormalization(name=name, synchronized=True)(fx)
240
+ except:
241
+ fx = BatchNormalization(name=name)(fx)
242
+
243
+ out = Add()([res,fx])
244
+ out = ReLU()(out)
245
+
246
+ return out
247
+
248
+ def make_unet(input_shape, nb_labels, filters='32-64-128-256-512', resnet=False):
249
+
250
+ nb_plans, nb_rows, nb_cols, _ = input_shape
251
+
252
+ inputs = Input(input_shape)
253
+
254
+ filters = filters.split('-')
255
+ filters = np.array(filters, dtype=int)
256
+ latent_space_size = filters[-1]
257
+ filters = filters[:-1]
258
+ convs = []
259
+
260
+ i = 1
261
+ for f in filters:
262
+ if i==1:
263
+ if resnet:
264
+ conv = make_conv_block_resnet(f, inputs, i)
265
+ else:
266
+ conv = make_conv_block(f, inputs, i)
267
+ else:
268
+ if resnet:
269
+ conv = make_conv_block_resnet(f, pool, i)
270
+ else:
271
+ conv = make_conv_block(f, pool, i)
272
+ pool = MaxPooling3D(pool_size=(2, 2, 2))(conv)
273
+ convs.append(conv)
274
+ i += 1
275
+
276
+ if resnet:
277
+ conv = make_conv_block_resnet(latent_space_size, pool, i)
278
+ else:
279
+ conv = make_conv_block(latent_space_size, pool, i)
280
+ i += 1
281
+
282
+ for k, f in enumerate(filters[::-1]):
283
+ up = Concatenate()([UpSampling3D(size=(2, 2, 2))(conv), convs[-(k+1)]])
284
+ if resnet:
285
+ conv = make_conv_block_resnet(f, up, i)
286
+ else:
287
+ conv = make_conv_block(f, up, i)
288
+ i += 1
289
+
290
+ conv = Conv3D(nb_labels, (1, 1, 1), name=f'conv_{i}_1')(conv)
291
+
292
+ x = Reshape((nb_plans * nb_rows * nb_cols, nb_labels))(conv)
293
+ x = Activation('softmax')(x)
294
+ outputs = Reshape((nb_plans, nb_rows, nb_cols, nb_labels))(x)
295
+
296
+ model = Model(inputs=inputs, outputs=outputs)
297
+
298
+ return model
299
+
300
+ def get_labels(arr, allLabels):
301
+ np_unique = np.unique(arr)
302
+ final = np.zeros_like(arr)
303
+ for k in np_unique:
304
+ final[arr == k] = allLabels[k]
305
+ return final
306
+
307
+ def read_img_list(img_list, label_list, temp_img_dir, temp_label_dir):
308
+ # read filenames
309
+ img_names, label_names = [], []
310
+ for img_name, label_name in zip(img_list, label_list):
311
+
312
+ # check for tarball
313
+ img_dir, img_ext = os.path.splitext(img_name)
314
+ if img_ext == '.gz':
315
+ img_ext = os.path.splitext(img_dir)[1]
316
+
317
+ label_dir, label_ext = os.path.splitext(label_name)
318
+ if label_ext == '.gz':
319
+ label_ext = os.path.splitext(label_dir)[1]
320
+
321
+ if (img_ext == '.tar' and label_ext == '.tar') or (os.path.isdir(img_name) and os.path.isdir(label_name)):
322
+
323
+ # extract files
324
+ if img_ext == '.tar':
325
+ tar = tarfile.open(img_name)
326
+ tar.extractall(path=temp_img_dir)
327
+ tar.close()
328
+ img_name = temp_img_dir
329
+ if label_ext == '.tar':
330
+ tar = tarfile.open(label_name)
331
+ tar.extractall(path=temp_label_dir)
332
+ tar.close()
333
+ label_name = temp_label_dir
334
+
335
+ for data_type in ['.am','.tif','.tiff','.hdr','.mhd','.mha','.nrrd','.nii','.nii.gz','.zip','.mrc']:
336
+ img_names += [file for file in glob.glob(img_name+'/**/*'+data_type, recursive=True) if not os.path.basename(file).startswith('.')]
337
+ label_names += [file for file in glob.glob(label_name+'/**/*'+data_type, recursive=True) if not os.path.basename(file).startswith('.')]
338
+ img_names = sorted(img_names)
339
+ label_names = sorted(label_names)
340
+ if len(img_names)==0 or len(label_names)==0 or len(img_names)!=len(label_names):
341
+ if len(img_names)!=len(label_names):
342
+ InputError.message = 'Number of image and label files must be the same'
343
+ elif img_ext == '.tar' and len(img_names)==0:
344
+ InputError.message = 'Invalid image TAR file'
345
+ elif label_ext == '.tar' and len(label_names)==0:
346
+ InputError.message = 'Invalid label TAR file'
347
+ elif len(img_names)==0:
348
+ InputError.message = 'Invalid image data'
349
+ else:
350
+ InputError.message = 'Invalid label data'
351
+ raise InputError()
352
+ else:
353
+ img_names.append(img_name)
354
+ label_names.append(label_name)
355
+ return img_names, label_names
356
+
357
+ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_scale, z_scale, no_scaling,
358
+ crop_data, labels_to_compute, labels_to_remove, img_in=None, label_in=None,
359
+ normalization_parameters=None, allLabels=None, header=None, extension='.tif',
360
+ x_puffer=25, y_puffer=25, z_puffer=25):
361
+
362
+ # make temporary directories
363
+ with tempfile.TemporaryDirectory() as temp_img_dir:
364
+ with tempfile.TemporaryDirectory() as temp_label_dir:
365
+
366
+ # read image lists
367
+ if any(img_list):
368
+ img_names, label_names = read_img_list(img_list, label_list, temp_img_dir, temp_label_dir)
369
+
370
+ # load first label
371
+ if any(img_list):
372
+ label, header, extension = load_data(label_names[0], 'first_queue', True)
373
+ if label is None:
374
+ InputError.message = f'Invalid label data "{os.path.basename(label_names[0])}"'
375
+ raise InputError()
376
+ elif type(label_in) is list:
377
+ label = label_in[0]
378
+ label_names = [f'label_{i}' for i in range(1, len(label_in) + 1)]
379
+ else:
380
+ label = label_in
381
+ label_names = ['label_1']
382
+ label_dim = label.shape
383
+ label = set_labels_to_zero(label, labels_to_compute, labels_to_remove)
384
+ label_values, counts = np.unique(label, return_counts=True)
385
+ print(f'{os.path.basename(label_names[0])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
386
+ if crop_data:
387
+ argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(label, x_puffer, y_puffer, z_puffer)
388
+ label = np.copy(label[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
389
+ if not no_scaling:
390
+ label = img_resize(label, z_scale, y_scale, x_scale, labels=True)
391
+
392
+ # if header is not single data stream Amira Mesh falling back to Multi-TIFF
393
+ if extension != '.am':
394
+ if extension != '.tif':
395
+ print(f'Warning! Please use -hf or --header_file="path_to_training_label{extension}" for prediction to save your result as "{extension}"')
396
+ extension, header = '.tif', None
397
+ elif len(header) > 1:
398
+ print('Warning! Multiple data streams are not supported. Falling back to TIFF.')
399
+ extension, header = '.tif', None
400
+ else:
401
+ header = header[0]
402
+
403
+ # load first img
404
+ if any(img_list):
405
+ img, _ = load_data(img_names[0], 'first_queue')
406
+ if img is None:
407
+ InputError.message = f'Invalid image data "{os.path.basename(img_names[0])}"'
408
+ raise InputError()
409
+ elif type(img_in) is list:
410
+ img = img_in[0]
411
+ img_names = [f'img_{i}' for i in range(1, len(img_in) + 1)]
412
+ else:
413
+ img = img_in
414
+ img_names = ['img_1']
415
+ if label_dim != img.shape:
416
+ InputError.message = f'Dimensions of "{os.path.basename(img_names[0])}" and "{os.path.basename(label_names[0])}" do not match'
417
+ raise InputError()
418
+
419
+ # ensure images have channels >=1
420
+ if len(img.shape)==3:
421
+ z_shape, y_shape, x_shape = img.shape
422
+ img = img.reshape(z_shape, y_shape, x_shape, 1)
423
+ if channels is None:
424
+ channels = img.shape[3]
425
+ if channels != img.shape[3]:
426
+ InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[0])}"'
427
+ raise InputError()
428
+
429
+ # crop data
430
+ if crop_data:
431
+ img = np.copy(img[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
432
+
433
+ # scale/resize image data
434
+ img = img.astype(np.float32)
435
+ if not no_scaling:
436
+ img = img_resize(img, z_scale, y_scale, x_scale)
437
+
438
+ # normalize image data
439
+ for c in range(channels):
440
+ img[:,:,:,c] -= np.amin(img[:,:,:,c])
441
+ img[:,:,:,c] /= np.amax(img[:,:,:,c])
442
+ if normalization_parameters is None:
443
+ normalization_parameters = np.zeros((2,channels))
444
+ normalization_parameters[0,c] = np.mean(img[:,:,:,c])
445
+ normalization_parameters[1,c] = np.std(img[:,:,:,c])
446
+ elif normalize:
447
+ mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
448
+ img[:,:,:,c] = (img[:,:,:,c] - mean) / std
449
+ img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
450
+
451
+ # loop over list of images
452
+ if any(img_list) or type(img_in) is list:
453
+ number_of_images = len(img_names) if any(img_list) else len(img_in)
454
+
455
+ for k in range(1, number_of_images):
456
+
457
+ # append label
458
+ if any(label_list):
459
+ a, _ = load_data(label_names[k], 'first_queue')
460
+ if a is None:
461
+ InputError.message = f'Invalid label data "{os.path.basename(label_names[k])}"'
462
+ raise InputError()
463
+ else:
464
+ a = label_in[k]
465
+ label_dim = a.shape
466
+ a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
467
+ label_values, counts = np.unique(a, return_counts=True)
468
+ print(f'{os.path.basename(label_names[k])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
469
+ if crop_data:
470
+ argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(a, x_puffer, y_puffer, z_puffer)
471
+ a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
472
+ if not no_scaling:
473
+ a = img_resize(a, z_scale, y_scale, x_scale, labels=True)
474
+ label = np.append(label, a, axis=0)
475
+
476
+ # append image
477
+ if any(img_list):
478
+ a, _ = load_data(img_names[k], 'first_queue')
479
+ if a is None:
480
+ InputError.message = f'Invalid image data "{os.path.basename(img_names[k])}"'
481
+ raise InputError()
482
+ else:
483
+ a = img_in[k]
484
+ if label_dim != a.shape:
485
+ InputError.message = f'Dimensions of "{os.path.basename(img_names[k])}" and "{os.path.basename(label_names[k])}" do not match'
486
+ raise InputError()
487
+ if len(a.shape)==3:
488
+ z_shape, y_shape, x_shape = a.shape
489
+ a = a.reshape(z_shape, y_shape, x_shape, 1)
490
+ if a.shape[3] != channels:
491
+ InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
492
+ raise InputError()
493
+ if crop_data:
494
+ a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
495
+ a = a.astype(np.float32)
496
+ if not no_scaling:
497
+ a = img_resize(a, z_scale, y_scale, x_scale)
498
+ for c in range(channels):
499
+ a[:,:,:,c] -= np.amin(a[:,:,:,c])
500
+ a[:,:,:,c] /= np.amax(a[:,:,:,c])
501
+ if normalize:
502
+ mean, std = np.mean(a[:,:,:,c]), np.std(a[:,:,:,c])
503
+ a[:,:,:,c] = (a[:,:,:,c] - mean) / std
504
+ a[:,:,:,c] = a[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
505
+ img = np.append(img, a, axis=0)
506
+
507
+ # limit intensity range
508
+ img[img<0] = 0
509
+ img[img>1] = 1
510
+
511
+ # get labels
512
+ if allLabels is None:
513
+ allLabels = np.unique(label)
514
+
515
+ # labels must be in ascending order
516
+ for k, l in enumerate(allLabels):
517
+ label[label==l] = k
518
+
519
+ return img, label, allLabels, normalization_parameters, header, extension, channels
520
+
521
+ class CustomCallback(Callback):
522
+ def __init__(self, id, epochs):
523
+ self.epochs = epochs
524
+ self.id = id
525
+
526
+ def on_epoch_begin(self, batch, logs={}):
527
+ self.epoch_time_start = time.time()
528
+
529
+ def on_epoch_end(self, epoch, logs=None):
530
+ import django
531
+ django.setup()
532
+ from biomedisa_app.models import Upload
533
+ image = Upload.objects.get(pk=self.id)
534
+ if image.status == 3:
535
+ self.model.stop_training = True
536
+ else:
537
+ keys = list(logs.keys())
538
+ percentage = round((int(epoch)+1)*100/float(self.epochs))
539
+ t = round(time.time() - self.epoch_time_start) * (self.epochs-int(epoch)-1)
540
+ if t < 3600:
541
+ time_remaining = str(t // 60) + 'min'
542
+ else:
543
+ time_remaining = str(t // 3600) + 'h ' + str((t % 3600) // 60) + 'min'
544
+ try:
545
+ val_accuracy = round(float(logs["val_accuracy"])*100,1)
546
+ image.message = 'Progress {}%, {} remaining, {}% accuracy'.format(percentage,time_remaining,val_accuracy)
547
+ except KeyError:
548
+ image.message = 'Progress {}%, {} remaining'.format(percentage,time_remaining)
549
+ image.save()
550
+ print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
551
+
552
+ class MetaData(Callback):
553
+ def __init__(self, path_to_model, configuration_data, allLabels,
554
+ extension, header, crop_data, cropping_weights, cropping_config,
555
+ normalization_parameters, cropping_norm):
556
+
557
+ self.path_to_model = path_to_model
558
+ self.configuration_data = configuration_data
559
+ self.normalization_parameters = normalization_parameters
560
+ self.allLabels = allLabels
561
+ self.extension = extension
562
+ self.header = header
563
+ self.crop_data = crop_data
564
+ self.cropping_weights = cropping_weights
565
+ self.cropping_config = cropping_config
566
+ self.cropping_norm = cropping_norm
567
+
568
+ def on_epoch_end(self, epoch, logs={}):
569
+ hf = h5py.File(self.path_to_model, 'r')
570
+ if not '/meta' in hf:
571
+ hf.close()
572
+ hf = h5py.File(self.path_to_model, 'r+')
573
+ group = hf.create_group('meta')
574
+ group.create_dataset('configuration', data=self.configuration_data)
575
+ group.create_dataset('normalization', data=self.normalization_parameters)
576
+ group.create_dataset('labels', data=self.allLabels)
577
+ if self.extension == '.am':
578
+ group.create_dataset('extension', data=self.extension)
579
+ group.create_dataset('header', data=self.header)
580
+ if self.crop_data:
581
+ cm_group = hf.create_group('cropping_meta')
582
+ cm_group.create_dataset('configuration', data=self.cropping_config)
583
+ cm_group.create_dataset('normalization', data=self.cropping_norm)
584
+ cw_group = hf.create_group('cropping_weights')
585
+ for iterator, arr in enumerate(self.cropping_weights):
586
+ cw_group.create_dataset(str(iterator), data=arr)
587
+ hf.close()
588
+
589
+ class Metrics(Callback):
590
+ def __init__(self, bm, img, label, list_IDs, dim_img, n_classes, train):
591
+ self.dim_patch = (bm.z_patch, bm.y_patch, bm.x_patch)
592
+ self.dim_img = dim_img
593
+ self.list_IDs = list_IDs
594
+ self.batch_size = bm.batch_size
595
+ self.label = label
596
+ self.img = img
597
+ self.path_to_model = bm.path_to_model
598
+ self.early_stopping = bm.early_stopping
599
+ self.validation_freq = bm.validation_freq
600
+ self.n_classes = n_classes
601
+ self.n_channels = bm.channels
602
+ self.average_dice = bm.average_dice
603
+ self.django_env = bm.django_env
604
+ self.patch_normalization = bm.patch_normalization
605
+ self.train = train
606
+ self.train_dice = bm.train_dice
607
+
608
+ def on_train_begin(self, logs={}):
609
+ self.history = {}
610
+ self.history['val_accuracy'] = []
611
+ self.history['accuracy'] = []
612
+ self.history['val_dice'] = []
613
+ self.history['dice'] = []
614
+ self.history['val_loss'] = []
615
+ self.history['loss'] = []
616
+
617
+ def on_epoch_end(self, epoch, logs={}):
618
+ if epoch % self.validation_freq == 0:
619
+
620
+ result = np.zeros((*self.dim_img, self.n_classes), dtype=np.float32)
621
+
622
+ len_IDs = len(self.list_IDs)
623
+ n_batches = int(np.floor(len_IDs / self.batch_size))
624
+ np.random.shuffle(self.list_IDs)
625
+
626
+ for batch in range(n_batches):
627
+ # Generate indexes of the batch
628
+ list_IDs_batch = self.list_IDs[batch*self.batch_size:(batch+1)*self.batch_size]
629
+
630
+ # Initialization
631
+ X_val = np.empty((self.batch_size, *self.dim_patch, self.n_channels), dtype=np.float32)
632
+
633
+ # Generate data
634
+ for i, ID in enumerate(list_IDs_batch):
635
+
636
+ # get patch indices
637
+ k = ID // (self.dim_img[1]*self.dim_img[2])
638
+ rest = ID % (self.dim_img[1]*self.dim_img[2])
639
+ l = rest // self.dim_img[2]
640
+ m = rest % self.dim_img[2]
641
+ tmp_X = self.img[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]]
642
+ if self.patch_normalization:
643
+ tmp_X = np.copy(tmp_X, order='C')
644
+ for c in range(self.n_channels):
645
+ tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
646
+ tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
647
+ X_val[i] = tmp_X
648
+
649
+ # Prediction segmentation
650
+ y_predict = np.asarray(self.model.predict(X_val, verbose=0, steps=None, batch_size=self.batch_size))
651
+
652
+ for i, ID in enumerate(list_IDs_batch):
653
+
654
+ # get patch indices
655
+ k = ID // (self.dim_img[1]*self.dim_img[2])
656
+ rest = ID % (self.dim_img[1]*self.dim_img[2])
657
+ l = rest // self.dim_img[2]
658
+ m = rest % self.dim_img[2]
659
+ result[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]] += y_predict[i]
660
+
661
+ # calculate categorical crossentropy
662
+ if not self.train:
663
+ crossentropy = categorical_crossentropy(self.label, softmax(result))
664
+
665
+ # get result
666
+ result = np.argmax(result, axis=-1)
667
+ result = result.astype(np.uint8)
668
+
669
+ # calculate standard accuracy
670
+ if not self.train:
671
+ accuracy = np.sum(self.label==result) / float(self.label.size)
672
+
673
+ # calculate dice score
674
+ if self.average_dice:
675
+ dice = 0
676
+ for l in range(1, self.n_classes):
677
+ dice += 2 * np.logical_and(self.label==l, result==l).sum() / float((self.label==l).sum() + (result==l).sum())
678
+ dice /= float(self.n_classes-1)
679
+ else:
680
+ dice = 2 * np.logical_and(self.label==result, (self.label+result)>0).sum() / \
681
+ float((self.label>0).sum() + (result>0).sum())
682
+
683
+ if self.train:
684
+ logs['dice'] = dice
685
+ else:
686
+ # save best model only
687
+ if epoch == 0 or round(dice,4) > max(self.history['val_dice']):
688
+ self.model.save(str(self.path_to_model))
689
+
690
+ # add accuracy to history
691
+ self.history['loss'].append(round(logs['loss'],4))
692
+ self.history['accuracy'].append(round(logs['accuracy'],4))
693
+ if self.train_dice:
694
+ self.history['dice'].append(round(logs['dice'],4))
695
+ self.history['val_accuracy'].append(round(accuracy,4))
696
+ self.history['val_dice'].append(round(dice,4))
697
+ self.history['val_loss'].append(round(crossentropy,4))
698
+
699
+ # tensorflow monitoring variables
700
+ logs['val_loss'] = crossentropy
701
+ logs['val_accuracy'] = accuracy
702
+ logs['val_dice'] = dice
703
+ logs['best_acc'] = max(self.history['accuracy'])
704
+ if self.train_dice:
705
+ logs['best_dice'] = max(self.history['dice'])
706
+ logs['best_val_acc'] = max(self.history['val_accuracy'])
707
+ logs['best_val_dice'] = max(self.history['val_dice'])
708
+
709
+ # plot history in figure and save as numpy array
710
+ save_history(self.history, self.path_to_model, True, self.train_dice)
711
+
712
+ # print accuracies
713
+ print('\nValidation history:')
714
+ print('train_acc:', self.history['accuracy'])
715
+ if self.train_dice:
716
+ print('train_dice:', self.history['dice'])
717
+ print('val_acc:', self.history['val_accuracy'])
718
+ print('val_dice:', self.history['val_dice'])
719
+ print('')
720
+
721
+ # early stopping
722
+ if self.early_stopping > 0 and max(self.history['val_dice']) not in self.history['val_dice'][-self.early_stopping:]:
723
+ self.model.stop_training = True
724
+
725
+ def softmax(x):
726
+ # Avoiding numerical instability by subtracting the maximum value
727
+ exp_values = np.exp(x - np.max(x, axis=-1, keepdims=True))
728
+ probabilities = exp_values / np.sum(exp_values, axis=-1, keepdims=True)
729
+ return probabilities
730
+
731
+ @numba.jit(nopython=True)
732
+ def categorical_crossentropy(true_labels, predicted_probs):
733
+ # Clip predicted probabilities to avoid log(0) issues
734
+ predicted_probs = np.clip(predicted_probs, 1e-7, 1 - 1e-7)
735
+ predicted_probs = -np.log(predicted_probs)
736
+ zsh,ysh,xsh = true_labels.shape
737
+ # Calculate categorical crossentropy
738
+ loss = 0
739
+ for z in range(zsh):
740
+ for y in range(ysh):
741
+ for x in range(xsh):
742
+ l = true_labels[z,y,x]
743
+ loss += predicted_probs[z,y,x,l]
744
+ loss = loss / float(zsh*ysh*xsh)
745
+ return loss
746
+
747
+ def dice_coef(y_true, y_pred, smooth=1e-5):
748
+ intersection = K.sum(Multiply()([y_true, y_pred]))
749
+ dice = (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_pred) + smooth)
750
+ return dice
751
+
752
+ def dice_coef_loss(nb_labels):
753
+ #y_pred_f = K.argmax(y_pred, axis=-1)
754
+ #y_pred_f = K.cast(y_pred_f,'float32')
755
+ #dice_coef(y_true[:,:,:,:,1], y_pred[:,:,:,:,1] * y_pred_f)
756
+ def loss_fn(y_true, y_pred):
757
+ dice = 0
758
+ for index in range(1,nb_labels):
759
+ dice += dice_coef(y_true[:,:,:,:,index], y_pred[:,:,:,:,index])
760
+ dice = dice / (nb_labels-1)
761
+ loss = -K.log(dice)
762
+ #loss = 1 - dice
763
+ return loss
764
+ return loss_fn
765
+
766
+ def train_semantic_segmentation(bm,
767
+ img_list, label_list,
768
+ val_img_list, val_label_list,
769
+ img=None, label=None,
770
+ img_val=None, label_val=None,
771
+ header=None, extension='.tif'):
772
+
773
+ # training data
774
+ img, label, allLabels, normalization_parameters, header, extension, bm.channels = load_training_data(bm.normalize,
775
+ img_list, label_list, None, bm.x_scale, bm.y_scale, bm.z_scale, bm.no_scaling, bm.crop_data,
776
+ bm.only, bm.ignore, img, label, None, None, header, extension)
777
+
778
+ # configuration data
779
+ configuration_data = np.array([bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, 0, 1])
780
+
781
+ # img shape
782
+ zsh, ysh, xsh, _ = img.shape
783
+
784
+ # validation data
785
+ if any(val_img_list) or img_val is not None:
786
+ img_val, label_val, _, _, _, _, _ = load_training_data(bm.normalize,
787
+ val_img_list, val_label_list, bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.no_scaling, bm.crop_data,
788
+ bm.only, bm.ignore, img_val, label_val, normalization_parameters, allLabels)
789
+
790
+ elif bm.validation_split:
791
+ split = round(zsh * bm.validation_split)
792
+ img_val = np.copy(img[split:], order='C')
793
+ label_val = np.copy(label[split:], order='C')
794
+ img = np.copy(img[:split], order='C')
795
+ label = np.copy(label[:split], order='C')
796
+ zsh, ysh, xsh, _ = img.shape
797
+
798
+ # list of IDs
799
+ list_IDs_fg, list_IDs_bg = [], []
800
+
801
+ # get IDs of patches
802
+ if bm.balance:
803
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
804
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
805
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
806
+ if np.any(label[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
807
+ list_IDs_fg.append(k*ysh*xsh+l*xsh+m)
808
+ else:
809
+ list_IDs_bg.append(k*ysh*xsh+l*xsh+m)
810
+ else:
811
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
812
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
813
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
814
+ list_IDs_fg.append(k*ysh*xsh+l*xsh+m)
815
+
816
+ if img_val is not None:
817
+
818
+ # img_val shape
819
+ zsh_val, ysh_val, xsh_val, _ = img_val.shape
820
+
821
+ # list of validation IDs
822
+ list_IDs_val_fg, list_IDs_val_bg = [], []
823
+
824
+ # get validation IDs of patches
825
+ if bm.balance and not bm.val_dice:
826
+ for k in range(0, zsh_val-bm.z_patch+1, bm.validation_stride_size):
827
+ for l in range(0, ysh_val-bm.y_patch+1, bm.validation_stride_size):
828
+ for m in range(0, xsh_val-bm.x_patch+1, bm.validation_stride_size):
829
+ if np.any(label_val[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
830
+ list_IDs_val_fg.append(k*ysh_val*xsh_val+l*xsh_val+m)
831
+ else:
832
+ list_IDs_val_bg.append(k*ysh_val*xsh_val+l*xsh_val+m)
833
+ else:
834
+ for k in range(0, zsh_val-bm.z_patch+1, bm.validation_stride_size):
835
+ for l in range(0, ysh_val-bm.y_patch+1, bm.validation_stride_size):
836
+ for m in range(0, xsh_val-bm.x_patch+1, bm.validation_stride_size):
837
+ list_IDs_val_fg.append(k*ysh_val*xsh_val+l*xsh_val+m)
838
+
839
+ # number of labels
840
+ nb_labels = len(allLabels)
841
+
842
+ # input shape
843
+ input_shape = (bm.z_patch, bm.y_patch, bm.x_patch, bm.channels)
844
+
845
+ # parameters
846
+ params = {'batch_size': bm.batch_size,
847
+ 'dim': (bm.z_patch, bm.y_patch, bm.x_patch),
848
+ 'dim_img': (zsh, ysh, xsh),
849
+ 'n_classes': nb_labels,
850
+ 'n_channels': bm.channels,
851
+ 'augment': (bm.flip_x, bm.flip_y, bm.flip_z, bm.swapaxes, bm.rotate),
852
+ 'patch_normalization': bm.patch_normalization}
853
+
854
+ # data generator
855
+ validation_generator = None
856
+ training_generator = DataGenerator(img, label, list_IDs_fg, list_IDs_bg, True, True, False, **params)
857
+ if img_val is not None:
858
+ if bm.val_dice:
859
+ val_metrics = Metrics(bm, img_val, label_val, list_IDs_val_fg, (zsh_val, ysh_val, xsh_val), nb_labels, False)
860
+ else:
861
+ params['dim_img'] = (zsh_val, ysh_val, xsh_val)
862
+ params['augment'] = (False, False, False, False, 0)
863
+ validation_generator = DataGenerator(img_val, label_val, list_IDs_val_fg, list_IDs_val_bg, True, False, False, **params)
864
+
865
+ # monitor dice score on training data
866
+ if bm.train_dice:
867
+ train_metrics = Metrics(bm, img, label, list_IDs_fg, (zsh, ysh, xsh), nb_labels, True)
868
+
869
+ # create a MirroredStrategy
870
+ cdo = tf.distribute.ReductionToOneDevice()
871
+ strategy = tf.distribute.MirroredStrategy(cross_device_ops=cdo)
872
+ ngpus = int(strategy.num_replicas_in_sync)
873
+ print(f'Number of devices: {ngpus}')
874
+ if ngpus == 1 and os.name == 'nt':
875
+ atexit.register(strategy._extended._collective_ops._pool.close)
876
+
877
+ # compile model
878
+ with strategy.scope():
879
+
880
+ # build model
881
+ model = make_unet(input_shape, nb_labels, bm.network_filters, bm.resnet)
882
+ model.summary()
883
+
884
+ # pretrained model
885
+ if bm.pretrained_model:
886
+ model_pretrained = load_model(bm.pretrained_model)
887
+ model.set_weights(model_pretrained.get_weights())
888
+ if not bm.fine_tune:
889
+ nb_blocks = len(bm.network_filters.split('-'))
890
+ for k in range(nb_blocks+1, 2*nb_blocks):
891
+ for l in [1,2]:
892
+ name = f'conv_{k}_{l}'
893
+ layer = model.get_layer(name)
894
+ layer.trainable = False
895
+ name = f'conv_{2*nb_blocks}_1'
896
+ layer = model.get_layer(name)
897
+ layer.trainable = False
898
+
899
+ # optimizer
900
+ sgd = SGD(learning_rate=bm.learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
901
+
902
+ # comile model
903
+ loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
904
+ model.compile(loss=loss,
905
+ optimizer=sgd,
906
+ metrics=['accuracy'])
907
+
908
+ # save meta data
909
+ meta_data = MetaData(bm.path_to_model, configuration_data, allLabels,
910
+ extension, header, bm.crop_data, bm.cropping_weights, bm.cropping_config,
911
+ normalization_parameters, bm.cropping_norm)
912
+
913
+ # model checkpoint
914
+ if img_val is not None:
915
+ if bm.val_dice:
916
+ callbacks = [val_metrics, meta_data]
917
+ else:
918
+ model_checkpoint_callback = ModelCheckpoint(
919
+ filepath=str(bm.path_to_model),
920
+ save_weights_only=False,
921
+ monitor='val_accuracy',
922
+ mode='max',
923
+ save_best_only=True)
924
+ callbacks = [model_checkpoint_callback, meta_data]
925
+ if bm.early_stopping > 0:
926
+ callbacks.insert(0, EarlyStopping(monitor='val_accuracy', mode='max', patience=bm.early_stopping))
927
+ else:
928
+ callbacks = [ModelCheckpoint(filepath=str(bm.path_to_model)), meta_data]
929
+
930
+ # monitor dice score on training data
931
+ if bm.train_dice:
932
+ callbacks = [train_metrics] + callbacks
933
+
934
+ # custom callback
935
+ if bm.django_env and not bm.remote:
936
+ callbacks.insert(-1, CustomCallback(bm.img_id, bm.epochs))
937
+
938
+ # train model
939
+ history = model.fit(training_generator,
940
+ epochs=bm.epochs,
941
+ validation_data=validation_generator,
942
+ callbacks=callbacks,
943
+ workers=bm.workers)
944
+
945
+ # save monitoring figure on train end
946
+ if img_val is not None and not bm.val_dice:
947
+ save_history(history.history, bm.path_to_model, False, bm.train_dice)
948
+
949
+ def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
950
+ no_scaling, normalize, normalization_parameters, region_of_interest,
951
+ img, img_header):
952
+ # read image data
953
+ if img is None:
954
+ img, img_header = load_data(path_to_img, 'first_queue')
955
+
956
+ # verify validity
957
+ if img is None:
958
+ InputError.message = f'Invalid image data: {os.path.basename(path_to_img)}.'
959
+ raise InputError()
960
+
961
+ # preserve original image data
962
+ img_data = img.copy()
963
+
964
+ # handle all images having channels >=1
965
+ if len(img.shape)==3:
966
+ z_shape, y_shape, x_shape = img.shape
967
+ img = img.reshape(z_shape, y_shape, x_shape, 1)
968
+ if img.shape[3] != channels:
969
+ InputError.message = f'Number of channels must be {channels}.'
970
+ raise InputError()
971
+
972
+ # image shape
973
+ z_shape, y_shape, x_shape, _ = img.shape
974
+
975
+ # automatic cropping of image to region of interest
976
+ if np.any(region_of_interest):
977
+ min_z, max_z, min_y, max_y, min_x, max_x = region_of_interest[:]
978
+ img = np.copy(img[min_z:max_z,min_y:max_y,min_x:max_x], order='C')
979
+ region_of_interest = np.array([min_z,max_z,min_y,max_y,min_x,max_x,z_shape,y_shape,x_shape])
980
+ z_shape, y_shape, x_shape = max_z-min_z, max_y-min_y, max_x-min_x
981
+
982
+ # scale/resize image data
983
+ img = img.astype(np.float32)
984
+ if not no_scaling:
985
+ img = img_resize(img, z_scale, y_scale, x_scale)
986
+
987
+ # normalize image data
988
+ for c in range(channels):
989
+ img[:,:,:,c] -= np.amin(img[:,:,:,c])
990
+ img[:,:,:,c] /= np.amax(img[:,:,:,c])
991
+ if normalize:
992
+ mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
993
+ img[:,:,:,c] = (img[:,:,:,c] - mean) / std
994
+ img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
995
+
996
+ # limit intensity range
997
+ if normalize:
998
+ img[img<0] = 0
999
+ img[img>1] = 1
1000
+
1001
+ return img, img_header, z_shape, y_shape, x_shape, region_of_interest, img_data
1002
+
1003
+ def predict_semantic_segmentation(bm, img, path_to_model,
1004
+ z_patch, y_patch, x_patch, z_shape, y_shape, x_shape, compress, header,
1005
+ img_header, stride_size, allLabels, batch_size, region_of_interest,
1006
+ no_scaling, extension, img_data):
1007
+
1008
+ results = {}
1009
+
1010
+ # img shape
1011
+ zsh, ysh, xsh, csh = img.shape
1012
+
1013
+ # number of labels
1014
+ nb_labels = len(allLabels)
1015
+
1016
+ # list of IDs
1017
+ list_IDs = []
1018
+
1019
+ # get Ids of patches
1020
+ for k in range(0, zsh-z_patch+1, stride_size):
1021
+ for l in range(0, ysh-y_patch+1, stride_size):
1022
+ for m in range(0, xsh-x_patch+1, stride_size):
1023
+ list_IDs.append(k*ysh*xsh+l*xsh+m)
1024
+
1025
+ # make length of list divisible by batch size
1026
+ rest = batch_size - (len(list_IDs) % batch_size)
1027
+ list_IDs = list_IDs + list_IDs[:rest]
1028
+
1029
+ # number of patches
1030
+ nb_patches = len(list_IDs)
1031
+
1032
+ # parameters
1033
+ params = {'dim': (z_patch, y_patch, x_patch),
1034
+ 'dim_img': (zsh, ysh, xsh),
1035
+ 'batch_size': batch_size,
1036
+ 'n_channels': csh,
1037
+ 'patch_normalization': bm.patch_normalization}
1038
+
1039
+ # data generator
1040
+ predict_generator = PredictDataGenerator(img, list_IDs, **params)
1041
+
1042
+ # load model
1043
+ model = load_model(str(path_to_model))
1044
+
1045
+ # predict
1046
+ if nb_patches < 400:
1047
+ probabilities = model.predict(predict_generator, verbose=0, steps=None)
1048
+ else:
1049
+ X = np.empty((batch_size, z_patch, y_patch, x_patch, csh), dtype=np.float32)
1050
+ probabilities = np.zeros((nb_patches, z_patch, y_patch, x_patch, nb_labels), dtype=np.float32)
1051
+
1052
+ # get image patches
1053
+ for step in range(nb_patches//batch_size):
1054
+ for i, ID in enumerate(list_IDs[step*batch_size:(step+1)*batch_size]):
1055
+
1056
+ # get patch indices
1057
+ k = ID // (ysh*xsh)
1058
+ rest = ID % (ysh*xsh)
1059
+ l = rest // xsh
1060
+ m = rest % xsh
1061
+
1062
+ # get patch
1063
+ tmp_X = img[k:k+z_patch,l:l+y_patch,m:m+x_patch]
1064
+ if bm.patch_normalization:
1065
+ tmp_X = np.copy(tmp_X, order='C')
1066
+ for c in range(csh):
1067
+ tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
1068
+ tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
1069
+ X[i] = tmp_X
1070
+
1071
+ probabilities[step*batch_size:(step+1)*batch_size] = model.predict(X, verbose=0, steps=None, batch_size=batch_size)
1072
+
1073
+ # create final
1074
+ final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1075
+ if bm.return_probs:
1076
+ counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1077
+ nb = 0
1078
+ for k in range(0, zsh-z_patch+1, stride_size):
1079
+ for l in range(0, ysh-y_patch+1, stride_size):
1080
+ for m in range(0, xsh-x_patch+1, stride_size):
1081
+ final[k:k+z_patch, l:l+y_patch, m:m+x_patch] += probabilities[nb]
1082
+ if bm.return_probs:
1083
+ counter[k:k+z_patch, l:l+y_patch, m:m+x_patch] += 1
1084
+ nb += 1
1085
+
1086
+ # return probabilities
1087
+ if bm.return_probs:
1088
+ counter[counter==0] = 1
1089
+ probabilities = final / counter
1090
+ if not no_scaling:
1091
+ probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
1092
+ if np.any(region_of_interest):
1093
+ min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
1094
+ tmp = np.zeros((original_zsh, original_ysh, original_xsh, nb_labels), dtype=np.float32)
1095
+ tmp[min_z:max_z,min_y:max_y,min_x:max_x] = probabilities
1096
+ probabilities = np.copy(tmp, order='C')
1097
+ results['probs'] = probabilities
1098
+
1099
+ # get final
1100
+ label = np.argmax(final, axis=3)
1101
+ label = label.astype(np.uint8)
1102
+
1103
+ # rescale final to input size
1104
+ if not no_scaling:
1105
+ label = img_resize(label, z_shape, y_shape, x_shape, labels=True)
1106
+
1107
+ # revert automatic cropping
1108
+ if np.any(region_of_interest):
1109
+ min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
1110
+ tmp = np.zeros((original_zsh, original_ysh, original_xsh), dtype=np.uint8)
1111
+ tmp[min_z:max_z,min_y:max_y,min_x:max_x] = label
1112
+ label = np.copy(tmp, order='C')
1113
+
1114
+ # get result
1115
+ label = get_labels(label, allLabels)
1116
+ results['regular'] = label
1117
+
1118
+ # load header from file
1119
+ if bm.header_file and os.path.exists(bm.header_file):
1120
+ _, header = load_data(bm.header_file)
1121
+ # update file extension
1122
+ if header is not None and bm.path_to_image:
1123
+ extension = os.path.splitext(bm.header_file)[1]
1124
+ if extension == '.gz':
1125
+ extension = '.nii.gz'
1126
+ bm.path_to_final = os.path.splitext(bm.path_to_final)[0] + extension
1127
+ if bm.django_env and not bm.remote and not bm.tarfile:
1128
+ bm.path_to_final = unique_file_path(bm.path_to_final)
1129
+
1130
+ # handle amira header
1131
+ if header is not None:
1132
+ if extension == '.am':
1133
+ header = get_image_dimensions(header[0], label)
1134
+ if img_header is not None:
1135
+ try:
1136
+ header = get_physical_size(header, img_header[0])
1137
+ except:
1138
+ pass
1139
+ header = [header]
1140
+ else:
1141
+ # build new header
1142
+ if img_header is None:
1143
+ zsh, ysh, xsh = label.shape
1144
+ img_header = sitk.Image(xsh, ysh, zsh, header.GetPixelID())
1145
+ # copy metadata
1146
+ for key in header.GetMetaDataKeys():
1147
+ if not (re.match(r'Segment\d+_Extent$', key) or key=='Segmentation_ConversionParameters'):
1148
+ img_header.SetMetaData(key, header.GetMetaData(key))
1149
+ header = img_header
1150
+ results['header'] = header
1151
+
1152
+ # save result
1153
+ if bm.path_to_image:
1154
+ save_data(bm.path_to_final, label, header=header, compress=compress)
1155
+
1156
+ # paths to optional results
1157
+ filename, extension = os.path.splitext(bm.path_to_final)
1158
+ if extension == '.gz':
1159
+ extension = '.nii.gz'
1160
+ filename = filename[:-4]
1161
+ path_to_cleaned = filename + '.cleaned' + extension
1162
+ path_to_filled = filename + '.filled' + extension
1163
+ path_to_cleaned_filled = filename + '.cleaned.filled' + extension
1164
+ path_to_refined = filename + '.refined' + extension
1165
+ path_to_acwe = filename + '.acwe' + extension
1166
+
1167
+ # remove outliers
1168
+ if bm.clean:
1169
+ cleaned_result = clean(label, bm.clean)
1170
+ results['cleaned'] = cleaned_result
1171
+ if bm.path_to_image:
1172
+ save_data(path_to_cleaned, cleaned_result, header=header, compress=compress)
1173
+ if bm.fill:
1174
+ filled_result = clean(label, bm.fill)
1175
+ results['filled'] = filled_result
1176
+ if bm.path_to_image:
1177
+ save_data(path_to_filled, filled_result, header=header, compress=compress)
1178
+ if bm.clean and bm.fill:
1179
+ cleaned_filled_result = cleaned_result + (filled_result - label)
1180
+ results['cleaned_filled'] = cleaned_filled_result
1181
+ if bm.path_to_image:
1182
+ save_data(path_to_cleaned_filled, cleaned_filled_result, header=header, compress=compress)
1183
+
1184
+ # post-processing with active contour
1185
+ if bm.acwe:
1186
+ acwe_result = activeContour(img_data, label, bm.acwe_alpha, bm.acwe_smooth, bm.acwe_steps)
1187
+ refined_result = activeContour(img_data, label, simple=True)
1188
+ results['acwe'] = acwe_result
1189
+ results['refined'] = refined_result
1190
+ if bm.path_to_image:
1191
+ save_data(path_to_acwe, acwe_result, header=header, compress=compress)
1192
+ save_data(path_to_refined, refined_result, header=header, compress=compress)
1193
+
1194
+ return results, bm
1195
+