biomedisa 2024.5.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. biomedisa/__init__.py +53 -0
  2. biomedisa/__main__.py +18 -0
  3. biomedisa/biomedisa_features/DataGenerator.py +299 -0
  4. biomedisa/biomedisa_features/DataGeneratorCrop.py +121 -0
  5. biomedisa/biomedisa_features/PredictDataGenerator.py +87 -0
  6. biomedisa/biomedisa_features/PredictDataGeneratorCrop.py +74 -0
  7. biomedisa/biomedisa_features/__init__.py +0 -0
  8. biomedisa/biomedisa_features/active_contour.py +434 -0
  9. biomedisa/biomedisa_features/amira_to_np/__init__.py +0 -0
  10. biomedisa/biomedisa_features/amira_to_np/amira_data_stream.py +980 -0
  11. biomedisa/biomedisa_features/amira_to_np/amira_grammar.py +369 -0
  12. biomedisa/biomedisa_features/amira_to_np/amira_header.py +290 -0
  13. biomedisa/biomedisa_features/amira_to_np/amira_helper.py +72 -0
  14. biomedisa/biomedisa_features/assd.py +167 -0
  15. biomedisa/biomedisa_features/biomedisa_helper.py +801 -0
  16. biomedisa/biomedisa_features/create_slices.py +286 -0
  17. biomedisa/biomedisa_features/crop_helper.py +586 -0
  18. biomedisa/biomedisa_features/curvop_numba.py +149 -0
  19. biomedisa/biomedisa_features/django_env.py +172 -0
  20. biomedisa/biomedisa_features/keras_helper.py +1219 -0
  21. biomedisa/biomedisa_features/nc_reader.py +179 -0
  22. biomedisa/biomedisa_features/pid.py +52 -0
  23. biomedisa/biomedisa_features/process_image.py +253 -0
  24. biomedisa/biomedisa_features/pycuda_test.py +84 -0
  25. biomedisa/biomedisa_features/random_walk/__init__.py +0 -0
  26. biomedisa/biomedisa_features/random_walk/gpu_kernels.py +183 -0
  27. biomedisa/biomedisa_features/random_walk/pycuda_large.py +826 -0
  28. biomedisa/biomedisa_features/random_walk/pycuda_large_allx.py +806 -0
  29. biomedisa/biomedisa_features/random_walk/pycuda_small.py +414 -0
  30. biomedisa/biomedisa_features/random_walk/pycuda_small_allx.py +493 -0
  31. biomedisa/biomedisa_features/random_walk/pyopencl_large.py +760 -0
  32. biomedisa/biomedisa_features/random_walk/pyopencl_small.py +441 -0
  33. biomedisa/biomedisa_features/random_walk/rw_large.py +390 -0
  34. biomedisa/biomedisa_features/random_walk/rw_small.py +310 -0
  35. biomedisa/biomedisa_features/remove_outlier.py +399 -0
  36. biomedisa/biomedisa_features/split_volume.py +274 -0
  37. biomedisa/deeplearning.py +519 -0
  38. biomedisa/interpolation.py +371 -0
  39. biomedisa/mesh.py +406 -0
  40. biomedisa-2024.5.14.dist-info/LICENSE +191 -0
  41. biomedisa-2024.5.14.dist-info/METADATA +306 -0
  42. biomedisa-2024.5.14.dist-info/RECORD +44 -0
  43. biomedisa-2024.5.14.dist-info/WHEEL +5 -0
  44. biomedisa-2024.5.14.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1219 @@
