biomedisa 2024.5.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- biomedisa/__init__.py +53 -0
- biomedisa/__main__.py +18 -0
- biomedisa/biomedisa_features/DataGenerator.py +299 -0
- biomedisa/biomedisa_features/DataGeneratorCrop.py +121 -0
- biomedisa/biomedisa_features/PredictDataGenerator.py +87 -0
- biomedisa/biomedisa_features/PredictDataGeneratorCrop.py +74 -0
- biomedisa/biomedisa_features/__init__.py +0 -0
- biomedisa/biomedisa_features/active_contour.py +434 -0
- biomedisa/biomedisa_features/amira_to_np/__init__.py +0 -0
- biomedisa/biomedisa_features/amira_to_np/amira_data_stream.py +980 -0
- biomedisa/biomedisa_features/amira_to_np/amira_grammar.py +369 -0
- biomedisa/biomedisa_features/amira_to_np/amira_header.py +290 -0
- biomedisa/biomedisa_features/amira_to_np/amira_helper.py +72 -0
- biomedisa/biomedisa_features/assd.py +167 -0
- biomedisa/biomedisa_features/biomedisa_helper.py +801 -0
- biomedisa/biomedisa_features/create_slices.py +286 -0
- biomedisa/biomedisa_features/crop_helper.py +586 -0
- biomedisa/biomedisa_features/curvop_numba.py +149 -0
- biomedisa/biomedisa_features/django_env.py +172 -0
- biomedisa/biomedisa_features/keras_helper.py +1219 -0
- biomedisa/biomedisa_features/nc_reader.py +179 -0
- biomedisa/biomedisa_features/pid.py +52 -0
- biomedisa/biomedisa_features/process_image.py +253 -0
- biomedisa/biomedisa_features/pycuda_test.py +84 -0
- biomedisa/biomedisa_features/random_walk/__init__.py +0 -0
- biomedisa/biomedisa_features/random_walk/gpu_kernels.py +183 -0
- biomedisa/biomedisa_features/random_walk/pycuda_large.py +826 -0
- biomedisa/biomedisa_features/random_walk/pycuda_large_allx.py +806 -0
- biomedisa/biomedisa_features/random_walk/pycuda_small.py +414 -0
- biomedisa/biomedisa_features/random_walk/pycuda_small_allx.py +493 -0
- biomedisa/biomedisa_features/random_walk/pyopencl_large.py +760 -0
- biomedisa/biomedisa_features/random_walk/pyopencl_small.py +441 -0
- biomedisa/biomedisa_features/random_walk/rw_large.py +390 -0
- biomedisa/biomedisa_features/random_walk/rw_small.py +310 -0
- biomedisa/biomedisa_features/remove_outlier.py +399 -0
- biomedisa/biomedisa_features/split_volume.py +274 -0
- biomedisa/deeplearning.py +519 -0
- biomedisa/interpolation.py +371 -0
- biomedisa/mesh.py +406 -0
- biomedisa-2024.5.14.dist-info/LICENSE +191 -0
- biomedisa-2024.5.14.dist-info/METADATA +306 -0
- biomedisa-2024.5.14.dist-info/RECORD +44 -0
- biomedisa-2024.5.14.dist-info/WHEEL +5 -0
- biomedisa-2024.5.14.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1219 @@
|
|
1
|
+
##########################################################################
|
2
|
+
## ##
|
3
|
+
## Copyright (c) 2024 Philipp Lösel. All rights reserved. ##
|
4
|
+
## ##
|
5
|
+
## This file is part of the open source project biomedisa. ##
|
6
|
+
## ##
|
7
|
+
## Licensed under the European Union Public Licence (EUPL) ##
|
8
|
+
## v1.2, or - as soon as they will be approved by the ##
|
9
|
+
## European Commission - subsequent versions of the EUPL; ##
|
10
|
+
## ##
|
11
|
+
## You may redistribute it and/or modify it under the terms ##
|
12
|
+
## of the EUPL v1.2. You may not use this work except in ##
|
13
|
+
## compliance with this Licence. ##
|
14
|
+
## ##
|
15
|
+
## You can obtain a copy of the Licence at: ##
|
16
|
+
## ##
|
17
|
+
## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
|
18
|
+
## ##
|
19
|
+
## Unless required by applicable law or agreed to in ##
|
20
|
+
## writing, software distributed under the Licence is ##
|
21
|
+
## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
|
22
|
+
## OR CONDITIONS OF ANY KIND, either express or implied. ##
|
23
|
+
## ##
|
24
|
+
## See the Licence for the specific language governing ##
|
25
|
+
## permissions and limitations under the Licence. ##
|
26
|
+
## ##
|
27
|
+
##########################################################################
|
28
|
+
|
29
|
+
import os
|
30
|
+
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
31
|
+
try:
|
32
|
+
from tensorflow.keras.optimizers.legacy import SGD
|
33
|
+
except:
|
34
|
+
from tensorflow.keras.optimizers import SGD
|
35
|
+
from tensorflow.keras.models import Model, load_model
|
36
|
+
from tensorflow.keras.layers import (
|
37
|
+
Input, Conv3D, MaxPooling3D, UpSampling3D, Activation, Reshape,
|
38
|
+
BatchNormalization, Concatenate, ReLU, Add, GlobalAveragePooling3D,
|
39
|
+
Dense, Dropout, MaxPool3D, Flatten, Multiply)
|
40
|
+
from tensorflow.keras import backend as K
|
41
|
+
from tensorflow.keras.utils import to_categorical
|
42
|
+
from tensorflow.keras.callbacks import Callback, ModelCheckpoint, EarlyStopping
|
43
|
+
from biomedisa_features.DataGenerator import DataGenerator
|
44
|
+
from biomedisa_features.PredictDataGenerator import PredictDataGenerator
|
45
|
+
from biomedisa_features.biomedisa_helper import (
|
46
|
+
img_resize, load_data, save_data, set_labels_to_zero, id_generator)
|
47
|
+
from biomedisa_features.remove_outlier import clean, fill
|
48
|
+
from biomedisa_features.active_contour import activeContour
|
49
|
+
import matplotlib.pyplot as plt
|
50
|
+
import SimpleITK as sitk
|
51
|
+
import tensorflow as tf
|
52
|
+
import numpy as np
|
53
|
+
import cv2
|
54
|
+
import tarfile
|
55
|
+
from random import shuffle
|
56
|
+
import glob
|
57
|
+
import random
|
58
|
+
import numba
|
59
|
+
import re
|
60
|
+
import time
|
61
|
+
import h5py
|
62
|
+
import atexit
|
63
|
+
import shutil
|
64
|
+
|
65
|
+
class InputError(Exception):
|
66
|
+
def __init__(self, message=None, img_names=[], label_names=[]):
|
67
|
+
self.message = message
|
68
|
+
self.img_names = img_names
|
69
|
+
self.label_names = label_names
|
70
|
+
|
71
|
+
def save_history(history, path_to_model, val_dice, train_dice):
|
72
|
+
# summarize history for accuracy
|
73
|
+
plt.plot(history['accuracy'])
|
74
|
+
plt.plot(history['val_accuracy'])
|
75
|
+
if val_dice and train_dice:
|
76
|
+
plt.plot(history['dice'])
|
77
|
+
plt.plot(history['val_dice'])
|
78
|
+
plt.legend(['Accuracy (train)', 'Accuracy (test)', 'Dice score (train)', 'Dice score (test)'], loc='upper left')
|
79
|
+
elif train_dice:
|
80
|
+
plt.plot(history['dice'])
|
81
|
+
plt.legend(['Accuracy (train)', 'Accuracy (test)', 'Dice score (train)'], loc='upper left')
|
82
|
+
elif val_dice:
|
83
|
+
plt.plot(history['val_dice'])
|
84
|
+
plt.legend(['Accuracy (train)', 'Accuracy (test)', 'Dice score (test)'], loc='upper left')
|
85
|
+
else:
|
86
|
+
plt.legend(['Accuracy (train)', 'Accuracy (test)'], loc='upper left')
|
87
|
+
plt.title('model accuracy')
|
88
|
+
plt.ylabel('accuracy')
|
89
|
+
plt.xlabel('epoch')
|
90
|
+
plt.tight_layout() # To prevent overlapping of subplots
|
91
|
+
plt.savefig(path_to_model.replace('.h5','_acc.png'), dpi=300, bbox_inches='tight')
|
92
|
+
plt.clf()
|
93
|
+
# summarize history for loss
|
94
|
+
plt.plot(history['loss'])
|
95
|
+
plt.plot(history['val_loss'])
|
96
|
+
plt.title('model loss')
|
97
|
+
plt.ylabel('loss')
|
98
|
+
plt.xlabel('epoch')
|
99
|
+
plt.legend(['train', 'test'], loc='upper left')
|
100
|
+
plt.tight_layout() # To prevent overlapping of subplots
|
101
|
+
plt.savefig(path_to_model.replace('.h5','_loss.png'), dpi=300, bbox_inches='tight')
|
102
|
+
plt.clf()
|
103
|
+
# save history dictonary
|
104
|
+
np.save(path_to_model.replace('.h5','.npy'), history)
|
105
|
+
|
106
|
+
def predict_blocksize(labelData, x_puffer, y_puffer, z_puffer):
|
107
|
+
zsh, ysh, xsh = labelData.shape
|
108
|
+
argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x = zsh, 0, ysh, 0, xsh, 0
|
109
|
+
for k in range(zsh):
|
110
|
+
y, x = np.nonzero(labelData[k])
|
111
|
+
if x.any():
|
112
|
+
argmin_x = min(argmin_x, np.amin(x))
|
113
|
+
argmax_x = max(argmax_x, np.amax(x))
|
114
|
+
argmin_y = min(argmin_y, np.amin(y))
|
115
|
+
argmax_y = max(argmax_y, np.amax(y))
|
116
|
+
argmin_z = min(argmin_z, k)
|
117
|
+
argmax_z = max(argmax_z, k)
|
118
|
+
zmin, zmax = argmin_z, argmax_z
|
119
|
+
argmin_x = argmin_x - x_puffer if argmin_x - x_puffer > 0 else 0
|
120
|
+
argmax_x = argmax_x + x_puffer if argmax_x + x_puffer < xsh else xsh
|
121
|
+
argmin_y = argmin_y - y_puffer if argmin_y - y_puffer > 0 else 0
|
122
|
+
argmax_y = argmax_y + y_puffer if argmax_y + y_puffer < ysh else ysh
|
123
|
+
argmin_z = argmin_z - z_puffer if argmin_z - z_puffer > 0 else 0
|
124
|
+
argmax_z = argmax_z + z_puffer if argmax_z + z_puffer < zsh else zsh
|
125
|
+
return argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x
|
126
|
+
|
127
|
+
def get_image_dimensions(header, data):
|
128
|
+
|
129
|
+
# read header as string
|
130
|
+
b = header.tobytes()
|
131
|
+
try:
|
132
|
+
s = b.decode("utf-8")
|
133
|
+
except:
|
134
|
+
s = b.decode("latin1")
|
135
|
+
|
136
|
+
# get image size in header
|
137
|
+
lattice = re.search('define Lattice (.*)\n', s)
|
138
|
+
lattice = lattice.group(1)
|
139
|
+
xsh, ysh, zsh = lattice.split(' ')
|
140
|
+
xsh, ysh, zsh = int(xsh), int(ysh), int(zsh)
|
141
|
+
|
142
|
+
# new image size
|
143
|
+
z,y,x = data.shape
|
144
|
+
|
145
|
+
# change image size in header
|
146
|
+
s = s.replace('%s %s %s' %(xsh,ysh,zsh), '%s %s %s' %(x,y,z),1)
|
147
|
+
s = s.replace('Content "%sx%sx%s byte' %(xsh,ysh,zsh), 'Content "%sx%sx%s byte' %(x,y,z),1)
|
148
|
+
|
149
|
+
# return header as array
|
150
|
+
b2 = s.encode()
|
151
|
+
new_header = np.frombuffer(b2, dtype=header.dtype)
|
152
|
+
return new_header
|
153
|
+
|
154
|
+
def get_physical_size(header, img_header):
|
155
|
+
|
156
|
+
# read img_header as string
|
157
|
+
b = img_header.tobytes()
|
158
|
+
try:
|
159
|
+
s = b.decode("utf-8")
|
160
|
+
except:
|
161
|
+
s = b.decode("latin1")
|
162
|
+
|
163
|
+
# get physical size from image header
|
164
|
+
bounding_box = re.search('BoundingBox (.*),\n', s)
|
165
|
+
bounding_box = bounding_box.group(1)
|
166
|
+
i0, i1, i2, i3, i4, i5 = bounding_box.split(' ')
|
167
|
+
bounding_box_i = re.search('&BoundingBox (.*),\n', s)
|
168
|
+
bounding_box_i = bounding_box_i.group(1)
|
169
|
+
|
170
|
+
# read label header as string
|
171
|
+
b = header.tobytes()
|
172
|
+
try:
|
173
|
+
s = b.decode("utf-8")
|
174
|
+
except:
|
175
|
+
s = b.decode("latin1")
|
176
|
+
|
177
|
+
# get physical size from label header
|
178
|
+
bounding_box = re.search('BoundingBox (.*),\n', s)
|
179
|
+
bounding_box = bounding_box.group(1)
|
180
|
+
l0, l1, l2, l3, l4, l5 = bounding_box.split(' ')
|
181
|
+
bounding_box_l = re.search('&BoundingBox (.*),\n', s)
|
182
|
+
bounding_box_l = bounding_box_l.group(1)
|
183
|
+
|
184
|
+
# change physical size in label header
|
185
|
+
s = s.replace('%s %s %s %s %s %s' %(l0,l1,l2,l3,l4,l5),'%s %s %s %s %s %s' %(i0,i1,i2,i3,i4,i5),1)
|
186
|
+
s = s.replace(bounding_box_l,bounding_box_i,1)
|
187
|
+
|
188
|
+
# return header as array
|
189
|
+
b2 = s.encode()
|
190
|
+
new_header = np.frombuffer(b2, dtype=header.dtype)
|
191
|
+
return new_header
|
192
|
+
|
193
|
+
@numba.jit(nopython=True)
|
194
|
+
def compute_position(position, zsh, ysh, xsh):
|
195
|
+
zsh_h, ysh_h, xsh_h = zsh//2, ysh//2, xsh//2
|
196
|
+
for k in range(zsh):
|
197
|
+
for l in range(ysh):
|
198
|
+
for m in range(xsh):
|
199
|
+
x = (xsh_h-m)**2
|
200
|
+
y = (ysh_h-l)**2
|
201
|
+
z = (zsh_h-k)**2
|
202
|
+
position[k,l,m] = x+y+z
|
203
|
+
return position
|
204
|
+
|
205
|
+
def make_conv_block(nb_filters, input_tensor, block):
|
206
|
+
def make_stage(input_tensor, stage):
|
207
|
+
name = 'conv_{}_{}'.format(block, stage)
|
208
|
+
x = Conv3D(nb_filters, (3, 3, 3), activation='relu',
|
209
|
+
padding='same', name=name, data_format="channels_last")(input_tensor)
|
210
|
+
name = 'batch_norm_{}_{}'.format(block, stage)
|
211
|
+
try:
|
212
|
+
x = BatchNormalization(name=name, synchronized=True)(x)
|
213
|
+
except:
|
214
|
+
x = BatchNormalization(name=name)(x)
|
215
|
+
x = Activation('relu')(x)
|
216
|
+
return x
|
217
|
+
|
218
|
+
x = make_stage(input_tensor, 1)
|
219
|
+
x = make_stage(x, 2)
|
220
|
+
return x
|
221
|
+
|
222
|
+
def make_conv_block_resnet(nb_filters, input_tensor, block):
|
223
|
+
|
224
|
+
# Residual/Skip connection
|
225
|
+
res = Conv3D(nb_filters, (1, 1, 1), padding='same', use_bias=False, name="Identity{}_1".format(block))(input_tensor)
|
226
|
+
|
227
|
+
stage = 1
|
228
|
+
name = 'conv_{}_{}'.format(block, stage)
|
229
|
+
fx = Conv3D(nb_filters, (3, 3, 3), activation='relu', padding='same', name=name, data_format="channels_last")(input_tensor)
|
230
|
+
name = 'batch_norm_{}_{}'.format(block, stage)
|
231
|
+
try:
|
232
|
+
fx = BatchNormalization(name=name, synchronized=True)(fx)
|
233
|
+
except:
|
234
|
+
fx = BatchNormalization(name=name)(fx)
|
235
|
+
fx = Activation('relu')(fx)
|
236
|
+
|
237
|
+
stage = 2
|
238
|
+
name = 'conv_{}_{}'.format(block, stage)
|
239
|
+
fx = Conv3D(nb_filters, (3, 3, 3), padding='same', name=name, data_format="channels_last")(fx)
|
240
|
+
name = 'batch_norm_{}_{}'.format(block, stage)
|
241
|
+
try:
|
242
|
+
fx = BatchNormalization(name=name, synchronized=True)(fx)
|
243
|
+
except:
|
244
|
+
fx = BatchNormalization(name=name)(fx)
|
245
|
+
|
246
|
+
out = Add()([res,fx])
|
247
|
+
out = ReLU()(out)
|
248
|
+
|
249
|
+
return out
|
250
|
+
|
251
|
+
def make_unet(input_shape, nb_labels, filters='32-64-128-256-512', resnet=False):
|
252
|
+
|
253
|
+
nb_plans, nb_rows, nb_cols, _ = input_shape
|
254
|
+
|
255
|
+
inputs = Input(input_shape)
|
256
|
+
|
257
|
+
filters = filters.split('-')
|
258
|
+
filters = np.array(filters, dtype=int)
|
259
|
+
latent_space_size = filters[-1]
|
260
|
+
filters = filters[:-1]
|
261
|
+
convs = []
|
262
|
+
|
263
|
+
i = 1
|
264
|
+
for f in filters:
|
265
|
+
if i==1:
|
266
|
+
if resnet:
|
267
|
+
conv = make_conv_block_resnet(f, inputs, i)
|
268
|
+
else:
|
269
|
+
conv = make_conv_block(f, inputs, i)
|
270
|
+
else:
|
271
|
+
if resnet:
|
272
|
+
conv = make_conv_block_resnet(f, pool, i)
|
273
|
+
else:
|
274
|
+
conv = make_conv_block(f, pool, i)
|
275
|
+
pool = MaxPooling3D(pool_size=(2, 2, 2))(conv)
|
276
|
+
convs.append(conv)
|
277
|
+
i += 1
|
278
|
+
|
279
|
+
if resnet:
|
280
|
+
conv = make_conv_block_resnet(latent_space_size, pool, i)
|
281
|
+
else:
|
282
|
+
conv = make_conv_block(latent_space_size, pool, i)
|
283
|
+
i += 1
|
284
|
+
|
285
|
+
for k, f in enumerate(filters[::-1]):
|
286
|
+
up = Concatenate()([UpSampling3D(size=(2, 2, 2))(conv), convs[-(k+1)]])
|
287
|
+
if resnet:
|
288
|
+
conv = make_conv_block_resnet(f, up, i)
|
289
|
+
else:
|
290
|
+
conv = make_conv_block(f, up, i)
|
291
|
+
i += 1
|
292
|
+
|
293
|
+
conv = Conv3D(nb_labels, (1, 1, 1), name=f'conv_{i}_1')(conv)
|
294
|
+
|
295
|
+
x = Reshape((nb_plans * nb_rows * nb_cols, nb_labels))(conv)
|
296
|
+
x = Activation('softmax')(x)
|
297
|
+
outputs = Reshape((nb_plans, nb_rows, nb_cols, nb_labels))(x)
|
298
|
+
|
299
|
+
model = Model(inputs=inputs, outputs=outputs)
|
300
|
+
|
301
|
+
return model
|
302
|
+
|
303
|
+
def get_labels(arr, allLabels):
|
304
|
+
np_unique = np.unique(arr)
|
305
|
+
final = np.zeros_like(arr)
|
306
|
+
for k in np_unique:
|
307
|
+
final[arr == k] = allLabels[k]
|
308
|
+
return final
|
309
|
+
|
310
|
+
def read_img_list(img_list, label_list):
|
311
|
+
# read filenames
|
312
|
+
InputError.img_names = []
|
313
|
+
InputError.label_names = []
|
314
|
+
img_names, label_names = [], []
|
315
|
+
for img_name, label_name in zip(img_list, label_list):
|
316
|
+
|
317
|
+
# check for tarball
|
318
|
+
img_dir, img_ext = os.path.splitext(img_name)
|
319
|
+
if img_ext == '.gz':
|
320
|
+
img_dir, img_ext = os.path.splitext(img_dir)
|
321
|
+
|
322
|
+
label_dir, label_ext = os.path.splitext(label_name)
|
323
|
+
if label_ext == '.gz':
|
324
|
+
label_dir, label_ext = os.path.splitext(label_dir)
|
325
|
+
|
326
|
+
if (img_ext == '.tar' and label_ext == '.tar') or (os.path.isdir(img_name) and os.path.isdir(label_name)):
|
327
|
+
|
328
|
+
# extract files
|
329
|
+
if img_ext == '.tar':
|
330
|
+
img_dir = BASE_DIR + '/tmp/' + id_generator(40)
|
331
|
+
tar = tarfile.open(img_name)
|
332
|
+
tar.extractall(path=img_dir)
|
333
|
+
tar.close()
|
334
|
+
img_name = img_dir
|
335
|
+
if label_ext == '.tar':
|
336
|
+
label_dir = BASE_DIR + '/tmp/' + id_generator(40)
|
337
|
+
tar = tarfile.open(label_name)
|
338
|
+
tar.extractall(path=label_dir)
|
339
|
+
tar.close()
|
340
|
+
label_name = label_dir
|
341
|
+
|
342
|
+
for data_type in ['.am','.tif','.tiff','.hdr','.mhd','.mha','.nrrd','.nii','.nii.gz','.zip','.mrc']:
|
343
|
+
img_names += [file for file in glob.glob(img_name+'/**/*'+data_type, recursive=True) if not os.path.basename(file).startswith('.')]
|
344
|
+
label_names += [file for file in glob.glob(label_name+'/**/*'+data_type, recursive=True) if not os.path.basename(file).startswith('.')]
|
345
|
+
img_names = sorted(img_names)
|
346
|
+
label_names = sorted(label_names)
|
347
|
+
if len(img_names)==0 or len(label_names)==0 or len(img_names)!=len(label_names):
|
348
|
+
if img_ext == '.tar' and os.path.exists(img_name):
|
349
|
+
shutil.rmtree(img_name)
|
350
|
+
if label_ext == '.tar' and os.path.exists(label_name):
|
351
|
+
shutil.rmtree(label_name)
|
352
|
+
if len(img_names)!=len(label_names):
|
353
|
+
InputError.message = 'Number of image and label files must be the same'
|
354
|
+
elif img_ext == '.tar':
|
355
|
+
InputError.message = 'Invalid image TAR file'
|
356
|
+
elif label_ext == '.tar':
|
357
|
+
InputError.message = 'Invalid label TAR file'
|
358
|
+
elif len(img_names)==0:
|
359
|
+
InputError.message = 'Invalid image data'
|
360
|
+
else:
|
361
|
+
InputError.message = 'Invalid label data'
|
362
|
+
raise InputError()
|
363
|
+
else:
|
364
|
+
img_names.append(img_name)
|
365
|
+
label_names.append(label_name)
|
366
|
+
return img_names, label_names
|
367
|
+
|
368
|
+
def remove_extracted_data(img_list, label_list):
|
369
|
+
for img_name in img_list:
|
370
|
+
if BASE_DIR + '/tmp/' in img_name:
|
371
|
+
img_dir = img_name[:len(BASE_DIR + '/tmp/') + 40]
|
372
|
+
if os.path.exists(img_dir):
|
373
|
+
shutil.rmtree(img_dir)
|
374
|
+
for label_name in label_list:
|
375
|
+
if BASE_DIR + '/tmp/' in label_name:
|
376
|
+
label_dir = label_name[:len(BASE_DIR + '/tmp/') + 40]
|
377
|
+
if os.path.exists(label_dir):
|
378
|
+
shutil.rmtree(label_dir)
|
379
|
+
|
380
|
+
def load_training_data(normalize, img_list, label_list, channels, x_scale, y_scale, z_scale, no_scaling,
|
381
|
+
crop_data, labels_to_compute, labels_to_remove, img_in=None, label_in=None,
|
382
|
+
normalization_parameters=None, allLabels=None, header=None, extension='.tif',
|
383
|
+
x_puffer=25, y_puffer=25, z_puffer=25):
|
384
|
+
|
385
|
+
# read image lists
|
386
|
+
if any(img_list):
|
387
|
+
img_names, label_names = read_img_list(img_list, label_list)
|
388
|
+
InputError.img_names = img_names
|
389
|
+
InputError.label_names = label_names
|
390
|
+
|
391
|
+
# load first label
|
392
|
+
if any(img_list):
|
393
|
+
label, header, extension = load_data(label_names[0], 'first_queue', True)
|
394
|
+
if label is None:
|
395
|
+
InputError.message = f'Invalid label data "{os.path.basename(label_names[0])}"'
|
396
|
+
raise InputError()
|
397
|
+
elif type(label_in) is list:
|
398
|
+
label = label_in[0]
|
399
|
+
else:
|
400
|
+
label = label_in
|
401
|
+
label_dim = label.shape
|
402
|
+
label = set_labels_to_zero(label, labels_to_compute, labels_to_remove)
|
403
|
+
label_values, counts = np.unique(label, return_counts=True)
|
404
|
+
print(f'{os.path.basename(label_names[0])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
|
405
|
+
if crop_data:
|
406
|
+
argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(label, x_puffer, y_puffer, z_puffer)
|
407
|
+
label = np.copy(label[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
|
408
|
+
if not no_scaling:
|
409
|
+
label = img_resize(label, z_scale, y_scale, x_scale, labels=True)
|
410
|
+
|
411
|
+
# if header is not single data stream Amira Mesh falling back to Multi-TIFF
|
412
|
+
if extension != '.am':
|
413
|
+
if extension != '.tif':
|
414
|
+
print(f'Warning! Please use -hf or --header_file="path_to_training_label{extension}" for prediction to save your result as "{extension}"')
|
415
|
+
extension, header = '.tif', None
|
416
|
+
elif len(header) > 1:
|
417
|
+
print('Warning! Multiple data streams are not supported. Falling back to TIFF.')
|
418
|
+
extension, header = '.tif', None
|
419
|
+
else:
|
420
|
+
header = header[0]
|
421
|
+
|
422
|
+
# load first img
|
423
|
+
if any(img_list):
|
424
|
+
img, _ = load_data(img_names[0], 'first_queue')
|
425
|
+
if img is None:
|
426
|
+
InputError.message = f'Invalid image data "{os.path.basename(img_names[0])}"'
|
427
|
+
raise InputError()
|
428
|
+
elif type(img_in) is list:
|
429
|
+
img = img_in[0]
|
430
|
+
else:
|
431
|
+
img = img_in
|
432
|
+
if label_dim != img.shape:
|
433
|
+
InputError.message = f'Dimensions of "{os.path.basename(img_names[0])}" and "{os.path.basename(label_names[0])}" do not match'
|
434
|
+
raise InputError()
|
435
|
+
|
436
|
+
# ensure images have channels >=1
|
437
|
+
if len(img.shape)==3:
|
438
|
+
z_shape, y_shape, x_shape = img.shape
|
439
|
+
img = img.reshape(z_shape, y_shape, x_shape, 1)
|
440
|
+
if channels is None:
|
441
|
+
channels = img.shape[3]
|
442
|
+
if channels != img.shape[3]:
|
443
|
+
InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[0])}"'
|
444
|
+
raise InputError()
|
445
|
+
|
446
|
+
# crop data
|
447
|
+
if crop_data:
|
448
|
+
img = np.copy(img[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
|
449
|
+
|
450
|
+
# scale/resize image data
|
451
|
+
img = img.astype(np.float32)
|
452
|
+
if not no_scaling:
|
453
|
+
img = img_resize(img, z_scale, y_scale, x_scale)
|
454
|
+
|
455
|
+
# normalize image data
|
456
|
+
for c in range(channels):
|
457
|
+
img[:,:,:,c] -= np.amin(img[:,:,:,c])
|
458
|
+
img[:,:,:,c] /= np.amax(img[:,:,:,c])
|
459
|
+
if normalization_parameters is None:
|
460
|
+
normalization_parameters = np.zeros((2,channels))
|
461
|
+
normalization_parameters[0,c] = np.mean(img[:,:,:,c])
|
462
|
+
normalization_parameters[1,c] = np.std(img[:,:,:,c])
|
463
|
+
elif normalize:
|
464
|
+
mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
|
465
|
+
img[:,:,:,c] = (img[:,:,:,c] - mean) / std
|
466
|
+
img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
|
467
|
+
|
468
|
+
# loop over list of images
|
469
|
+
if any(img_list) or type(img_in) is list:
|
470
|
+
number_of_images = len(img_names) if any(img_list) else len(img_in)
|
471
|
+
|
472
|
+
for k in range(1, number_of_images):
|
473
|
+
|
474
|
+
# append label
|
475
|
+
if any(label_list):
|
476
|
+
a, _ = load_data(label_names[k], 'first_queue')
|
477
|
+
if a is None:
|
478
|
+
InputError.message = f'Invalid label data "{os.path.basename(label_names[k])}"'
|
479
|
+
raise InputError()
|
480
|
+
else:
|
481
|
+
a = label_in[k]
|
482
|
+
label_dim = a.shape
|
483
|
+
a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
|
484
|
+
label_values, counts = np.unique(a, return_counts=True)
|
485
|
+
print(f'{os.path.basename(label_names[k])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
|
486
|
+
if crop_data:
|
487
|
+
argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(a, x_puffer, y_puffer, z_puffer)
|
488
|
+
a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
|
489
|
+
if not no_scaling:
|
490
|
+
a = img_resize(a, z_scale, y_scale, x_scale, labels=True)
|
491
|
+
label = np.append(label, a, axis=0)
|
492
|
+
|
493
|
+
# append image
|
494
|
+
if any(img_list):
|
495
|
+
a, _ = load_data(img_names[k], 'first_queue')
|
496
|
+
if a is None:
|
497
|
+
InputError.message = f'Invalid image data "{os.path.basename(img_names[k])}"'
|
498
|
+
raise InputError()
|
499
|
+
else:
|
500
|
+
a = img_in[k]
|
501
|
+
if label_dim != a.shape:
|
502
|
+
InputError.message = f'Dimensions of "{os.path.basename(img_names[k])}" and "{os.path.basename(label_names[k])}" do not match'
|
503
|
+
raise InputError()
|
504
|
+
if len(a.shape)==3:
|
505
|
+
z_shape, y_shape, x_shape = a.shape
|
506
|
+
a = a.reshape(z_shape, y_shape, x_shape, 1)
|
507
|
+
if a.shape[3] != channels:
|
508
|
+
InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
|
509
|
+
raise InputError()
|
510
|
+
if crop_data:
|
511
|
+
a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
|
512
|
+
a = a.astype(np.float32)
|
513
|
+
if not no_scaling:
|
514
|
+
a = img_resize(a, z_scale, y_scale, x_scale)
|
515
|
+
for c in range(channels):
|
516
|
+
a[:,:,:,c] -= np.amin(a[:,:,:,c])
|
517
|
+
a[:,:,:,c] /= np.amax(a[:,:,:,c])
|
518
|
+
if normalize:
|
519
|
+
mean, std = np.mean(a[:,:,:,c]), np.std(a[:,:,:,c])
|
520
|
+
a[:,:,:,c] = (a[:,:,:,c] - mean) / std
|
521
|
+
a[:,:,:,c] = a[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
|
522
|
+
img = np.append(img, a, axis=0)
|
523
|
+
|
524
|
+
# remove extracted data
|
525
|
+
if any(img_list):
|
526
|
+
remove_extracted_data(img_names, label_names)
|
527
|
+
|
528
|
+
# limit intensity range
|
529
|
+
img[img<0] = 0
|
530
|
+
img[img>1] = 1
|
531
|
+
|
532
|
+
# get labels
|
533
|
+
if allLabels is None:
|
534
|
+
allLabels = np.unique(label)
|
535
|
+
|
536
|
+
# labels must be in ascending order
|
537
|
+
for k, l in enumerate(allLabels):
|
538
|
+
label[label==l] = k
|
539
|
+
|
540
|
+
return img, label, allLabels, normalization_parameters, header, extension, channels
|
541
|
+
|
542
|
+
class CustomCallback(Callback):
|
543
|
+
def __init__(self, id, epochs):
|
544
|
+
self.epochs = epochs
|
545
|
+
self.id = id
|
546
|
+
|
547
|
+
def on_epoch_begin(self, batch, logs={}):
|
548
|
+
self.epoch_time_start = time.time()
|
549
|
+
|
550
|
+
def on_epoch_end(self, epoch, logs=None):
|
551
|
+
import django
|
552
|
+
django.setup()
|
553
|
+
from biomedisa_app.models import Upload
|
554
|
+
image = Upload.objects.get(pk=self.id)
|
555
|
+
if image.status == 3:
|
556
|
+
self.model.stop_training = True
|
557
|
+
else:
|
558
|
+
keys = list(logs.keys())
|
559
|
+
percentage = round((int(epoch)+1)*100/float(self.epochs))
|
560
|
+
t = round(time.time() - self.epoch_time_start) * (self.epochs-int(epoch)-1)
|
561
|
+
if t < 3600:
|
562
|
+
time_remaining = str(t // 60) + 'min'
|
563
|
+
else:
|
564
|
+
time_remaining = str(t // 3600) + 'h ' + str((t % 3600) // 60) + 'min'
|
565
|
+
try:
|
566
|
+
val_accuracy = round(float(logs["val_accuracy"])*100,1)
|
567
|
+
image.message = 'Progress {}%, {} remaining, {}% accuracy'.format(percentage,time_remaining,val_accuracy)
|
568
|
+
except KeyError:
|
569
|
+
image.message = 'Progress {}%, {} remaining'.format(percentage,time_remaining)
|
570
|
+
image.save()
|
571
|
+
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
|
572
|
+
|
573
|
+
class MetaData(Callback):
|
574
|
+
def __init__(self, path_to_model, configuration_data, allLabels,
|
575
|
+
extension, header, crop_data, cropping_weights, cropping_config,
|
576
|
+
normalization_parameters, cropping_norm):
|
577
|
+
|
578
|
+
self.path_to_model = path_to_model
|
579
|
+
self.configuration_data = configuration_data
|
580
|
+
self.normalization_parameters = normalization_parameters
|
581
|
+
self.allLabels = allLabels
|
582
|
+
self.extension = extension
|
583
|
+
self.header = header
|
584
|
+
self.crop_data = crop_data
|
585
|
+
self.cropping_weights = cropping_weights
|
586
|
+
self.cropping_config = cropping_config
|
587
|
+
self.cropping_norm = cropping_norm
|
588
|
+
|
589
|
+
def on_epoch_end(self, epoch, logs={}):
|
590
|
+
hf = h5py.File(self.path_to_model, 'r')
|
591
|
+
if not '/meta' in hf:
|
592
|
+
hf.close()
|
593
|
+
hf = h5py.File(self.path_to_model, 'r+')
|
594
|
+
group = hf.create_group('meta')
|
595
|
+
group.create_dataset('configuration', data=self.configuration_data)
|
596
|
+
group.create_dataset('normalization', data=self.normalization_parameters)
|
597
|
+
group.create_dataset('labels', data=self.allLabels)
|
598
|
+
if self.extension == '.am':
|
599
|
+
group.create_dataset('extension', data=self.extension)
|
600
|
+
group.create_dataset('header', data=self.header)
|
601
|
+
if self.crop_data:
|
602
|
+
cm_group = hf.create_group('cropping_meta')
|
603
|
+
cm_group.create_dataset('configuration', data=self.cropping_config)
|
604
|
+
cm_group.create_dataset('normalization', data=self.cropping_norm)
|
605
|
+
cw_group = hf.create_group('cropping_weights')
|
606
|
+
for iterator, arr in enumerate(self.cropping_weights):
|
607
|
+
cw_group.create_dataset(str(iterator), data=arr)
|
608
|
+
hf.close()
|
609
|
+
|
610
|
+
class Metrics(Callback):
|
611
|
+
def __init__(self, bm, img, label, list_IDs, dim_img, n_classes, train):
|
612
|
+
self.dim_patch = (bm.z_patch, bm.y_patch, bm.x_patch)
|
613
|
+
self.dim_img = dim_img
|
614
|
+
self.list_IDs = list_IDs
|
615
|
+
self.batch_size = bm.batch_size
|
616
|
+
self.label = label
|
617
|
+
self.img = img
|
618
|
+
self.path_to_model = bm.path_to_model
|
619
|
+
self.early_stopping = bm.early_stopping
|
620
|
+
self.validation_freq = bm.validation_freq
|
621
|
+
self.n_classes = n_classes
|
622
|
+
self.n_channels = bm.channels
|
623
|
+
self.average_dice = bm.average_dice
|
624
|
+
self.django_env = bm.django_env
|
625
|
+
self.patch_normalization = bm.patch_normalization
|
626
|
+
self.train = train
|
627
|
+
self.train_dice = bm.train_dice
|
628
|
+
|
629
|
+
def on_train_begin(self, logs={}):
|
630
|
+
self.history = {}
|
631
|
+
self.history['val_accuracy'] = []
|
632
|
+
self.history['accuracy'] = []
|
633
|
+
self.history['val_dice'] = []
|
634
|
+
self.history['dice'] = []
|
635
|
+
self.history['val_loss'] = []
|
636
|
+
self.history['loss'] = []
|
637
|
+
|
638
|
+
def on_epoch_end(self, epoch, logs={}):
|
639
|
+
if epoch % self.validation_freq == 0:
|
640
|
+
|
641
|
+
result = np.zeros((*self.dim_img, self.n_classes), dtype=np.float32)
|
642
|
+
|
643
|
+
len_IDs = len(self.list_IDs)
|
644
|
+
n_batches = int(np.floor(len_IDs / self.batch_size))
|
645
|
+
np.random.shuffle(self.list_IDs)
|
646
|
+
|
647
|
+
for batch in range(n_batches):
|
648
|
+
# Generate indexes of the batch
|
649
|
+
list_IDs_batch = self.list_IDs[batch*self.batch_size:(batch+1)*self.batch_size]
|
650
|
+
|
651
|
+
# Initialization
|
652
|
+
X_val = np.empty((self.batch_size, *self.dim_patch, self.n_channels), dtype=np.float32)
|
653
|
+
|
654
|
+
# Generate data
|
655
|
+
for i, ID in enumerate(list_IDs_batch):
|
656
|
+
|
657
|
+
# get patch indices
|
658
|
+
k = ID // (self.dim_img[1]*self.dim_img[2])
|
659
|
+
rest = ID % (self.dim_img[1]*self.dim_img[2])
|
660
|
+
l = rest // self.dim_img[2]
|
661
|
+
m = rest % self.dim_img[2]
|
662
|
+
tmp_X = self.img[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]]
|
663
|
+
if self.patch_normalization:
|
664
|
+
tmp_X = np.copy(tmp_X, order='C')
|
665
|
+
for c in range(self.n_channels):
|
666
|
+
tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
|
667
|
+
tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
|
668
|
+
X_val[i] = tmp_X
|
669
|
+
|
670
|
+
# Prediction segmentation
|
671
|
+
y_predict = np.asarray(self.model.predict(X_val, verbose=0, steps=None, batch_size=self.batch_size))
|
672
|
+
|
673
|
+
for i, ID in enumerate(list_IDs_batch):
|
674
|
+
|
675
|
+
# get patch indices
|
676
|
+
k = ID // (self.dim_img[1]*self.dim_img[2])
|
677
|
+
rest = ID % (self.dim_img[1]*self.dim_img[2])
|
678
|
+
l = rest // self.dim_img[2]
|
679
|
+
m = rest % self.dim_img[2]
|
680
|
+
result[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]] += y_predict[i]
|
681
|
+
|
682
|
+
# calculate categorical crossentropy
|
683
|
+
if not self.train:
|
684
|
+
crossentropy = categorical_crossentropy(self.label, softmax(result))
|
685
|
+
|
686
|
+
# get result
|
687
|
+
result = np.argmax(result, axis=-1)
|
688
|
+
result = result.astype(np.uint8)
|
689
|
+
|
690
|
+
# calculate standard accuracy
|
691
|
+
if not self.train:
|
692
|
+
accuracy = np.sum(self.label==result) / float(self.label.size)
|
693
|
+
|
694
|
+
# calculate dice score
|
695
|
+
if self.average_dice:
|
696
|
+
dice = 0
|
697
|
+
for l in range(1, self.n_classes):
|
698
|
+
dice += 2 * np.logical_and(self.label==l, result==l).sum() / float((self.label==l).sum() + (result==l).sum())
|
699
|
+
dice /= float(self.n_classes-1)
|
700
|
+
else:
|
701
|
+
dice = 2 * np.logical_and(self.label==result, (self.label+result)>0).sum() / \
|
702
|
+
float((self.label>0).sum() + (result>0).sum())
|
703
|
+
|
704
|
+
if self.train:
|
705
|
+
logs['dice'] = dice
|
706
|
+
else:
|
707
|
+
# save best model only
|
708
|
+
if epoch == 0 or round(dice,4) > max(self.history['val_dice']):
|
709
|
+
self.model.save(str(self.path_to_model))
|
710
|
+
|
711
|
+
# add accuracy to history
|
712
|
+
self.history['loss'].append(round(logs['loss'],4))
|
713
|
+
self.history['accuracy'].append(round(logs['accuracy'],4))
|
714
|
+
if self.train_dice:
|
715
|
+
self.history['dice'].append(round(logs['dice'],4))
|
716
|
+
self.history['val_accuracy'].append(round(accuracy,4))
|
717
|
+
self.history['val_dice'].append(round(dice,4))
|
718
|
+
self.history['val_loss'].append(round(crossentropy,4))
|
719
|
+
|
720
|
+
# tensorflow monitoring variables
|
721
|
+
logs['val_loss'] = crossentropy
|
722
|
+
logs['val_accuracy'] = accuracy
|
723
|
+
logs['val_dice'] = dice
|
724
|
+
logs['best_acc'] = max(self.history['accuracy'])
|
725
|
+
if self.train_dice:
|
726
|
+
logs['best_dice'] = max(self.history['dice'])
|
727
|
+
logs['best_val_acc'] = max(self.history['val_accuracy'])
|
728
|
+
logs['best_val_dice'] = max(self.history['val_dice'])
|
729
|
+
|
730
|
+
# plot history in figure and save as numpy array
|
731
|
+
save_history(self.history, self.path_to_model, True, self.train_dice)
|
732
|
+
|
733
|
+
# print accuracies
|
734
|
+
print('\nValidation history:')
|
735
|
+
print('train_acc:', self.history['accuracy'])
|
736
|
+
if self.train_dice:
|
737
|
+
print('train_dice:', self.history['dice'])
|
738
|
+
print('val_acc:', self.history['val_accuracy'])
|
739
|
+
print('val_dice:', self.history['val_dice'])
|
740
|
+
print('')
|
741
|
+
|
742
|
+
# early stopping
|
743
|
+
if self.early_stopping > 0 and max(self.history['val_dice']) not in self.history['val_dice'][-self.early_stopping:]:
|
744
|
+
self.model.stop_training = True
|
745
|
+
|
746
|
+
def softmax(x):
|
747
|
+
# Avoiding numerical instability by subtracting the maximum value
|
748
|
+
exp_values = np.exp(x - np.max(x, axis=-1, keepdims=True))
|
749
|
+
probabilities = exp_values / np.sum(exp_values, axis=-1, keepdims=True)
|
750
|
+
return probabilities
|
751
|
+
|
752
|
+
@numba.jit(nopython=True)
|
753
|
+
def categorical_crossentropy(true_labels, predicted_probs):
|
754
|
+
# Clip predicted probabilities to avoid log(0) issues
|
755
|
+
predicted_probs = np.clip(predicted_probs, 1e-7, 1 - 1e-7)
|
756
|
+
predicted_probs = -np.log(predicted_probs)
|
757
|
+
zsh,ysh,xsh = true_labels.shape
|
758
|
+
# Calculate categorical crossentropy
|
759
|
+
loss = 0
|
760
|
+
for z in range(zsh):
|
761
|
+
for y in range(ysh):
|
762
|
+
for x in range(xsh):
|
763
|
+
l = true_labels[z,y,x]
|
764
|
+
loss += predicted_probs[z,y,x,l]
|
765
|
+
loss = loss / float(zsh*ysh*xsh)
|
766
|
+
return loss
|
767
|
+
|
768
|
+
def dice_coef(y_true, y_pred, smooth=1e-5):
|
769
|
+
intersection = K.sum(Multiply()([y_true, y_pred]))
|
770
|
+
dice = (2. * intersection + smooth) / (K.sum(y_true) + K.sum(y_pred) + smooth)
|
771
|
+
return dice
|
772
|
+
|
773
|
+
def dice_coef_loss(nb_labels):
|
774
|
+
#y_pred_f = K.argmax(y_pred, axis=-1)
|
775
|
+
#y_pred_f = K.cast(y_pred_f,'float32')
|
776
|
+
#dice_coef(y_true[:,:,:,:,1], y_pred[:,:,:,:,1] * y_pred_f)
|
777
|
+
def loss_fn(y_true, y_pred):
|
778
|
+
dice = 0
|
779
|
+
for index in range(1,nb_labels):
|
780
|
+
dice += dice_coef(y_true[:,:,:,:,index], y_pred[:,:,:,:,index])
|
781
|
+
dice = dice / (nb_labels-1)
|
782
|
+
loss = -K.log(dice)
|
783
|
+
#loss = 1 - dice
|
784
|
+
return loss
|
785
|
+
return loss_fn
|
786
|
+
|
787
|
+
def train_semantic_segmentation(bm,
|
788
|
+
img_list, label_list,
|
789
|
+
val_img_list, val_label_list,
|
790
|
+
img=None, label=None,
|
791
|
+
img_val=None, label_val=None,
|
792
|
+
header=None, extension='.tif'):
|
793
|
+
|
794
|
+
# training data
|
795
|
+
img, label, allLabels, normalization_parameters, header, extension, bm.channels = load_training_data(bm.normalize,
|
796
|
+
img_list, label_list, None, bm.x_scale, bm.y_scale, bm.z_scale, bm.no_scaling, bm.crop_data,
|
797
|
+
bm.only, bm.ignore, img, label, None, None, header, extension)
|
798
|
+
|
799
|
+
# configuration data
|
800
|
+
configuration_data = np.array([bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, 0, 1])
|
801
|
+
|
802
|
+
# img shape
|
803
|
+
zsh, ysh, xsh, _ = img.shape
|
804
|
+
|
805
|
+
# validation data
|
806
|
+
if any(val_img_list) or img_val is not None:
|
807
|
+
img_val, label_val, _, _, _, _, _ = load_training_data(bm.normalize,
|
808
|
+
val_img_list, val_label_list, bm.channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.no_scaling, bm.crop_data,
|
809
|
+
bm.only, bm.ignore, img_val, label_val, normalization_parameters, allLabels)
|
810
|
+
|
811
|
+
elif bm.validation_split:
|
812
|
+
split = round(zsh * bm.validation_split)
|
813
|
+
img_val = np.copy(img[split:], order='C')
|
814
|
+
label_val = np.copy(label[split:], order='C')
|
815
|
+
img = np.copy(img[:split], order='C')
|
816
|
+
label = np.copy(label[:split], order='C')
|
817
|
+
zsh, ysh, xsh, _ = img.shape
|
818
|
+
|
819
|
+
# list of IDs
|
820
|
+
list_IDs_fg, list_IDs_bg = [], []
|
821
|
+
|
822
|
+
# get IDs of patches
|
823
|
+
if bm.balance:
|
824
|
+
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
825
|
+
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
826
|
+
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
827
|
+
if np.any(label[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
|
828
|
+
list_IDs_fg.append(k*ysh*xsh+l*xsh+m)
|
829
|
+
else:
|
830
|
+
list_IDs_bg.append(k*ysh*xsh+l*xsh+m)
|
831
|
+
else:
|
832
|
+
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
833
|
+
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
834
|
+
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
835
|
+
list_IDs_fg.append(k*ysh*xsh+l*xsh+m)
|
836
|
+
|
837
|
+
if img_val is not None:
|
838
|
+
|
839
|
+
# img_val shape
|
840
|
+
zsh_val, ysh_val, xsh_val, _ = img_val.shape
|
841
|
+
|
842
|
+
# list of validation IDs
|
843
|
+
list_IDs_val_fg, list_IDs_val_bg = [], []
|
844
|
+
|
845
|
+
# get validation IDs of patches
|
846
|
+
if bm.balance and not bm.val_dice:
|
847
|
+
for k in range(0, zsh_val-bm.z_patch+1, bm.validation_stride_size):
|
848
|
+
for l in range(0, ysh_val-bm.y_patch+1, bm.validation_stride_size):
|
849
|
+
for m in range(0, xsh_val-bm.x_patch+1, bm.validation_stride_size):
|
850
|
+
if np.any(label_val[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
|
851
|
+
list_IDs_val_fg.append(k*ysh_val*xsh_val+l*xsh_val+m)
|
852
|
+
else:
|
853
|
+
list_IDs_val_bg.append(k*ysh_val*xsh_val+l*xsh_val+m)
|
854
|
+
else:
|
855
|
+
for k in range(0, zsh_val-bm.z_patch+1, bm.validation_stride_size):
|
856
|
+
for l in range(0, ysh_val-bm.y_patch+1, bm.validation_stride_size):
|
857
|
+
for m in range(0, xsh_val-bm.x_patch+1, bm.validation_stride_size):
|
858
|
+
list_IDs_val_fg.append(k*ysh_val*xsh_val+l*xsh_val+m)
|
859
|
+
|
860
|
+
# number of labels
|
861
|
+
nb_labels = len(allLabels)
|
862
|
+
|
863
|
+
# input shape
|
864
|
+
input_shape = (bm.z_patch, bm.y_patch, bm.x_patch, bm.channels)
|
865
|
+
|
866
|
+
# parameters
|
867
|
+
params = {'batch_size': bm.batch_size,
|
868
|
+
'dim': (bm.z_patch, bm.y_patch, bm.x_patch),
|
869
|
+
'dim_img': (zsh, ysh, xsh),
|
870
|
+
'n_classes': nb_labels,
|
871
|
+
'n_channels': bm.channels,
|
872
|
+
'augment': (bm.flip_x, bm.flip_y, bm.flip_z, bm.swapaxes, bm.rotate),
|
873
|
+
'patch_normalization': bm.patch_normalization}
|
874
|
+
|
875
|
+
# data generator
|
876
|
+
validation_generator = None
|
877
|
+
training_generator = DataGenerator(img, label, list_IDs_fg, list_IDs_bg, True, True, False, **params)
|
878
|
+
if img_val is not None:
|
879
|
+
if bm.val_dice:
|
880
|
+
val_metrics = Metrics(bm, img_val, label_val, list_IDs_val_fg, (zsh_val, ysh_val, xsh_val), nb_labels, False)
|
881
|
+
else:
|
882
|
+
params['dim_img'] = (zsh_val, ysh_val, xsh_val)
|
883
|
+
params['augment'] = (False, False, False, False, 0)
|
884
|
+
validation_generator = DataGenerator(img_val, label_val, list_IDs_val_fg, list_IDs_val_bg, True, False, False, **params)
|
885
|
+
|
886
|
+
# monitor dice score on training data
|
887
|
+
if bm.train_dice:
|
888
|
+
train_metrics = Metrics(bm, img, label, list_IDs_fg, (zsh, ysh, xsh), nb_labels, True)
|
889
|
+
|
890
|
+
# create a MirroredStrategy
|
891
|
+
cdo = tf.distribute.ReductionToOneDevice()
|
892
|
+
strategy = tf.distribute.MirroredStrategy(cross_device_ops=cdo)
|
893
|
+
ngpus = int(strategy.num_replicas_in_sync)
|
894
|
+
print(f'Number of devices: {ngpus}')
|
895
|
+
if ngpus == 1 and os.name == 'nt':
|
896
|
+
atexit.register(strategy._extended._collective_ops._pool.close)
|
897
|
+
|
898
|
+
# compile model
|
899
|
+
with strategy.scope():
|
900
|
+
|
901
|
+
# build model
|
902
|
+
model = make_unet(input_shape, nb_labels, bm.network_filters, bm.resnet)
|
903
|
+
model.summary()
|
904
|
+
|
905
|
+
# pretrained model
|
906
|
+
if bm.pretrained_model:
|
907
|
+
model_pretrained = load_model(bm.pretrained_model)
|
908
|
+
model.set_weights(model_pretrained.get_weights())
|
909
|
+
if not bm.fine_tune:
|
910
|
+
nb_blocks = len(bm.network_filters.split('-'))
|
911
|
+
for k in range(nb_blocks+1, 2*nb_blocks):
|
912
|
+
for l in [1,2]:
|
913
|
+
name = f'conv_{k}_{l}'
|
914
|
+
layer = model.get_layer(name)
|
915
|
+
layer.trainable = False
|
916
|
+
name = f'conv_{2*nb_blocks}_1'
|
917
|
+
layer = model.get_layer(name)
|
918
|
+
layer.trainable = False
|
919
|
+
|
920
|
+
# optimizer
|
921
|
+
sgd = SGD(learning_rate=bm.learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
|
922
|
+
|
923
|
+
# comile model
|
924
|
+
loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
|
925
|
+
model.compile(loss=loss,
|
926
|
+
optimizer=sgd,
|
927
|
+
metrics=['accuracy'])
|
928
|
+
|
929
|
+
# save meta data
|
930
|
+
meta_data = MetaData(bm.path_to_model, configuration_data, allLabels,
|
931
|
+
extension, header, bm.crop_data, bm.cropping_weights, bm.cropping_config,
|
932
|
+
normalization_parameters, bm.cropping_norm)
|
933
|
+
|
934
|
+
# model checkpoint
|
935
|
+
if img_val is not None:
|
936
|
+
if bm.val_dice:
|
937
|
+
callbacks = [val_metrics, meta_data]
|
938
|
+
else:
|
939
|
+
model_checkpoint_callback = ModelCheckpoint(
|
940
|
+
filepath=str(bm.path_to_model),
|
941
|
+
save_weights_only=False,
|
942
|
+
monitor='val_accuracy',
|
943
|
+
mode='max',
|
944
|
+
save_best_only=True)
|
945
|
+
callbacks = [model_checkpoint_callback, meta_data]
|
946
|
+
if bm.early_stopping > 0:
|
947
|
+
callbacks.insert(0, EarlyStopping(monitor='val_accuracy', mode='max', patience=bm.early_stopping))
|
948
|
+
else:
|
949
|
+
callbacks = [ModelCheckpoint(filepath=str(bm.path_to_model)), meta_data]
|
950
|
+
|
951
|
+
# monitor dice score on training data
|
952
|
+
if bm.train_dice:
|
953
|
+
callbacks = [train_metrics] + callbacks
|
954
|
+
|
955
|
+
# custom callback
|
956
|
+
if bm.django_env and not bm.remote:
|
957
|
+
callbacks.insert(-1, CustomCallback(bm.img_id, bm.epochs))
|
958
|
+
|
959
|
+
# train model
|
960
|
+
history = model.fit(training_generator,
|
961
|
+
epochs=bm.epochs,
|
962
|
+
validation_data=validation_generator,
|
963
|
+
callbacks=callbacks,
|
964
|
+
workers=bm.workers)
|
965
|
+
|
966
|
+
# save monitoring figure on train end
|
967
|
+
if img_val is not None and not bm.val_dice:
|
968
|
+
save_history(history.history, bm.path_to_model, False, bm.train_dice)
|
969
|
+
|
970
|
+
def load_prediction_data(path_to_img, channels, x_scale, y_scale, z_scale,
|
971
|
+
no_scaling, normalize, normalization_parameters, region_of_interest,
|
972
|
+
img, img_header):
|
973
|
+
# read image data
|
974
|
+
if img is None:
|
975
|
+
img, img_header = load_data(path_to_img, 'first_queue')
|
976
|
+
InputError.img_names = [path_to_img]
|
977
|
+
InputError.label_names = []
|
978
|
+
|
979
|
+
# verify validity
|
980
|
+
if img is None:
|
981
|
+
InputError.message = f'Invalid image data: {os.path.basename(path_to_img)}.'
|
982
|
+
raise InputError()
|
983
|
+
|
984
|
+
# preserve original image data
|
985
|
+
img_data = img.copy()
|
986
|
+
|
987
|
+
# handle all images having channels >=1
|
988
|
+
if len(img.shape)==3:
|
989
|
+
z_shape, y_shape, x_shape = img.shape
|
990
|
+
img = img.reshape(z_shape, y_shape, x_shape, 1)
|
991
|
+
if img.shape[3] != channels:
|
992
|
+
InputError.message = f'Number of channels must be {channels}.'
|
993
|
+
raise InputError()
|
994
|
+
|
995
|
+
# image shape
|
996
|
+
z_shape, y_shape, x_shape, _ = img.shape
|
997
|
+
|
998
|
+
# automatic cropping of image to region of interest
|
999
|
+
if np.any(region_of_interest):
|
1000
|
+
min_z, max_z, min_y, max_y, min_x, max_x = region_of_interest[:]
|
1001
|
+
img = np.copy(img[min_z:max_z,min_y:max_y,min_x:max_x], order='C')
|
1002
|
+
region_of_interest = np.array([min_z,max_z,min_y,max_y,min_x,max_x,z_shape,y_shape,x_shape])
|
1003
|
+
z_shape, y_shape, x_shape = max_z-min_z, max_y-min_y, max_x-min_x
|
1004
|
+
|
1005
|
+
# scale/resize image data
|
1006
|
+
img = img.astype(np.float32)
|
1007
|
+
if not no_scaling:
|
1008
|
+
img = img_resize(img, z_scale, y_scale, x_scale)
|
1009
|
+
|
1010
|
+
# normalize image data
|
1011
|
+
for c in range(channels):
|
1012
|
+
img[:,:,:,c] -= np.amin(img[:,:,:,c])
|
1013
|
+
img[:,:,:,c] /= np.amax(img[:,:,:,c])
|
1014
|
+
if normalize:
|
1015
|
+
mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
|
1016
|
+
img[:,:,:,c] = (img[:,:,:,c] - mean) / std
|
1017
|
+
img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
|
1018
|
+
|
1019
|
+
# limit intensity range
|
1020
|
+
if normalize:
|
1021
|
+
img[img<0] = 0
|
1022
|
+
img[img>1] = 1
|
1023
|
+
|
1024
|
+
return img, img_header, z_shape, y_shape, x_shape, region_of_interest, img_data
|
1025
|
+
|
1026
|
+
def predict_semantic_segmentation(bm, img, path_to_model,
|
1027
|
+
z_patch, y_patch, x_patch, z_shape, y_shape, x_shape, compress, header,
|
1028
|
+
img_header, stride_size, allLabels, batch_size, region_of_interest,
|
1029
|
+
no_scaling, extension, img_data):
|
1030
|
+
|
1031
|
+
results = {}
|
1032
|
+
|
1033
|
+
# img shape
|
1034
|
+
zsh, ysh, xsh, csh = img.shape
|
1035
|
+
|
1036
|
+
# number of labels
|
1037
|
+
nb_labels = len(allLabels)
|
1038
|
+
|
1039
|
+
# list of IDs
|
1040
|
+
list_IDs = []
|
1041
|
+
|
1042
|
+
# get Ids of patches
|
1043
|
+
for k in range(0, zsh-z_patch+1, stride_size):
|
1044
|
+
for l in range(0, ysh-y_patch+1, stride_size):
|
1045
|
+
for m in range(0, xsh-x_patch+1, stride_size):
|
1046
|
+
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1047
|
+
|
1048
|
+
# make length of list divisible by batch size
|
1049
|
+
rest = batch_size - (len(list_IDs) % batch_size)
|
1050
|
+
list_IDs = list_IDs + list_IDs[:rest]
|
1051
|
+
|
1052
|
+
# number of patches
|
1053
|
+
nb_patches = len(list_IDs)
|
1054
|
+
|
1055
|
+
# parameters
|
1056
|
+
params = {'dim': (z_patch, y_patch, x_patch),
|
1057
|
+
'dim_img': (zsh, ysh, xsh),
|
1058
|
+
'batch_size': batch_size,
|
1059
|
+
'n_channels': csh,
|
1060
|
+
'patch_normalization': bm.patch_normalization}
|
1061
|
+
|
1062
|
+
# data generator
|
1063
|
+
predict_generator = PredictDataGenerator(img, list_IDs, **params)
|
1064
|
+
|
1065
|
+
# load model
|
1066
|
+
model = load_model(str(path_to_model))
|
1067
|
+
|
1068
|
+
# predict
|
1069
|
+
if nb_patches < 400:
|
1070
|
+
probabilities = model.predict(predict_generator, verbose=0, steps=None)
|
1071
|
+
else:
|
1072
|
+
X = np.empty((batch_size, z_patch, y_patch, x_patch, csh), dtype=np.float32)
|
1073
|
+
probabilities = np.zeros((nb_patches, z_patch, y_patch, x_patch, nb_labels), dtype=np.float32)
|
1074
|
+
|
1075
|
+
# get image patches
|
1076
|
+
for step in range(nb_patches//batch_size):
|
1077
|
+
for i, ID in enumerate(list_IDs[step*batch_size:(step+1)*batch_size]):
|
1078
|
+
|
1079
|
+
# get patch indices
|
1080
|
+
k = ID // (ysh*xsh)
|
1081
|
+
rest = ID % (ysh*xsh)
|
1082
|
+
l = rest // xsh
|
1083
|
+
m = rest % xsh
|
1084
|
+
|
1085
|
+
# get patch
|
1086
|
+
tmp_X = img[k:k+z_patch,l:l+y_patch,m:m+x_patch]
|
1087
|
+
if bm.patch_normalization:
|
1088
|
+
tmp_X = np.copy(tmp_X, order='C')
|
1089
|
+
for c in range(csh):
|
1090
|
+
tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
|
1091
|
+
tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
|
1092
|
+
X[i] = tmp_X
|
1093
|
+
|
1094
|
+
probabilities[step*batch_size:(step+1)*batch_size] = model.predict(X, verbose=0, steps=None, batch_size=batch_size)
|
1095
|
+
|
1096
|
+
# create final
|
1097
|
+
final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
|
1098
|
+
if bm.return_probs:
|
1099
|
+
counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
|
1100
|
+
nb = 0
|
1101
|
+
for k in range(0, zsh-z_patch+1, stride_size):
|
1102
|
+
for l in range(0, ysh-y_patch+1, stride_size):
|
1103
|
+
for m in range(0, xsh-x_patch+1, stride_size):
|
1104
|
+
final[k:k+z_patch, l:l+y_patch, m:m+x_patch] += probabilities[nb]
|
1105
|
+
if bm.return_probs:
|
1106
|
+
counter[k:k+z_patch, l:l+y_patch, m:m+x_patch] += 1
|
1107
|
+
nb += 1
|
1108
|
+
|
1109
|
+
# return probabilities
|
1110
|
+
if bm.return_probs:
|
1111
|
+
counter[counter==0] = 1
|
1112
|
+
probabilities = final / counter
|
1113
|
+
if not no_scaling:
|
1114
|
+
probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
|
1115
|
+
if np.any(region_of_interest):
|
1116
|
+
min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
|
1117
|
+
tmp = np.zeros((original_zsh, original_ysh, original_xsh, nb_labels), dtype=np.float32)
|
1118
|
+
tmp[min_z:max_z,min_y:max_y,min_x:max_x] = probabilities
|
1119
|
+
probabilities = np.copy(tmp, order='C')
|
1120
|
+
results['probs'] = probabilities
|
1121
|
+
|
1122
|
+
# get final
|
1123
|
+
label = np.argmax(final, axis=3)
|
1124
|
+
label = label.astype(np.uint8)
|
1125
|
+
|
1126
|
+
# rescale final to input size
|
1127
|
+
if not no_scaling:
|
1128
|
+
label = img_resize(label, z_shape, y_shape, x_shape, labels=True)
|
1129
|
+
|
1130
|
+
# revert automatic cropping
|
1131
|
+
if np.any(region_of_interest):
|
1132
|
+
min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
|
1133
|
+
tmp = np.zeros((original_zsh, original_ysh, original_xsh), dtype=np.uint8)
|
1134
|
+
tmp[min_z:max_z,min_y:max_y,min_x:max_x] = label
|
1135
|
+
label = np.copy(tmp, order='C')
|
1136
|
+
|
1137
|
+
# get result
|
1138
|
+
label = get_labels(label, allLabels)
|
1139
|
+
results['regular'] = label
|
1140
|
+
|
1141
|
+
# load header from file
|
1142
|
+
if bm.header_file and os.path.exists(bm.header_file):
|
1143
|
+
_, header = load_data(bm.header_file)
|
1144
|
+
# update file extension
|
1145
|
+
if header is not None and bm.path_to_image:
|
1146
|
+
extension = os.path.splitext(bm.header_file)[1]
|
1147
|
+
if extension == '.gz':
|
1148
|
+
extension = '.nii.gz'
|
1149
|
+
bm.path_to_final = os.path.splitext(bm.path_to_final)[0] + extension
|
1150
|
+
if bm.django_env and not bm.remote and not bm.tmp_dir:
|
1151
|
+
from biomedisa_app.views import unique_file_path
|
1152
|
+
bm.path_to_final = unique_file_path(bm.path_to_final)
|
1153
|
+
|
1154
|
+
# handle amira header
|
1155
|
+
if header is not None:
|
1156
|
+
if extension == '.am':
|
1157
|
+
header = get_image_dimensions(header[0], label)
|
1158
|
+
if img_header is not None:
|
1159
|
+
try:
|
1160
|
+
header = get_physical_size(header, img_header[0])
|
1161
|
+
except:
|
1162
|
+
pass
|
1163
|
+
header = [header]
|
1164
|
+
else:
|
1165
|
+
# build new header
|
1166
|
+
if img_header is None:
|
1167
|
+
zsh, ysh, xsh = label.shape
|
1168
|
+
img_header = sitk.Image(xsh, ysh, zsh, header.GetPixelID())
|
1169
|
+
# copy metadata
|
1170
|
+
for key in header.GetMetaDataKeys():
|
1171
|
+
if not (re.match(r'Segment\d+_Extent$', key) or key=='Segmentation_ConversionParameters'):
|
1172
|
+
img_header.SetMetaData(key, header.GetMetaData(key))
|
1173
|
+
header = img_header
|
1174
|
+
results['header'] = header
|
1175
|
+
|
1176
|
+
# save result
|
1177
|
+
if bm.path_to_image:
|
1178
|
+
save_data(bm.path_to_final, label, header=header, compress=compress)
|
1179
|
+
|
1180
|
+
# paths to optional results
|
1181
|
+
filename, extension = os.path.splitext(bm.path_to_final)
|
1182
|
+
if extension == '.gz':
|
1183
|
+
extension = '.nii.gz'
|
1184
|
+
filename = filename[:-4]
|
1185
|
+
path_to_cleaned = filename + '.cleaned' + extension
|
1186
|
+
path_to_filled = filename + '.filled' + extension
|
1187
|
+
path_to_cleaned_filled = filename + '.cleaned.filled' + extension
|
1188
|
+
path_to_refined = filename + '.refined' + extension
|
1189
|
+
path_to_acwe = filename + '.acwe' + extension
|
1190
|
+
|
1191
|
+
# remove outliers
|
1192
|
+
if bm.clean:
|
1193
|
+
cleaned_result = clean(label, bm.clean)
|
1194
|
+
results['cleaned'] = cleaned_result
|
1195
|
+
if bm.path_to_image:
|
1196
|
+
save_data(path_to_cleaned, cleaned_result, header=header, compress=compress)
|
1197
|
+
if bm.fill:
|
1198
|
+
filled_result = clean(label, bm.fill)
|
1199
|
+
results['filled'] = filled_result
|
1200
|
+
if bm.path_to_image:
|
1201
|
+
save_data(path_to_filled, filled_result, header=header, compress=compress)
|
1202
|
+
if bm.clean and bm.fill:
|
1203
|
+
cleaned_filled_result = cleaned_result + (filled_result - label)
|
1204
|
+
results['cleaned_filled'] = cleaned_filled_result
|
1205
|
+
if bm.path_to_image:
|
1206
|
+
save_data(path_to_cleaned_filled, cleaned_filled_result, header=header, compress=compress)
|
1207
|
+
|
1208
|
+
# post-processing with active contour
|
1209
|
+
if bm.acwe:
|
1210
|
+
acwe_result = activeContour(img_data, label, bm.acwe_alpha, bm.acwe_smooth, bm.acwe_steps)
|
1211
|
+
refined_result = activeContour(img_data, label, simple=True)
|
1212
|
+
results['acwe'] = acwe_result
|
1213
|
+
results['refined'] = refined_result
|
1214
|
+
if bm.path_to_image:
|
1215
|
+
save_data(path_to_acwe, acwe_result, header=header, compress=compress)
|
1216
|
+
save_data(path_to_refined, refined_result, header=header, compress=compress)
|
1217
|
+
|
1218
|
+
return results, bm
|
1219
|
+
|