1
+ ##########################################################################
2
+ ## ##
3
+ ## Copyright (c) 2024 Philipp Lösel. All rights reserved. ##
4
+ ## ##
5
+ ## This file is part of the open source project biomedisa. ##
6
+ ## ##
7
+ ## Licensed under the European Union Public Licence (EUPL) ##
8
+ ## v1.2, or - as soon as they will be approved by the ##
9
+ ## European Commission - subsequent versions of the EUPL; ##
10
+ ## ##
11
+ ## You may redistribute it and/or modify it under the terms ##
12
+ ## of the EUPL v1.2. You may not use this work except in ##
13
+ ## compliance with this Licence. ##
14
+ ## ##
15
+ ## You can obtain a copy of the Licence at: ##
16
+ ## ##
17
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
18
+ ## ##
19
+ ## Unless required by applicable law or agreed to in ##
20
+ ## writing, software distributed under the Licence is ##
21
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
22
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
23
+ ## ##
24
+ ## See the Licence for the specific language governing ##
25
+ ## permissions and limitations under the Licence. ##
26
+ ## ##
27
+ ##########################################################################
28
+
29
+ import os
30
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
31
+ try:
32
+ from tensorflow.keras.optimizers.legacy import SGD
33
+ except:
34
+ from tensorflow.keras.optimizers import SGD
35
+ from tensorflow.keras.models import Model, load_model
36
+ from tensorflow.keras.layers import (
37
+ Input, Conv3D, MaxPooling3D, UpSampling3D, Activation, Reshape,
38
+ BatchNormalization, Concatenate, ReLU, Add, GlobalAveragePooling3D,
39
+ Dense, Dropout, MaxPool3D, Flatten, Multiply)
40
+ from tensorflow.keras import backend as K
41
+ from tensorflow.keras.utils import to_categorical
42
+ from tensorflow.keras.callbacks import Callback, ModelCheckpoint, EarlyStopping
43
+ from biomedisa_features.DataGenerator import DataGenerator
44
+ from biomedisa_features.PredictDataGenerator import PredictDataGenerator
45
+ from biomedisa_features.biomedisa_helper import (
46
+ img_resize, load_data, save_data, set_labels_to_zero, id_generator)
47
+ from biomedisa_features.remove_outlier import clean, fill
48
+ from biomedisa_features.active_contour import activeContour
49
+ import matplotlib.pyplot as plt
50
+ import SimpleITK as sitk
51
+ import tensorflow as tf
52
+ import numpy as np
53
+ import cv2
54
+ import tarfile
55
+ from random import shuffle
56
+ import glob
57
+ import random
58
+ import numba
59
+ import re
60
+ import time
61
+ import h5py
62
+ import atexit
63
+ import shutil
64
+
65
+ class InputError(Exception):
66
+ def __init__(self, message=None, img_names=[], label_names=[]):
67
+ self.message = message
68
+ self.img_names = img_names
69
+ self.label_names = label_names
70
+
71
+ def save_history(history, path_to_model, val_dice, train_dice):
72
+ # summarize history for accuracy
73
+ plt.plot(history['accuracy'])
74
+ plt.plot(history['val_accuracy'])
75
+ if val_dice and train_dice:
76
+ plt.plot(history['dice'])
77
+ plt.plot(history['val_dice'])
78
+ plt.legend(['Accuracy (train)', 'Accuracy (test)', 'Dice score (train)', 'Dice score (test)'], loc='upper left')
79
+ elif train_dice:
80
+ plt.plot(history['dice'])
81
+ plt.legend(['Accuracy (train)', 'Accuracy (test)', 'Dice score (train)'], loc='upper left')
82
+ elif val_dice:
83
+ plt.plot(history['val_dice'])
84
+ plt.legend(['Accuracy (train)', 'Accuracy (test)', 'Dice score (test)'], loc='upper left')
85
+ else:
86
+ plt.legend(['Accuracy (train)', 'Accuracy (test)'], loc='upper left')
87
+ plt.title('model accuracy')
88
+ plt.ylabel('accuracy')
89
+ plt.xlabel('epoch')
90
+ plt.tight_layout() # To prevent overlapping of subplots
91
+ plt.savefig(path_to_model.replace('.h5','_acc.png'), dpi=300, bbox_inches='tight')
92
+ plt.clf()
93
+ # summarize history for loss
94
+ plt.plot(history['loss'])
95
+ plt.plot(history['val_loss'])
96
+ plt.title('model loss')
97
+ plt.ylabel('loss')
98
+ plt.xlabel('epoch')
99
+ plt.legend(['train', 'test'], loc='upper left')
100
+ plt.tight_layout() # To prevent overlapping of subplots
101
+ plt.savefig(path_to_model.replace('.h5','_loss.png'), dpi=300, bbox_inches='tight')
102
+ plt.clf()
103
+ # save history dictonary
104
+ np.save(path_to_model.replace('.h5','.npy'), history)
105
+
106
+ def predict_blocksize(labelData, x_puffer, y_puffer, z_puffer):
107
+ zsh, ysh, xsh = labelData.shape
108
+ argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x = zsh, 0, ysh, 0, xsh, 0
109
+ for k in range(zsh):
110
+ y, x = np.nonzero(labelData[k])
111
+ if x.any():
112
+ argmin_x = min(argmin_x, np.amin(x))
113
+ argmax_x = max(argmax_x, np.amax(x))
114
+ argmin_y = min(argmin_y, np.amin(y))
115
+ argmax_y = max(argmax_y, np.amax(y))
116
+ argmin_z = min(argmin_z, k)
117
+ argmax_z = max(argmax_z, k)
118
+ zmin, zmax = argmin_z, argmax_z
119
+ argmin_x = argmin_x - x_puffer if argmin_x - x_puffer > 0 else 0
120
+ argmax_x = argmax_x + x_puffer if argmax_x + x_puffer < xsh else xsh
121
+ argmin_y = argmin_y - y_puffer if argmin_y - y_puffer > 0 else 0
122
+ argmax_y = argmax_y + y_puffer if argmax_y + y_puffer < ysh else ysh
123
+ argmin_z = argmin_z - z_puffer if argmin_z - z_puffer > 0 else 0
124
+ argmax_z = argmax_z + z_puffer if argmax_z + z_puffer < zsh else zsh
125
+ return argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x
126
+
127
+ def get_image_dimensions(header, data):
128
+
129
+ # read header as string
130
+ b = header.tobytes()
131
+ try:
132
+ s = b.decode("utf-8")
133
+ except:
134
+ s = b.decode("latin1")
135
+
136
+ # get image size in header
137
+ lattice = re.search('define Lattice (.*)\n', s)
138
+ lattice = lattice.group(1)
139
+ xsh, ysh, zsh = lattice.split(' ')
140
+ xsh, ysh, zsh = int(xsh), int(ysh), int(zsh)
141
+
142
+ # new image size
143
+ z,y,x = data.shape
144
+
145
+ # change image size in header
146
+ s = s.replace('%s %s %s' %(xsh,ysh,zsh), '%s %s %s' %(x,y,z),1)
147
+ s = s.replace('Content "%sx%sx%s byte' %(xsh,ysh,zsh), 'Content "%sx%sx%s byte' %(x,y,z),1)
148
+
149
+ # return header as array
150
+ b2 = s.encode()
151
+ new_header = np.frombuffer(b2, dtype=header.dtype)
152
+ return new_header
153
+
154
+ def get_physical_size(header, img_header):
155
+
156
+ # read img_header as string
157
+ b = img_header.tobytes()
158
+ try:
159
+ s = b.decode("utf-8")
160
+ except:
161
+ s = b.decode("latin1")
162
+
163
+ # get physical size from image header
164
+ bounding_box = re.search('BoundingBox (.*),\n', s)
165
+ bounding_box = bounding_box.group(1)
166
+ i0, i1, i2, i3, i4, i5 = bounding_box.split(' ')
167
+ bounding_box_i = re.search('&BoundingBox (.*),\n', s)
168
+ bounding_box_i = bounding_box_i.group(1)
169
+
170
+ # read label header as string
171
+ b = header.tobytes()
172
+ try:
173
+ s = b.decode("utf-8")
174
+ except:
175
+ s = b.decode("latin1")
176
+
177
+ # get physical size from label header
178
+ bounding_box = re.search('BoundingBox (.*),\n', s)
179
+ bounding_box = bounding_box.group(1)
180
+ l0, l1, l2, l3, l4, l5 = bounding_box.split(' ')
181
+ bounding_box_l = re.search('&BoundingBox (.*),\n', s)
182
+ bounding_box_l = bounding_box_l.group(1)
183
+
184
+ # change physical size in label header
185
+ s = s.replace('%s %s %s %s %s %s' %(l0,l1,l2,l3,l4,l5),'%s %s %s %s %s %s' %(i0,i1,i2,i3,i4,i5),1)
186
+ s = s.replace(bounding_box_l,bounding_box_i,1)
187
+
188
+ # return header as array
189
+ b2 = s.encode()
190
+ new_header = np.frombuffer(b2, dtype=header.dtype)
191
+ return new_header
192
+
193
+ @numba.jit(nopython=True)
194
+ def compute_position(position, zsh, ysh, xsh):
195
+ zsh_h, ysh_h, xsh_h = zsh//2, ysh//2, xsh//2
196
+ for k in range(zsh):
197
+ for l in range(ysh):
198
+ for m in range(xsh):
199
+ x = (xsh_h-m)**2
200
+ y = (ysh_h-l)**2
201
+ z = (zsh_h-k)**2
202
+ position[k,l,m] = x+y+z
203
+ return position
204
+
205
+ def make_conv_block(nb_filters, input_tensor, block):
206
+ def make_stage(input_tensor, stage):
207
+ name = 'conv_{}_{}'.format(block, stage)
208
+ x = Conv3D(nb_filters, (3, 3, 3), activation='relu',
209
+ padding='same', name=name, data_format="channels_last")(input_tensor)
210
+ name = 'batch_norm_{}_{}'.format(block, stage)
211
+ try:
212
+ x = BatchNormalization(name=name, synchronized=True)(x)
213
+ except:
214
+ x = BatchNormalization(name=name)(x)
215
+ x = Activation('relu')(x)
216
+ return x
217
+
218
+ x = make_stage(input_tensor, 1)
219
+ x = make_stage(x, 2)
220
+ return x
221
+
222
+ def make_conv_block_resnet(nb_filters, input_tensor, block):
223
+
224
+ # Residual/Skip connection
225
+ res = Conv3D(nb_filters, (1, 1, 1), padding='same', use_bias=False, name="Identity{}_1".format(block))(input_tensor)
226
+
227
+ stage = 1
228
+ name = 'conv_{}_{}'.format(block, stage)
229
+ fx = Conv3D(nb_filters, (3, 3, 3), activation='relu', padding='same', name=name, data_format="channels_last")(input_tensor)
230
+ name = 'batch_norm_{}_{}'.format(block, stage)
231
+ try:
232
+ fx = BatchNormalization(name=name, synchronized=True)(fx)
233
+ except:
234
+ fx = BatchNormalization(name=name)(fx)
235
+ fx = Activation('relu')(fx)
236
+
237
+ stage = 2
238
+ name = 'conv_{}_{}'.format(block, stage)
239
+ fx = Conv3D(nb_filters, (3, 3, 3), padding='same', name=name, data_format="channels_last")(fx)
240
+ name = 'batch_norm_{}_{}'.format(block, stage)
241
+ try:
242
+ fx = BatchNormalization(name=name, synchronized=True)(fx)
243
+ except:
244
+ fx = BatchNormalization(name=name)(fx)
245
+
246
+ out = Add()([res,fx])
247
+ out = ReLU()(out)
248
+
249
+ return out
250
+
251
+ def make_unet(input_shape, nb_labels, filters='32-64-128-256-512', resnet=False):
252
+
253
+ nb_plans, nb_rows, nb_cols, _ = input_shape
254
+
255
+ inputs = Input(input_shape)
256
+
257
+ filters = filters.split('-')
258
+ filters = np.array(filters, dtype=int)
259
+ latent_space_size = filters[-1]
260
+ filters = filters[:-1]
261
+ convs = []
262
+
263
+ i = 1
264
+ for f in filters:
265
+ if i==1:
266
+ if resnet:
267
+ conv = make_conv_block_resnet(f, inputs, i)
268
+ else:
269
+ conv = make_conv_block(f, inputs, i)
270
+ else:
271
+ if resnet:
272
+ conv = make_conv_block_resnet(f, pool, i)
273
+ else:
274
+ conv = make_conv_block(f, pool, i)
275
+ pool = MaxPooling3D(pool_size=(2, 2, 2))(conv)
276
+ convs.append(conv)
277
+ i += 1
278
+
279
+ if resnet:
280
+ conv = make_conv_block_resnet(latent_space_size, pool, i)
281
+ else:
282
+ conv = make_conv_block(latent_space_size, pool, i)
283
+ i += 1
284
+
285
+ for k, f in enumerate(filters[::-1]):
286
+ up = Concatenate()([UpSampling3D(size=(2, 2, 2))(conv), convs[-(k+1)]])
287
+ if resnet:
288
+ conv = make_conv_block_resnet(f, up, i)
289
+ else:
290
+ conv = make_conv_block(f, up, i)
291
+ i += 1
292
+
293
+ conv = Conv3D(nb_labels, (1, 1, 1), name=f'conv_{i}_1')(conv)
294
+
295
+ x = Reshape((nb_plans * nb_rows * nb_cols, nb_labels))(conv)
296
+ x = Activation('softmax')(x)
297
+ outputs = Reshape((nb_plans, nb_rows, nb_cols, nb_labels))(x)
298
+
299
+ model = Model(inputs=inputs, outputs=outputs)
300
+
301
+ return model
302
+
303
+ def get_labels(arr, allLabels):
304
+ np_unique = np.unique(arr)
305
+ final = np.zeros_like(arr)
306
+ for k in np_unique:
307
+ final[arr == k] = allLabels[k]
308
+ return final
309
+
310
+ def read_img_list(img_list, label_list):
311
+ # read filenames
312
+ InputError.img_names = []
313
+ InputError.label_names = []
314
+ img_names, label_names = [], []
315
+ for img_name, label_name in zip(img_list, label_list):
316
+
317
+ # check for tarball
318
+ img_dir, img_ext = os.path.splitext(img_name)
319
+ if img_ext == '.gz':
320
+ img_dir, img_ext = os.path.splitext(img_dir)
321
+
322
+ label_dir, label_ext = os.path.splitext(label_name)
323
+ if label_ext == '.gz':
324
+ label_dir, label_ext = os.path.splitext(label_dir)
325
+
326
+ if (img_ext == '.tar' and label_ext == '.tar') or (os.path.isdir(img_name) and os.path.isdir(label_name)):
327
+
328
+ # extract files
329
+ if img_ext == '.tar':
330
+ img_dir = BASE_DIR + '/tmp/' + id_generator(40)
331
+ tar = tarfile.open(img_name)
332
+ tar.extractall(path=img_dir)
333
+ tar.close()
334
+ img_name = img_dir
335
+ if label_ext == '.tar':
336
+ label_dir = BASE_DIR + '/tmp/' + id_generator(40)
337
+ tar = tarfile.open(label_name)
338
+ tar.extractall(path=label_dir)
339
+ tar.close()
340
+ label_name = label_dir
341
+
342
+ for data_type in ['.am','.tif','.tiff','.hdr','.mhd','.mha','.nrrd','.nii','.nii.gz','.zip','.mrc']:
343
+ img_names += [file for file in glob.glob(img_name+'/**/*'+data_type, recursive=True) if not os.path.basename(file).startswith('.')]
344
+ label_names += [file for file in glob.glob(label_name+'/**/*'+data_type, recursive=True) if not os.path.basename(file).startswith('.')]
345
+ img_names = sorted(img_names)
346
+ label_names = sorted(label_names)
347
+ if len(img_names)==0 or len(label_names)==0 or len(img_names)!=len(label_names):
348
+ if img_ext == '.tar' and os.path.exists(img_name):
349
+ shutil.rmtree(img_name)
350
+ if label_ext == '.tar' and os.path.exists(label_name):
351
+ shutil.rmtree(label_name)
352
+ if len(img_names)!=len(label_names):
353
+ InputError.message = 'Number of image and label files must be the same'
354
+ elif img_ext == '.tar':
355
+ InputError.message = 'Invalid image TAR file'
356
+ elif label_ext == '.tar':
357
+ InputError.message = 'Invalid label TAR file'
358
+ elif len(img_names)==0:
359
+ InputError.message = 'Invalid image data'
360
+ else:
361
+ InputError.message = 'Invalid label data'
362
+ raise InputError()
363
+ else:
364
+ img_names.append(img_name)
365
+ label_names.append(label_name)
366
+ return img_names, label_names
367
+
368
+ def remove_extracted_data(img_list, label_list):
369
+ for img_name in img_list:
370
+ if BASE_DIR + '/tmp/' in img_name:
371
+ img_dir = img_name[:len(BASE_DIR + '/tmp/') + 40]
372
+ if os.path.exists(img_dir):
373
+ shutil.rmtree(img_dir)
374
+ for label_name in label_list:
375
+ if BASE_DIR + '/tmp/' in label_name:
376
+ label_dir = label_name[:len(BASE_DIR + '/tmp/') + 40]
377
+ if os.path.exists(label_dir):
378
+ shutil.rmtree(label_dir)
379
+
380
+ def load_training_data(normalize, img_list, label_list, channels, x_scale, y_scale, z_scale, no_scaling,
381
+ crop_data, labels_to_compute, labels_to_remove, img_in=None, label_in=None,
382
+ normalization_parameters=None, allLabels=None, header=None, extension='.tif',
383
+ x_puffer=25, y_puffer=25, z_puffer=25):
384
+
385
+ # read image lists
386
+ if any(img_list):
387
+ img_names, label_names = read_img_list(img_list, label_list)
388
+ InputError.img_names = img_names
389
+ InputError.label_names = label_names
390
+
391
+ # load first label
392
+ if any(img_list):
393
+ label, header, extension = load_data(label_names[0], 'first_queue', True)
394
+ if label is None:
395
+ InputError.message = f'Invalid label data "{os.path.basename(label_names[0])}"'
396
+ raise InputError()
397
+ elif type(label_in) is list:
398
+ label = label_in[0]
399
+ else:
400
+ label = label_in
401
+ label_dim = label.shape
402
+ label = set_labels_to_zero(label, labels_to_compute, labels_to_remove)
403
+ label_values, counts = np.unique(label, return_counts=True)
404
+ print(f'{os.path.basename(label_names[0])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
405
+ if crop_data:
406
+ argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(label, x_puffer, y_puffer, z_puffer)
407
+ label = np.copy(label[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
408
+ if not no_scaling:
409
+ label = img_resize(label, z_scale, y_scale, x_scale, labels=True)
410
+
411
+ # if header is not single data stream Amira Mesh falling back to Multi-TIFF
412
+ if extension != '.am':
413
+ if extension != '.tif':
414
+ print(f'Warning! Please use -hf or --header_file="path_to_training_label{extension}" for prediction to save your result as "{extension}"')
415
+ extension, header = '.tif', None
416
+ elif len(header) > 1:
417
+ print('Warning! Multiple data streams are not supported. Falling back to TIFF.')
418
+ extension, header = '.tif', None
419
+ else:
420
+ header = header[0]
421
+
422
+ # load first img
423
+ if any(img_list):
424
+ img, _ = load_data(img_names[0], 'first_queue')
425
+ if img is None:
426
+ InputError.message = f'Invalid image data "{os.path.basename(img_names[0])}"'
427
+ raise InputError()
428
+ elif type(img_in) is list:
429
+ img = img_in[0]
430
+ else:
431
+ img = img_in
432
+ if label_dim != img.shape:
433
+ InputError.message = f'Dimensions of "{os.path.basename(img_names[0])}" and "{os.path.basename(label_names[0])}" do not match'
434
+ raise InputError()
435
+
436
+ # ensure images have channels >=1
437
+ if len(img.shape)==3:
438
+ z_shape, y_shape, x_shape = img.shape
439
+ img = img.reshape(z_shape, y_shape, x_shape, 1)
440
+ if channels is None:
441
+ channels = img.shape[3]
442
+ if channels != img.shape[3]:
443
+ InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[0])}"'
444
+ raise InputError()
445
+
446
+ # crop data
447
+ if crop_data:
448
+ img = np.copy(img[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
449
+
450
+ # scale/resize image data
451
+ img = img.astype(np.float32)
452
+ if not no_scaling:
453
+ img = img_resize(img, z_scale, y_scale, x_scale)
454
+
455
+ # normalize image data
456
+ for c in range(channels):
457
+ img[:,:,:,c] -= np.amin(img[:,:,:,c])
458
+ img[:,:,:,c] /= np.amax(img[:,:,:,c])
459
+ if normalization_parameters is None:
460
+ normalization_parameters = np.zeros((2,channels))
461
+ normalization_parameters[0,c] = np.mean(img[:,:,:,c])
462
+ normalization_parameters[1,c] = np.std(img[:,:,:,c])
463
+ elif normalize:
464
+ mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
465
+ img[:,:,:,c] = (img[:,:,:,c] - mean) / std
466
+ img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
467
+
468
+ # loop over list of images
469
+ if any(img_list) or type(img_in) is list:
470
+ number_of_images = len(img_names) if any(img_list) else len(img_in)
471
+
472
+ for k in range(1, number_of_images):
473
+
474
+ # append label
475
+ if any(label_list):
476
+ a, _ = load_data(label_names[k], 'first_queue')
477
+ if a is None:
478
+ InputError.message = f'Invalid label data "{os.path.basename(label_names[k])}"'
479
+ raise InputError()
480
+ else:
481
+ a = label_in[k]
482
+ label_dim = a.shape
483
+ a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
484
+ label_values, counts = np.unique(a, return_counts=True)
485
+ print(f'{os.path.basename(label_names[k])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
486
+ if crop_data:
487
+ argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(a, x_puffer, y_puffer, z_puffer)
488
+ a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
489
+ if not no_scaling:
490
+ a = img_resize(a, z_scale, y_scale, x_scale, labels=True)
491
+ label = np.append(label, a, axis=0)
492
+
493
+ # append image
494
+ if any(img_list):
495
+ a, _ = load_data(img_names[k], 'first_queue')
496
+ if a is None:
497
+ InputError.message = f'Invalid image data "{os.path.basename(img_names[k])}"'
498
+ raise InputError()
499
+ else:
500
+ a = img_in[k]
501
+ if label_dim != a.shape:
502
+ InputError.message = f'Dimensions of "{os.path.basename(img_names[k])}" and "{os.path.basename(label_names[k])}" do not match'
503
+ raise InputError()
504
+ if len(a.shape)==3:
505
+ z_shape, y_shape, x_shape = a.shape
506
+ a = a.reshape(z_shape, y_shape, x_shape, 1)
507
+ if a.shape[3] != channels:
508
+ InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
509
+ raise InputError()
510
+ if crop_data:
511
+ a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
512
+ a = a.astype(np.float32)
513
+ if not no_scaling:
514
+ a = img_resize(a, z_scale, y_scale, x_scale)
515
+ for c in range(channels):
516
+ a[:,:,:,c] -= np.amin(a[:,:,:,c])
517
+ a[:,:,:,c] /= np.amax(a[:,:,:,c])
518
+ if normalize:
519
+ mean, std = np.mean(a[:,:,:,c]), np.std(a[:,:,:,c])
520
+ a[:,:,:,c] = (a[:,:,:,c] - mean) / std
521
+ a[:,:,:,c] = a[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
522
+ img = np.append(img, a, axis=0)
523
+
524
+ # remove extracted data
525
+ if any(img_list):
526
+ remove_extracted_data(img_names, label_names)
527
+
528
+ # limit intensity range
529
+ img[img<0] = 0
530
+ img[img>1] = 1
531
+
532
+ # get labels
533
+ if allLabels is None:
534
+ allLabels = np.unique(label)
535
+
536
+ # labels must be in ascending order
537
+ for k, l in enumerate(allLabels):
538
+ label[label==l] = k
539
+
540
+ return img, label, allLabels, normalization_parameters, header, extension, channels
541
+
542
+ class CustomCallback(Callback):
543
+ def __init__(self, id, epochs):
544
+ self.epochs = epochs
545
+ self.id = id
546
+
547
+ def on_epoch_begin(self, batch, logs={}):
548
+ self.epoch_time_start = time.time()
549
+
550
+ def on_epoch_end(self, epoch, logs=None):
551
+ import django
552
+ django.setup()
553
+ from biomedisa_app.models import Upload
554
+ image = Upload.objects.get(pk=self.id)
555
+ if image.status == 3:
556
+ self.model.stop_training = True
557
+ else:
558
+ keys = list(logs.keys())
559
+ percentage = round((int(epoch)+1)*100/float(self.epochs))
560
+ t = round(time.time() - self.epoch_time_start) * (self.epochs-int(epoch)-1)
561
+ if t < 3600:
562
+ time_remaining = str(t // 60) + 'min'
563
+ else:
564
+ time_remaining = str(t // 3600) + 'h ' + str((t % 3600) // 60) + 'min'
565
+ try:
566
+ val_accuracy = round(float(logs["val_accuracy"])*100,1)
567
+ image.message = 'Progress {}%, {} remaining, {}% accuracy'.format(percentage,time_remaining,val_accuracy)
568
+ except KeyError:
569
+ image.message = 'Progress {}%, {} remaining'.format(percentage,time_remaining)
570
+ image.save()
571
+ print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
572
+
573
+ class MetaData(Callback):
574
+ def __init__(self, path_to_model, configuration_data, allLabels,
575
+ extension, header, crop_data, cropping_weights, cropping_config,
576
+ normalization_parameters, cropping_norm):
577
+
578
+ self.path_to_model = path_to_model
579
+ self.configuration_data = configuration_data
580
+ self.normalization_parameters = normalization_parameters
581
+ self.allLabels = allLabels
582
+ self.extension = extension
583
+ self.header = header
584
+ self.crop_data = crop_data
585
+ self.cropping_weights = cropping_weights
586
+ self.cropping_config = cropping_config
587
+ self.cropping_norm = cropping_norm
588
+
589
+ def on_epoch_end(self, epoch, logs={}):
590
+ hf = h5py.File(self.path_to_model, 'r')
591
+ if not '/meta' in hf:
592
+ hf.close()
593
+ hf = h5py.File(self.path_to_model, 'r+')
594
+ group = hf.create_group('meta')
595
+ group.create_dataset('configuration', data=self.configuration_data)
596
+ group.create_dataset('normalization', data=self.normalization_parameters)
597
+ group.create_dataset('labels', data=self.allLabels)
598
+ if self.extension == '.am':
599
+ group.create_dataset('extension', data=self.extension)
600
+ group.create_dataset('header', data=self.header)
601
+ if self.crop_data:
602
+ cm_group = hf.create_group('cropping_meta')
603
+ cm_group.create_dataset('configuration', data=self.cropping_config)
604
+ cm_group.create_dataset('normalization', data=self.cropping_norm)
605
+ cw_group = hf.create_group('cropping_weights')
606
+ for iterator, arr in enumerate(self.cropping_weights):
607
+ cw_group.create_dataset(str(iterator), data=arr)
608
+ hf.close()
609
+
610
+ class Metrics(Callback):
611
+ def __init__(self, bm, img, label, list_IDs, dim_img, n_classes, train):
612
+ self.dim_patch = (bm.z_patch, bm.y_patch, bm.x_patch)
613
+ self.dim_img = dim_img
614
+ self.list_IDs = list_IDs
615
+ self.batch_size = bm.batch_size
616
+ self.label = label
617
+ self.img = img
618
+ self.path_to_model = bm.path_to_model
619
+ self.early_stopping = bm.early_stopping
620
+ self.validation_freq = bm.validation_freq
621
+ self.n_classes = n_classes
622
+ self.n_channels = bm.channels
623
+ self.average_dice = bm.average_dice
624
+ self.django_env = bm.django_env
625
+ self.patch_normalization = bm.patch_normalization
626
+ self.train = train
627
+ self.train_dice = bm.train_dice
628
+
629
+ def on_train_begin(self, logs={}):
630
+ self.history = {}
631
+ self.history['val_accuracy'] = []
632
+ self.history['accuracy'] = []
633
+ self.history['val_dice'] = []
634
+ self.history['dice'] = []
635
+ self.history['val_loss'] = []
636
+ self.history['loss'] = []
637
+
638
+ def on_epoch_end(self, epoch, logs={}):
639
+ if epoch % self.validation_freq == 0:
640
+
641
+ result = np.zeros((*self.dim_img, self.n_classes), dtype=np.float32)
642
+
643
+ len_IDs = len(self.list_IDs)
644
+ n_batches = int(np.floor(len_IDs / self.batch_size))
645
+ np.random.shuffle(self.list_IDs)
646
+
647
+ for batch in range(n_batches):
648
+ # Generate indexes of the batch
649
+ list_IDs_batch = self.list_IDs[batch*self.batch_size:(batch+1)*self.batch_size]
650
+
651
+ # Initialization
652
+ X_val = np.empty((self.batch_size, *self.dim_patch, self.n_channels), dtype=np.float32)
653
+
654
+ # Generate data
655
+ for i, ID in enumerate(list_IDs_batch):
656
+
657
+ # get patch indices
658
+ k = ID // (self.dim_img[1]*self.dim_img[2])
659
+ rest = ID % (self.dim_img[1]*self.dim_img[2])
660
+ l = rest // self.dim_img[2]
661
+ m = rest % self.dim_img[2]
662
+ tmp_X = self.img[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]]
663
+ if self.patch_normalization:
664
+ tmp_X = np.copy(tmp_X, order='C')
665
+ for c in range(self.n_channels):
666
+ tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
667
+ tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
668
+ X_val[i] = tmp_X
669
+
670
+ # Prediction segmentation
671
+ y_predict = np.asarray(self.model.predict(X_val, verbose=0, steps=None, batch_size=self.batch_size))
672
+
673
+ for i, ID in enumerate(list_IDs_batch):
674
+
675
+ # get patch indices
676
+ k = ID // (self.dim_img[1]*self.dim_img[2])
677
+ rest = ID % (self.dim_img[1]*self.dim_img[2])
678
+ l = rest // self.dim_img[2]
679
+ m = rest % self.dim_img[2]
680
+ result[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]] += y_predict[i]
681
+
682
+ # calculate categorical crossentropy
683
+ if not self.train:
684
+ crossentropy = categorical_crossentropy(self.label, softmax(result))
685
+
686
+ # get result
687
+ result = np.argmax(result, axis=-1)
688
+ result = result.astype(np.uint8)
689
+
690
+ # calculate standard accuracy
691
+ if not self.train:
692
+ accuracy = np.sum(self.label==result) / float(self.label.size)
693
+
694
+ # calculate dice score
695
+ if self.average_dice:
696
+ dice = 0
697
+ for l in range(1, self.n_classes):
698
+ dice += 2 * np.logical_and(self.label==l, result==l).sum() / float((self.label==l).sum() + (result==l).sum())
699
+ dice /= float(self.n_classes-1)
700
+ else:
701
+ dice = 2 * np.logical_and(self.label==result, (self.label+result)>0).sum() / \
702
+ float((self.label>0).sum() + (result>0).sum())
703
+
704
+ if self.train:
705
+ logs['dice'] = dice
706
+ else:
707
+ # save best model only
708
+ if epoch == 0 or round(dice,4) > max(self.history['val_dice']):
709
+ self.model.save(str(self.path_to_model))
710
+
711
+ # add accuracy to history
712
+ self.history['loss'].append(round(logs['loss'],4))
713
+ self.history['accuracy'].append(round(logs['accuracy'],4))
714
+ if self.train_dice:
715
+ self.history['dice'].append(round(logs['dice'],4))
716
+ self.history['val_accuracy'].append(round(accuracy,4))
717
+ self.history['val_dice'].append(round(dice,4))
718
+ self.history['val_loss'].append(round(crossentropy,4))
719
+
720
+ # tensorflow monitoring variables
721
+ logs['val_loss'] = crossentropy
722
+ logs['val_accuracy'] = accuracy
723
+ logs['val_dice'] = dice
724
+ logs['best_acc'] = max(self.history['accuracy'])
725
+ if self.train_dice:
726
+ logs['best_dice'] = max(self.history['dice'])
727
+ logs['best_val_acc'] = max(self.history['val_accuracy'])
728
+ logs['best_val_dice'] = max(self.history['val_dice'])
729
+
730
+ # plot history in figure and save as numpy array
731
+ save_history(self.history, self.path_to_model, True, self.train_dice)
732
+
733
+ # print accuracies
734
+ print('\nValidation history:')
735
+ print('train_acc:', self.history['accuracy'])
736
+ if self.train_dice:
737
+ print('train_dice:', self.history['dice'])
738
+ print('val_acc:', self.history['val_accuracy'])
739
+ print('val_dice:', self.history['val_dice'])
740
+ print('')
741
+
742
+ # early stopping
743
+ if self.early_stopping > 0 and max(self.history['val_dice']) not in self.history['val_dice'][-self.early_stopping:]:
744
+ self.model.stop_training = True
745
+
746
+ def softmax(x):
747
+ # Avoiding numerical instability by subtracting the maximum value
748
+ exp_values = np.exp(x - np.max(x, axis=-1, keepdims=True))
749
+ probabilities = exp_values / np.sum(exp_values, axis=-1, keepdims=True)
750
+ return probabilities
751
+
752
+ @numba.jit(nopython=True)
753
+ def categorical_crossentropy(true_labels, predicted_probs):
754
+ # Clip predicted probabilities to avoid log(0) issues
755
+ predicted_probs = np.clip(predicted_probs, 1e-7, 1 - 1e-7)
756
+ predicted_probs = -np.log(predicted_probs)
757
+ zsh,ysh,xsh = true_labels.shape
758
+ # Calculate categorical crossentropy
759
+ loss = 0
760
+ for z in range(zsh):
761
+ for y in range(ysh):
762
+ for x in range(xsh):
763
+ l = true_labels[z,y,x]
764
+ loss += predicted_probs[z,y,x,l]
765
+ loss = loss / float(zsh*ysh*xsh)
766
+ return loss
767
+
768
+ def dice_coef(y_true, y_pred, smooth=1e-5):
769
+ intersection = K.sum(Multiply()([y_true, y_pred]))
770
+ dice = (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_pred) + smooth)
771
+ return dice
772
+
773
+ def dice_coef_loss(nb_labels):
774
+ #y_pred_f = K.argmax(y_pred, axis=-1)
775
+ #y_pred_f = K.cast(y_pred_f,'float32')
776
+ #dice_coef(y_true[:,:,:,:,1], y_pred[:,:,:,:,1] * y_pred_f)
777
+ def loss_fn(y_true, y_pred):
778
+ dice = 0
779
+ for index in range(1,nb_labels):
780
+ dice += dice_coef(y_true[:,:,:,:,index], y_pred[:,:,:,:,index])
781
+ dice = dice / (nb_labels-1)
782
+ loss = -K.log(dice)
783
+ #loss = 1 - dice
784
+ return loss
785
+ return loss_fn
786
+
787
+ def train_semantic_segmentation(bm,
788
+ img_list, label_list,
789
+ val_img_list, val_label_list,
790
+ img=None, label=None,
791
+ img_val=None, label_val=None,
792
+ header=None, extension='.tif'):
793
+
794
+ # training data
795
+ img, label, allLabels, normalization_parameters, header, extension, bm.channels = load_training_data(bm.normalize,
796
+ img_list, label_list, None, bm.x_scale, bm.y_scale, bm.z_scale, bm.no_scaling, bm.crop_data,
797
+ bm.only, bm.ignore, img, label, None, None, header, extension)
798
+
799
+ # configuration data
800
+ configuration_data = np.array([bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, 0, 1])
801
+
802
+ # img shape
803
+ zsh, ysh, xsh, _ = img.shape
804
+
805
+ # validation data
806
+ if any(val_img_list) or img_val is not None:
807
+ img_val, label_val, _, _, _, _, _ = load_training_data(bm.normalize,
808
+ val_img_list, val_label_list, bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.no_scaling, bm.crop_data,
809
+ bm.only, bm.ignore, img_val, label_val, normalization_parameters, allLabels)
810
+
811
+ elif bm.validation_split:
812
+ split = round(zsh * bm.validation_split)
813
+ img_val = np.copy(img[split:], order='C')
814
+ label_val = np.copy(label[split:], order='C')
815
+ img = np.copy(img[:split], order='C')
816
+ label = np.copy(label[:split], order='C')
817
+ zsh, ysh, xsh, _ = img.shape
818
+
819
+ # list of IDs
820
+ list_IDs_fg, list_IDs_bg = [], []
821
+
822
+ # get IDs of patches
823
+ if bm.balance:
824
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
825
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
826
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
827
+ if np.any(label[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
828
+ list_IDs_fg.append(k*ysh*xsh+l*xsh+m)
829
+ else:
830
+ list_IDs_bg.append(k*ysh*xsh+l*xsh+m)
831
+ else:
832
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
833
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
834
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
835
+ list_IDs_fg.append(k*ysh*xsh+l*xsh+m)
836
+
837
+ if img_val is not None:
838
+
839
+ # img_val shape
840
+ zsh_val, ysh_val, xsh_val, _ = img_val.shape
841
+
842
+ # list of validation IDs
843
+ list_IDs_val_fg, list_IDs_val_bg = [], []
844
+
845
+ # get validation IDs of patches
846
+ if bm.balance and not bm.val_dice:
847
+ for k in range(0, zsh_val-bm.z_patch+1, bm.validation_stride_size):
848
+ for l in range(0, ysh_val-bm.y_patch+1, bm.validation_stride_size):
849
+ for m in range(0, xsh_val-bm.x_patch+1, bm.validation_stride_size):
850
+ if np.any(label_val[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
851
+ list_IDs_val_fg.append(k*ysh_val*xsh_val+l*xsh_val+m)
852
+ else:
853
+ list_IDs_val_bg.append(k*ysh_val*xsh_val+l*xsh_val+m)
854
+ else:
855
+ for k in range(0, zsh_val-bm.z_patch+1, bm.validation_stride_size):
856
+ for l in range(0, ysh_val-bm.y_patch+1, bm.validation_stride_size):
857
+ for m in range(0, xsh_val-bm.x_patch+1, bm.validation_stride_size):
858
+ list_IDs_val_fg.append(k*ysh_val*xsh_val+l*xsh_val+m)
859
+
860
+ # number of labels
861
+ nb_labels = len(allLabels)
862
+
863
+ # input shape
864
+ input_shape = (bm.z_patch, bm.y_patch, bm.x_patch, bm.channels)
865
+
866
+ # parameters
867
+ params = {'batch_size': bm.batch_size,
868
+ 'dim': (bm.z_patch, bm.y_patch, bm.x_patch),
869
+ 'dim_img': (zsh, ysh, xsh),
870
+ 'n_classes': nb_labels,
871
+ 'n_channels': bm.channels,
872
+ 'augment': (bm.flip_x, bm.flip_y, bm.flip_z, bm.swapaxes, bm.rotate),
873
+ 'patch_normalization': bm.patch_normalization}
874
+
875
+ # data generator
876
+ validation_generator = None
877
+ training_generator = DataGenerator(img, label, list_IDs_fg, list_IDs_bg, True, True, False, **params)
878
+ if img_val is not None:
879
+ if bm.val_dice:
880
+ val_metrics = Metrics(bm, img_val, label_val, list_IDs_val_fg, (zsh_val, ysh_val, xsh_val), nb_labels, False)
881
+ else:
882
+ params['dim_img'] = (zsh_val, ysh_val, xsh_val)
883
+ params['augment'] = (False, False, False, False, 0)
884
+ validation_generator = DataGenerator(img_val, label_val, list_IDs_val_fg, list_IDs_val_bg, True, False, False, **params)
885
+
886
+ # monitor dice score on training data
887
+ if bm.train_dice:
888
+ train_metrics = Metrics(bm, img, label, list_IDs_fg, (zsh, ysh, xsh), nb_labels, True)
889
+
890
+ # create a MirroredStrategy
891
+ cdo = tf.distribute.ReductionToOneDevice()
892
+ strategy = tf.distribute.MirroredStrategy(cross_device_ops=cdo)
893
+ ngpus = int(strategy.num_replicas_in_sync)
894
+ print(f'Number of devices: {ngpus}')
895
+ if ngpus == 1 and os.name == 'nt':
896
+ atexit.register(strategy._extended._collective_ops._pool.close)
897
+
898
+ # compile model
899
+ with strategy.scope():
900
+
901
+ # build model
902
+ model = make_unet(input_shape, nb_labels, bm.network_filters, bm.resnet)
903
+ model.summary()
904
+
905
+ # pretrained model
906
+ if bm.pretrained_model:
907
+ model_pretrained = load_model(bm.pretrained_model)
908
+ model.set_weights(model_pretrained.get_weights())
909
+ if not bm.fine_tune:
910
+ nb_blocks = len(bm.network_filters.split('-'))
911
+ for k in range(nb_blocks+1, 2*nb_blocks):
912
+ for l in [1,2]:
913
+ name = f'conv_{k}_{l}'
914
+ layer = model.get_layer(name)
915
+ layer.trainable = False
916
+ name = f'conv_{2*nb_blocks}_1'
917
+ layer = model.get_layer(name)
918
+ layer.trainable = False
919
+
920
+ # optimizer
921
+ sgd = SGD(learning_rate=bm.learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
922
+
923
+ # comile model
924
+ loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
925
+ model.compile(loss=loss,
926
+ optimizer=sgd,
927
+ metrics=['accuracy'])
928
+
929
+ # save meta data
930
+ meta_data = MetaData(bm.path_to_model, configuration_data, allLabels,
931
+ extension, header, bm.crop_data, bm.cropping_weights, bm.cropping_config,
932
+ normalization_parameters, bm.cropping_norm)
933
+
934
+ # model checkpoint
935
+ if img_val is not None:
936
+ if bm.val_dice:
937
+ callbacks = [val_metrics, meta_data]
938
+ else:
939
+ model_checkpoint_callback = ModelCheckpoint(
940
+ filepath=str(bm.path_to_model),
941
+ save_weights_only=False,
942
+ monitor='val_accuracy',
943
+ mode='max',
944
+ save_best_only=True)
945
+ callbacks = [model_checkpoint_callback, meta_data]
946
+ if bm.early_stopping > 0:
947
+ callbacks.insert(0, EarlyStopping(monitor='val_accuracy', mode='max', patience=bm.early_stopping))
948
+ else:
949
+ callbacks = [ModelCheckpoint(filepath=str(bm.path_to_model)), meta_data]
950
+
951
+ # monitor dice score on training data
952
+ if bm.train_dice:
953
+ callbacks = [train_metrics] + callbacks
954
+
955
+ # custom callback
956
+ if bm.django_env and not bm.remote:
957
+ callbacks.insert(-1, CustomCallback(bm.img_id, bm.epochs))
958
+
959
+ # train model
960
+ history = model.fit(training_generator,
961
+ epochs=bm.epochs,
962
+ validation_data=validation_generator,
963
+ callbacks=callbacks,
964
+ workers=bm.workers)
965
+
966
+ # save monitoring figure on train end
967
+ if img_val is not None and not bm.val_dice:
968
+ save_history(history.history, bm.path_to_model, False, bm.train_dice)
969
+
970
+ def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
971
+ no_scaling, normalize, normalization_parameters, region_of_interest,
972
+ img, img_header):
973
+ # read image data
974
+ if img is None:
975
+ img, img_header = load_data(path_to_img, 'first_queue')
976
+ InputError.img_names = [path_to_img]
977
+ InputError.label_names = []
978
+
979
+ # verify validity
980
+ if img is None:
981
+ InputError.message = f'Invalid image data: {os.path.basename(path_to_img)}.'
982
+ raise InputError()
983
+
984
+ # preserve original image data
985
+ img_data = img.copy()
986
+
987
+ # handle all images having channels >=1
988
+ if len(img.shape)==3:
989
+ z_shape, y_shape, x_shape = img.shape
990
+ img = img.reshape(z_shape, y_shape, x_shape, 1)
991
+ if img.shape[3] != channels:
992
+ InputError.message = f'Number of channels must be {channels}.'
993
+ raise InputError()
994
+
995
+ # image shape
996
+ z_shape, y_shape, x_shape, _ = img.shape
997
+
998
+ # automatic cropping of image to region of interest
999
+ if np.any(region_of_interest):
1000
+ min_z, max_z, min_y, max_y, min_x, max_x = region_of_interest[:]
1001
+ img = np.copy(img[min_z:max_z,min_y:max_y,min_x:max_x], order='C')
1002
+ region_of_interest = np.array([min_z,max_z,min_y,max_y,min_x,max_x,z_shape,y_shape,x_shape])
1003
+ z_shape, y_shape, x_shape = max_z-min_z, max_y-min_y, max_x-min_x
1004
+
1005
+ # scale/resize image data
1006
+ img = img.astype(np.float32)
1007
+ if not no_scaling:
1008
+ img = img_resize(img, z_scale, y_scale, x_scale)
1009
+
1010
+ # normalize image data
1011
+ for c in range(channels):
1012
+ img[:,:,:,c] -= np.amin(img[:,:,:,c])
1013
+ img[:,:,:,c] /= np.amax(img[:,:,:,c])
1014
+ if normalize:
1015
+ mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
1016
+ img[:,:,:,c] = (img[:,:,:,c] - mean) / std
1017
+ img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
1018
+
1019
+ # limit intensity range
1020
+ if normalize:
1021
+ img[img<0] = 0
1022
+ img[img>1] = 1
1023
+
1024
+ return img, img_header, z_shape, y_shape, x_shape, region_of_interest, img_data
1025
+
1026
+ def predict_semantic_segmentation(bm, img, path_to_model,
1027
+ z_patch, y_patch, x_patch, z_shape, y_shape, x_shape, compress, header,
1028
+ img_header, stride_size, allLabels, batch_size, region_of_interest,
1029
+ no_scaling, extension, img_data):
1030
+
1031
+ results = {}
1032
+
1033
+ # img shape
1034
+ zsh, ysh, xsh, csh = img.shape
1035
+
1036
+ # number of labels
1037
+ nb_labels = len(allLabels)
1038
+
1039
+ # list of IDs
1040
+ list_IDs = []
1041
+
1042
+ # get Ids of patches
1043
+ for k in range(0, zsh-z_patch+1, stride_size):
1044
+ for l in range(0, ysh-y_patch+1, stride_size):
1045
+ for m in range(0, xsh-x_patch+1, stride_size):
1046
+ list_IDs.append(k*ysh*xsh+l*xsh+m)
1047
+
1048
+ # make length of list divisible by batch size
1049
+ rest = batch_size - (len(list_IDs) % batch_size)
1050
+ list_IDs = list_IDs + list_IDs[:rest]
1051
+
1052
+ # number of patches
1053
+ nb_patches = len(list_IDs)
1054
+
1055
+ # parameters
1056
+ params = {'dim': (z_patch, y_patch, x_patch),
1057
+ 'dim_img': (zsh, ysh, xsh),
1058
+ 'batch_size': batch_size,
1059
+ 'n_channels': csh,
1060
+ 'patch_normalization': bm.patch_normalization}
1061
+
1062
+ # data generator
1063
+ predict_generator = PredictDataGenerator(img, list_IDs, **params)
1064
+
1065
+ # load model
1066
+ model = load_model(str(path_to_model))
1067
+
1068
+ # predict
1069
+ if nb_patches < 400:
1070
+ probabilities = model.predict(predict_generator, verbose=0, steps=None)
1071
+ else:
1072
+ X = np.empty((batch_size, z_patch, y_patch, x_patch, csh), dtype=np.float32)
1073
+ probabilities = np.zeros((nb_patches, z_patch, y_patch, x_patch, nb_labels), dtype=np.float32)
1074
+
1075
+ # get image patches
1076
+ for step in range(nb_patches//batch_size):
1077
+ for i, ID in enumerate(list_IDs[step*batch_size:(step+1)*batch_size]):
1078
+
1079
+ # get patch indices
1080
+ k = ID // (ysh*xsh)
1081
+ rest = ID % (ysh*xsh)
1082
+ l = rest // xsh
1083
+ m = rest % xsh
1084
+
1085
+ # get patch
1086
+ tmp_X = img[k:k+z_patch,l:l+y_patch,m:m+x_patch]
1087
+ if bm.patch_normalization:
1088
+ tmp_X = np.copy(tmp_X, order='C')
1089
+ for c in range(csh):
1090
+ tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
1091
+ tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
1092
+ X[i] = tmp_X
1093
+
1094
+ probabilities[step*batch_size:(step+1)*batch_size] = model.predict(X, verbose=0, steps=None, batch_size=batch_size)
1095
+
1096
+ # create final
1097
+ final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1098
+ if bm.return_probs:
1099
+ counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1100
+ nb = 0
1101
+ for k in range(0, zsh-z_patch+1, stride_size):
1102
+ for l in range(0, ysh-y_patch+1, stride_size):
1103
+ for m in range(0, xsh-x_patch+1, stride_size):
1104
+ final[k:k+z_patch, l:l+y_patch, m:m+x_patch] += probabilities[nb]
1105
+ if bm.return_probs:
1106
+ counter[k:k+z_patch, l:l+y_patch, m:m+x_patch] += 1
1107
+ nb += 1
1108
+
1109
+ # return probabilities
1110
+ if bm.return_probs:
1111
+ counter[counter==0] = 1
1112
+ probabilities = final / counter
1113
+ if not no_scaling:
1114
+ probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
1115
+ if np.any(region_of_interest):
1116
+ min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
1117
+ tmp = np.zeros((original_zsh, original_ysh, original_xsh, nb_labels), dtype=np.float32)
1118
+ tmp[min_z:max_z,min_y:max_y,min_x:max_x] = probabilities
1119
+ probabilities = np.copy(tmp, order='C')
1120
+ results['probs'] = probabilities
1121
+
1122
+ # get final
1123
+ label = np.argmax(final, axis=3)
1124
+ label = label.astype(np.uint8)
1125
+
1126
+ # rescale final to input size
1127
+ if not no_scaling:
1128
+ label = img_resize(label, z_shape, y_shape, x_shape, labels=True)
1129
+
1130
+ # revert automatic cropping
1131
+ if np.any(region_of_interest):
1132
+ min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
1133
+ tmp = np.zeros((original_zsh, original_ysh, original_xsh), dtype=np.uint8)
1134
+ tmp[min_z:max_z,min_y:max_y,min_x:max_x] = label
1135
+ label = np.copy(tmp, order='C')
1136
+
1137
+ # get result
1138
+ label = get_labels(label, allLabels)
1139
+ results['regular'] = label
1140
+
1141
+ # load header from file
1142
+ if bm.header_file and os.path.exists(bm.header_file):
1143
+ _, header = load_data(bm.header_file)
1144
+ # update file extension
1145
+ if header is not None and bm.path_to_image:
1146
+ extension = os.path.splitext(bm.header_file)[1]
1147
+ if extension == '.gz':
1148
+ extension = '.nii.gz'
1149
+ bm.path_to_final = os.path.splitext(bm.path_to_final)[0] + extension
1150
+ if bm.django_env and not bm.remote and not bm.tmp_dir:
1151
+ from biomedisa_app.views import unique_file_path
1152
+ bm.path_to_final = unique_file_path(bm.path_to_final)
1153
+
1154
+ # handle amira header
1155
+ if header is not None:
1156
+ if extension == '.am':
1157
+ header = get_image_dimensions(header[0], label)
1158
+ if img_header is not None:
1159
+ try:
1160
+ header = get_physical_size(header, img_header[0])
1161
+ except:
1162
+ pass
1163
+ header = [header]
1164
+ else:
1165
+ # build new header
1166
+ if img_header is None:
1167
+ zsh, ysh, xsh = label.shape
1168
+ img_header = sitk.Image(xsh, ysh, zsh, header.GetPixelID())
1169
+ # copy metadata
1170
+ for key in header.GetMetaDataKeys():
1171
+ if not (re.match(r'Segment\d+_Extent$', key) or key=='Segmentation_ConversionParameters'):
1172
+ img_header.SetMetaData(key, header.GetMetaData(key))
1173
+ header = img_header
1174
+ results['header'] = header
1175
+
1176
+ # save result
1177
+ if bm.path_to_image:
1178
+ save_data(bm.path_to_final, label, header=header, compress=compress)
1179
+
1180
+ # paths to optional results
1181
+ filename, extension = os.path.splitext(bm.path_to_final)
1182
+ if extension == '.gz':
1183
+ extension = '.nii.gz'
1184
+ filename = filename[:-4]
1185
+ path_to_cleaned = filename + '.cleaned' + extension
1186
+ path_to_filled = filename + '.filled' + extension
1187
+ path_to_cleaned_filled = filename + '.cleaned.filled' + extension
1188
+ path_to_refined = filename + '.refined' + extension
1189
+ path_to_acwe = filename + '.acwe' + extension
1190
+
1191
+ # remove outliers
1192
+ if bm.clean:
1193
+ cleaned_result = clean(label, bm.clean)
1194
+ results['cleaned'] = cleaned_result
1195
+ if bm.path_to_image:
1196
+ save_data(path_to_cleaned, cleaned_result, header=header, compress=compress)
1197
+ if bm.fill:
1198
+ filled_result = clean(label, bm.fill)
1199
+ results['filled'] = filled_result
1200
+ if bm.path_to_image:
1201
+ save_data(path_to_filled, filled_result, header=header, compress=compress)
1202
+ if bm.clean and bm.fill:
1203
+ cleaned_filled_result = cleaned_result + (filled_result - label)
1204
+ results['cleaned_filled'] = cleaned_filled_result
1205
+ if bm.path_to_image:
1206
+ save_data(path_to_cleaned_filled, cleaned_filled_result, header=header, compress=compress)
1207
+
1208
+ # post-processing with active contour
1209
+ if bm.acwe:
1210
+ acwe_result = activeContour(img_data, label, bm.acwe_alpha, bm.acwe_smooth, bm.acwe_steps)
1211
+ refined_result = activeContour(img_data, label, simple=True)
1212
+ results['acwe'] = acwe_result
1213
+ results['refined'] = refined_result
1214
+ if bm.path_to_image:
1215
+ save_data(path_to_acwe, acwe_result, header=header, compress=compress)
1216
+ save_data(path_to_refined, refined_result, header=header, compress=compress)
1217
+
1218
+ return results, bm
1219
+