biomedisa 2024.5.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- biomedisa/__init__.py +53 -0
- biomedisa/__main__.py +18 -0
- biomedisa/biomedisa_features/DataGenerator.py +299 -0
- biomedisa/biomedisa_features/DataGeneratorCrop.py +121 -0
- biomedisa/biomedisa_features/PredictDataGenerator.py +87 -0
- biomedisa/biomedisa_features/PredictDataGeneratorCrop.py +74 -0
- biomedisa/biomedisa_features/__init__.py +0 -0
- biomedisa/biomedisa_features/active_contour.py +434 -0
- biomedisa/biomedisa_features/amira_to_np/__init__.py +0 -0
- biomedisa/biomedisa_features/amira_to_np/amira_data_stream.py +980 -0
- biomedisa/biomedisa_features/amira_to_np/amira_grammar.py +369 -0
- biomedisa/biomedisa_features/amira_to_np/amira_header.py +290 -0
- biomedisa/biomedisa_features/amira_to_np/amira_helper.py +72 -0
- biomedisa/biomedisa_features/assd.py +167 -0
- biomedisa/biomedisa_features/biomedisa_helper.py +801 -0
- biomedisa/biomedisa_features/create_slices.py +286 -0
- biomedisa/biomedisa_features/crop_helper.py +586 -0
- biomedisa/biomedisa_features/curvop_numba.py +149 -0
- biomedisa/biomedisa_features/django_env.py +172 -0
- biomedisa/biomedisa_features/keras_helper.py +1219 -0
- biomedisa/biomedisa_features/nc_reader.py +179 -0
- biomedisa/biomedisa_features/pid.py +52 -0
- biomedisa/biomedisa_features/process_image.py +253 -0
- biomedisa/biomedisa_features/pycuda_test.py +84 -0
- biomedisa/biomedisa_features/random_walk/__init__.py +0 -0
- biomedisa/biomedisa_features/random_walk/gpu_kernels.py +183 -0
- biomedisa/biomedisa_features/random_walk/pycuda_large.py +826 -0
- biomedisa/biomedisa_features/random_walk/pycuda_large_allx.py +806 -0
- biomedisa/biomedisa_features/random_walk/pycuda_small.py +414 -0
- biomedisa/biomedisa_features/random_walk/pycuda_small_allx.py +493 -0
- biomedisa/biomedisa_features/random_walk/pyopencl_large.py +760 -0
- biomedisa/biomedisa_features/random_walk/pyopencl_small.py +441 -0
- biomedisa/biomedisa_features/random_walk/rw_large.py +390 -0
- biomedisa/biomedisa_features/random_walk/rw_small.py +310 -0
- biomedisa/biomedisa_features/remove_outlier.py +399 -0
- biomedisa/biomedisa_features/split_volume.py +274 -0
- biomedisa/deeplearning.py +519 -0
- biomedisa/interpolation.py +371 -0
- biomedisa/mesh.py +406 -0
- biomedisa-2024.5.14.dist-info/LICENSE +191 -0
- biomedisa-2024.5.14.dist-info/METADATA +306 -0
- biomedisa-2024.5.14.dist-info/RECORD +44 -0
- biomedisa-2024.5.14.dist-info/WHEEL +5 -0
- biomedisa-2024.5.14.dist-info/top_level.txt +1 -0
@@ -0,0 +1,586 @@
|
|
1
|
+
##########################################################################
|
2
|
+
## ##
|
3
|
+
## Copyright (c) 2024 Philipp Lösel. All rights reserved. ##
|
4
|
+
## ##
|
5
|
+
## This file is part of the open source project biomedisa. ##
|
6
|
+
## ##
|
7
|
+
## Licensed under the European Union Public Licence (EUPL) ##
|
8
|
+
## v1.2, or - as soon as they will be approved by the ##
|
9
|
+
## European Commission - subsequent versions of the EUPL; ##
|
10
|
+
## ##
|
11
|
+
## You may redistribute it and/or modify it under the terms ##
|
12
|
+
## of the EUPL v1.2. You may not use this work except in ##
|
13
|
+
## compliance with this Licence. ##
|
14
|
+
## ##
|
15
|
+
## You can obtain a copy of the Licence at: ##
|
16
|
+
## ##
|
17
|
+
## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
|
18
|
+
## ##
|
19
|
+
## Unless required by applicable law or agreed to in ##
|
20
|
+
## writing, software distributed under the Licence is ##
|
21
|
+
## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
|
22
|
+
## OR CONDITIONS OF ANY KIND, either express or implied. ##
|
23
|
+
## ##
|
24
|
+
## See the Licence for the specific language governing ##
|
25
|
+
## permissions and limitations under the Licence. ##
|
26
|
+
## ##
|
27
|
+
##########################################################################
|
28
|
+
|
29
|
+
import os
|
30
|
+
from biomedisa_features.keras_helper import read_img_list, remove_extracted_data
|
31
|
+
from biomedisa_features.biomedisa_helper import img_resize, load_data, save_data, set_labels_to_zero
|
32
|
+
from tensorflow.python.framework.errors_impl import ResourceExhaustedError
|
33
|
+
from tensorflow.keras.applications import DenseNet121, densenet
|
34
|
+
from tensorflow.keras.optimizers import Adam
|
35
|
+
from tensorflow.keras.models import Model, load_model
|
36
|
+
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dropout, Dense
|
37
|
+
from tensorflow.keras.callbacks import Callback, ModelCheckpoint
|
38
|
+
from biomedisa_features.DataGeneratorCrop import DataGeneratorCrop
|
39
|
+
from biomedisa_features.PredictDataGeneratorCrop import PredictDataGeneratorCrop
|
40
|
+
import tensorflow as tf
|
41
|
+
import numpy as np
|
42
|
+
from glob import glob
|
43
|
+
import h5py
|
44
|
+
import tarfile
|
45
|
+
import matplotlib.pyplot as plt
|
46
|
+
|
47
|
+
class InputError(Exception):
|
48
|
+
def __init__(self, message=None, img_names=[], label_names=[]):
|
49
|
+
self.message = message
|
50
|
+
self.img_names = img_names
|
51
|
+
self.label_names = label_names
|
52
|
+
|
53
|
+
def save_history(history, path_to_model):
|
54
|
+
# summarize history for accuracy
|
55
|
+
plt.plot(history['accuracy'])
|
56
|
+
plt.plot(history['val_accuracy'])
|
57
|
+
if 'val_loss' in history:
|
58
|
+
plt.legend(['train', 'test'], loc='upper left')
|
59
|
+
else:
|
60
|
+
plt.legend(['train', 'test (Dice)'], loc='upper left')
|
61
|
+
plt.title('model accuracy')
|
62
|
+
plt.ylabel('accuracy')
|
63
|
+
plt.xlabel('epoch')
|
64
|
+
plt.tight_layout() # To prevent overlapping of subplots
|
65
|
+
plt.savefig(path_to_model.replace(".h5","_acc.png"), dpi=300, bbox_inches='tight')
|
66
|
+
plt.clf()
|
67
|
+
# summarize history for loss
|
68
|
+
plt.plot(history['loss'])
|
69
|
+
if 'val_loss' in history:
|
70
|
+
plt.plot(history['val_loss'])
|
71
|
+
plt.title('model loss')
|
72
|
+
plt.ylabel('loss')
|
73
|
+
plt.xlabel('epoch')
|
74
|
+
plt.legend(['train', 'test'], loc='upper left')
|
75
|
+
plt.tight_layout() # To prevent overlapping of subplots
|
76
|
+
plt.savefig(path_to_model.replace(".h5","_loss.png"), dpi=300, bbox_inches='tight')
|
77
|
+
plt.clf()
|
78
|
+
|
79
|
+
def make_densenet(inputshape):
|
80
|
+
base_model = DenseNet121(
|
81
|
+
input_tensor=Input(inputshape),
|
82
|
+
include_top=False,)
|
83
|
+
|
84
|
+
base_model.trainable= False
|
85
|
+
|
86
|
+
inputs = Input(inputshape)
|
87
|
+
x = densenet.preprocess_input(inputs)
|
88
|
+
|
89
|
+
x = base_model(x, training=False)
|
90
|
+
x = GlobalAveragePooling2D()(x)
|
91
|
+
x = Dropout(0.3)(x)
|
92
|
+
|
93
|
+
outputs = Dense(1, activation='sigmoid')(x)
|
94
|
+
|
95
|
+
model = Model(inputs, outputs)
|
96
|
+
return model
|
97
|
+
|
98
|
+
def load_cropping_training_data(normalize, img_list, label_list, x_scale, y_scale, z_scale,
|
99
|
+
labels_to_compute, labels_to_remove, img_in, label_in, normalization_parameters=None, channels=None):
|
100
|
+
|
101
|
+
# read image lists
|
102
|
+
if any(img_list):
|
103
|
+
img_names, label_names = read_img_list(img_list, label_list)
|
104
|
+
InputError.img_names = img_names
|
105
|
+
InputError.label_names = label_names
|
106
|
+
|
107
|
+
# load first label
|
108
|
+
if any(img_list):
|
109
|
+
a, _, _ = load_data(label_names[0], 'first_queue', True)
|
110
|
+
if a is None:
|
111
|
+
InputError.message = f'Invalid label data "{os.path.basename(label_names[0])}"'
|
112
|
+
raise InputError()
|
113
|
+
elif type(label_in) is list:
|
114
|
+
a = label_in[0]
|
115
|
+
else:
|
116
|
+
a = label_in
|
117
|
+
a = a.astype(np.uint8)
|
118
|
+
a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
|
119
|
+
label_z = np.any(a,axis=(1,2))
|
120
|
+
label_y = np.any(a,axis=(0,2))
|
121
|
+
label_x = np.any(a,axis=(0,1))
|
122
|
+
label = np.append(label_z,label_y,axis=0)
|
123
|
+
label = np.append(label,label_x,axis=0)
|
124
|
+
|
125
|
+
# load first img
|
126
|
+
if any(img_list):
|
127
|
+
img, _ = load_data(img_names[0], 'first_queue')
|
128
|
+
if img is None:
|
129
|
+
InputError.message = f'Invalid image data "{os.path.basename(img_names[0])}"'
|
130
|
+
raise InputError()
|
131
|
+
elif type(img_in) is list:
|
132
|
+
img = img_in[0]
|
133
|
+
else:
|
134
|
+
img = img_in
|
135
|
+
# handle all images having channels >=1
|
136
|
+
if len(img.shape)==3:
|
137
|
+
z_shape, y_shape, x_shape = img.shape
|
138
|
+
img = img.reshape(z_shape, y_shape, x_shape, 1)
|
139
|
+
if channels is None:
|
140
|
+
channels = img.shape[3]
|
141
|
+
if img.shape[3] != channels:
|
142
|
+
InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[0])}"'
|
143
|
+
raise InputError()
|
144
|
+
img = img.astype(np.float32)
|
145
|
+
img_z = img_resize(img, a.shape[0], y_scale, x_scale)
|
146
|
+
img_y = np.swapaxes(img_resize(img, z_scale, a.shape[1], x_scale),0,1)
|
147
|
+
img_x = np.swapaxes(img_resize(img, z_scale, y_scale, a.shape[2]),0,2)
|
148
|
+
img = np.append(img_z,img_y,axis=0)
|
149
|
+
img = np.append(img,img_x,axis=0)
|
150
|
+
|
151
|
+
# normalize image data
|
152
|
+
for c in range(channels):
|
153
|
+
img[:,:,:,c] -= np.amin(img[:,:,:,c])
|
154
|
+
img[:,:,:,c] /= np.amax(img[:,:,:,c])
|
155
|
+
if normalization_parameters is None:
|
156
|
+
normalization_parameters = np.zeros((2,channels))
|
157
|
+
normalization_parameters[0,c] = np.mean(img[:,:,:,c])
|
158
|
+
normalization_parameters[1,c] = np.std(img[:,:,:,c])
|
159
|
+
elif normalize:
|
160
|
+
mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
|
161
|
+
img[:,:,:,c] = (img[:,:,:,c] - mean) / std
|
162
|
+
img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
|
163
|
+
|
164
|
+
# loop over list of images
|
165
|
+
if any(img_list) or type(img_in) is list:
|
166
|
+
number_of_images = len(img_names) if any(img_list) else len(img_in)
|
167
|
+
|
168
|
+
for k in range(1, number_of_images):
|
169
|
+
|
170
|
+
# append label
|
171
|
+
if any(label_list):
|
172
|
+
a, _ = load_data(label_names[k], 'first_queue')
|
173
|
+
if a is None:
|
174
|
+
InputError.message = f'Invalid label data "{os.path.basename(label_names[k])}"'
|
175
|
+
raise InputError()
|
176
|
+
else:
|
177
|
+
a = label_in[k]
|
178
|
+
a = a.astype(np.uint8)
|
179
|
+
a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
|
180
|
+
next_label_z = np.any(a,axis=(1,2))
|
181
|
+
next_label_y = np.any(a,axis=(0,2))
|
182
|
+
next_label_x = np.any(a,axis=(0,1))
|
183
|
+
label = np.append(label,next_label_z,axis=0)
|
184
|
+
label = np.append(label,next_label_y,axis=0)
|
185
|
+
label = np.append(label,next_label_x,axis=0)
|
186
|
+
|
187
|
+
# append image
|
188
|
+
if any(img_list):
|
189
|
+
a, _ = load_data(img_names[k], 'first_queue')
|
190
|
+
if a is None:
|
191
|
+
InputError.message = f'Invalid image data "{os.path.basename(img_names[k])}"'
|
192
|
+
raise InputError()
|
193
|
+
else:
|
194
|
+
a = img_in[k]
|
195
|
+
if len(a.shape)==3:
|
196
|
+
z_shape, y_shape, x_shape = a.shape
|
197
|
+
a = a.reshape(z_shape, y_shape, x_shape, 1)
|
198
|
+
if a.shape[3] != channels:
|
199
|
+
InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
|
200
|
+
raise InputError()
|
201
|
+
a = a.astype(np.float32)
|
202
|
+
img_z = img_resize(a, a.shape[0], y_scale, x_scale)
|
203
|
+
img_y = np.swapaxes(img_resize(a, z_scale, a.shape[1], x_scale),0,1)
|
204
|
+
img_x = np.swapaxes(img_resize(a, z_scale, y_scale, a.shape[2]),0,2)
|
205
|
+
next_img = np.append(img_z,img_y,axis=0)
|
206
|
+
next_img = np.append(next_img,img_x,axis=0)
|
207
|
+
for c in range(channels):
|
208
|
+
next_img[:,:,:,c] -= np.amin(next_img[:,:,:,c])
|
209
|
+
next_img[:,:,:,c] /= np.amax(next_img[:,:,:,c])
|
210
|
+
if normalize:
|
211
|
+
mean, std = np.mean(next_img[:,:,:,c]), np.std(next_img[:,:,:,c])
|
212
|
+
next_img[:,:,:,c] = (next_img[:,:,:,c] - mean) / std
|
213
|
+
next_img[:,:,:,c] = next_img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
|
214
|
+
img = np.append(img, next_img, axis=0)
|
215
|
+
|
216
|
+
# remove extracted data
|
217
|
+
if any(img_list):
|
218
|
+
remove_extracted_data(img_names, label_names)
|
219
|
+
|
220
|
+
# limit intensity range
|
221
|
+
img[img<0] = 0
|
222
|
+
img[img>1] = 1
|
223
|
+
img = np.uint8(img*255)
|
224
|
+
|
225
|
+
# number of channels must be three (reuse or cut off)
|
226
|
+
min_channels = min(3, channels)
|
227
|
+
img_rgb = np.empty((img.shape[:3] + (3,)), dtype=np.uint8)
|
228
|
+
for i in range(3):
|
229
|
+
img_rgb[...,i] = img[...,i % min_channels]
|
230
|
+
|
231
|
+
return img_rgb, label, normalization_parameters, channels
|
232
|
+
|
233
|
+
def train_cropping(img, label, path_to_model, epochs, batch_size,
|
234
|
+
validation_split, flip_x, flip_y, flip_z, rotate,
|
235
|
+
img_val, label_val):
|
236
|
+
|
237
|
+
# img shape
|
238
|
+
zsh, ysh, xsh, channels = img.shape
|
239
|
+
|
240
|
+
# list of IDs
|
241
|
+
list_IDs_fg = list(np.where(label)[0])
|
242
|
+
list_IDs_bg = list(np.where(label==False)[0])
|
243
|
+
|
244
|
+
# validation data
|
245
|
+
if np.any(img_val):
|
246
|
+
list_IDs_val_fg = list(np.where(label_val)[0])
|
247
|
+
list_IDs_val_bg = list(np.where(label_val==False)[0])
|
248
|
+
elif validation_split:
|
249
|
+
split_fg = int(len(list_IDs_fg) * validation_split)
|
250
|
+
split_bg = int(len(list_IDs_bg) * validation_split)
|
251
|
+
list_IDs_val_fg = list_IDs_fg[split_fg:]
|
252
|
+
list_IDs_val_bg = list_IDs_bg[split_bg:]
|
253
|
+
list_IDs_fg = list_IDs_fg[:split_fg]
|
254
|
+
list_IDs_bg = list_IDs_bg[:split_bg]
|
255
|
+
|
256
|
+
# upsample IDs
|
257
|
+
max_IDs = max(len(list_IDs_fg), len(list_IDs_bg))
|
258
|
+
tmp_fg = []
|
259
|
+
while len(tmp_fg) < max_IDs:
|
260
|
+
tmp_fg.extend(list_IDs_fg)
|
261
|
+
tmp_fg = tmp_fg[:max_IDs]
|
262
|
+
list_IDs_fg = tmp_fg[:]
|
263
|
+
|
264
|
+
tmp_bg = []
|
265
|
+
while len(tmp_bg) < max_IDs:
|
266
|
+
tmp_bg.extend(list_IDs_bg)
|
267
|
+
tmp_bg = tmp_bg[:max_IDs]
|
268
|
+
list_IDs_bg = tmp_bg[:]
|
269
|
+
|
270
|
+
# validation data
|
271
|
+
if np.any(img_val) or validation_split:
|
272
|
+
max_IDs = max(len(list_IDs_val_fg), len(list_IDs_val_bg))
|
273
|
+
tmp_fg = []
|
274
|
+
while len(tmp_fg) < max_IDs:
|
275
|
+
tmp_fg.extend(list_IDs_val_fg)
|
276
|
+
tmp_fg = tmp_fg[:max_IDs]
|
277
|
+
list_IDs_val_fg = tmp_fg[:]
|
278
|
+
tmp_bg = []
|
279
|
+
|
280
|
+
while len(tmp_bg) < max_IDs:
|
281
|
+
tmp_bg.extend(list_IDs_val_bg)
|
282
|
+
tmp_bg = tmp_bg[:max_IDs]
|
283
|
+
list_IDs_val_bg = tmp_bg[:]
|
284
|
+
|
285
|
+
# input shape
|
286
|
+
input_shape = (ysh, xsh, channels)
|
287
|
+
|
288
|
+
# parameters
|
289
|
+
params = {'dim': (ysh, xsh),
|
290
|
+
'batch_size': batch_size,
|
291
|
+
'n_classes': 2,
|
292
|
+
'n_channels': channels,
|
293
|
+
'shuffle': True}
|
294
|
+
|
295
|
+
# validation parameters
|
296
|
+
params_val = {'dim': (ysh, xsh),
|
297
|
+
'batch_size': batch_size,
|
298
|
+
'n_classes': 2,
|
299
|
+
'n_channels': channels,
|
300
|
+
'shuffle': False}
|
301
|
+
|
302
|
+
# data generator
|
303
|
+
training_generator = DataGeneratorCrop(img, label, list_IDs_fg, list_IDs_bg, **params)
|
304
|
+
if np.any(img_val):
|
305
|
+
validation_generator = DataGeneratorCrop(img_val, label_val, list_IDs_val_fg, list_IDs_val_bg, **params_val)
|
306
|
+
elif validation_split:
|
307
|
+
validation_generator = DataGeneratorCrop(img, label, list_IDs_val_fg, list_IDs_val_bg, **params_val)
|
308
|
+
else:
|
309
|
+
validation_generator = None
|
310
|
+
|
311
|
+
# create a MirroredStrategy
|
312
|
+
cdo = tf.distribute.ReductionToOneDevice()
|
313
|
+
strategy = tf.distribute.MirroredStrategy(cross_device_ops=cdo)
|
314
|
+
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
|
315
|
+
|
316
|
+
# create callback
|
317
|
+
if np.any(img_val) or validation_split:
|
318
|
+
save_best_only = True
|
319
|
+
else:
|
320
|
+
save_best_only = False
|
321
|
+
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(str(path_to_model), save_best_only=save_best_only)
|
322
|
+
|
323
|
+
# compile model
|
324
|
+
with strategy.scope():
|
325
|
+
model = make_densenet(input_shape)
|
326
|
+
model.compile(loss= tf.keras.losses.BinaryCrossentropy(),
|
327
|
+
optimizer= Adam(learning_rate=0.001),
|
328
|
+
metrics=['accuracy'])
|
329
|
+
# train model
|
330
|
+
history = model.fit(training_generator,
|
331
|
+
validation_data=validation_generator,
|
332
|
+
epochs=max(1,epochs),
|
333
|
+
callbacks=checkpoint_cb)
|
334
|
+
|
335
|
+
# save results in figure on train end
|
336
|
+
if np.any(img_val) or validation_split:
|
337
|
+
save_history(history.history, path_to_model.replace(".h5","_cropping.h5"))
|
338
|
+
|
339
|
+
# compile model for finetunning
|
340
|
+
with strategy.scope():
|
341
|
+
model = load_model(str(path_to_model))
|
342
|
+
model.trainable = True
|
343
|
+
model.compile(loss= tf.keras.losses.BinaryCrossentropy(),
|
344
|
+
optimizer= Adam(learning_rate=1e-5),
|
345
|
+
metrics=['accuracy'])
|
346
|
+
|
347
|
+
# finetune model
|
348
|
+
history = model.fit(training_generator,
|
349
|
+
validation_data=validation_generator,
|
350
|
+
epochs=max(1,epochs),
|
351
|
+
callbacks=checkpoint_cb)
|
352
|
+
|
353
|
+
# save results in figure on train end
|
354
|
+
if np.any(img_val) or validation_split:
|
355
|
+
save_history(history.history, path_to_model.replace(".h5","_cropfine.h5"))
|
356
|
+
|
357
|
+
def load_data_to_crop(path_to_img, channels, x_scale, y_scale, z_scale,
|
358
|
+
normalize, normalization_parameters, img):
|
359
|
+
# read image data
|
360
|
+
if img is None:
|
361
|
+
img, _, _ = load_data(path_to_img, 'first_queue', return_extension=True)
|
362
|
+
InputError.img_names = [path_to_img]
|
363
|
+
InputError.label_names = []
|
364
|
+
img_data = np.copy(img, order='C')
|
365
|
+
if img is None:
|
366
|
+
InputError.message = "Invalid image data %s." %(os.path.basename(path_to_img))
|
367
|
+
raise InputError()
|
368
|
+
# handle all images having channels >=1
|
369
|
+
if len(img.shape)==3:
|
370
|
+
z_shape, y_shape, x_shape = img.shape
|
371
|
+
img = img.reshape(z_shape, y_shape, x_shape, 1)
|
372
|
+
if img.shape[3] != channels:
|
373
|
+
InputError.message = f'Number of channels must be {channels}.'
|
374
|
+
raise InputError()
|
375
|
+
z_shape, y_shape, x_shape, _ = img.shape
|
376
|
+
img = img.astype(np.float32)
|
377
|
+
img_z = img_resize(img, z_shape, y_scale, x_scale)
|
378
|
+
img_y = np.swapaxes(img_resize(img,z_scale,y_shape,x_scale),0,1)
|
379
|
+
img_x = np.swapaxes(img_resize(img,z_scale,y_scale,x_shape),0,2)
|
380
|
+
img = np.append(img_z,img_y,axis=0)
|
381
|
+
img = np.append(img,img_x,axis=0)
|
382
|
+
for c in range(channels):
|
383
|
+
img[:,:,:,c] -= np.amin(img[:,:,:,c])
|
384
|
+
img[:,:,:,c] /= np.amax(img[:,:,:,c])
|
385
|
+
if normalize:
|
386
|
+
mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
|
387
|
+
img[:,:,:,c] = (img[:,:,:,c] - mean) / std
|
388
|
+
img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
|
389
|
+
img[img<0] = 0
|
390
|
+
img[img>1] = 1
|
391
|
+
img = np.uint8(img*255)
|
392
|
+
|
393
|
+
# number of channels must be three (reuse or cut off)
|
394
|
+
channels = min(3, channels)
|
395
|
+
img_rgb = np.empty((img.shape[:3] + (3,)), dtype=np.uint8)
|
396
|
+
for i in range(3):
|
397
|
+
img_rgb[...,i] = img[...,i % channels]
|
398
|
+
return img_rgb, z_shape, y_shape, x_shape, img_data
|
399
|
+
|
400
|
+
def crop_volume(img, path_to_model, path_to_final, z_shape, y_shape, x_shape, batch_size,
|
401
|
+
debug_cropping, save_cropped, img_data, x_range, y_range, z_range, x_puffer=25, y_puffer=25, z_puffer=25):
|
402
|
+
|
403
|
+
# img shape
|
404
|
+
zsh, ysh, xsh, channels = img.shape
|
405
|
+
|
406
|
+
# list of IDs
|
407
|
+
list_IDs = [x for x in range(zsh)]
|
408
|
+
|
409
|
+
# make length of list divisible by batch size
|
410
|
+
rest = batch_size - (len(list_IDs) % batch_size)
|
411
|
+
list_IDs = list_IDs + list_IDs[:rest]
|
412
|
+
|
413
|
+
# parameters
|
414
|
+
params = {'dim': (ysh,xsh),
|
415
|
+
'dim_img': (zsh, ysh, xsh),
|
416
|
+
'batch_size': batch_size,
|
417
|
+
'n_channels': channels}
|
418
|
+
|
419
|
+
# data generator
|
420
|
+
predict_generator = PredictDataGeneratorCrop(img, list_IDs, **params)
|
421
|
+
|
422
|
+
# input shape
|
423
|
+
input_shape = (ysh, xsh, channels)
|
424
|
+
|
425
|
+
# load model
|
426
|
+
model = make_densenet(input_shape)
|
427
|
+
|
428
|
+
# load weights
|
429
|
+
hf = h5py.File(path_to_model, 'r')
|
430
|
+
cropping_weights = hf.get('cropping_weights')
|
431
|
+
iterator = 0
|
432
|
+
for layer in model.layers:
|
433
|
+
if layer.get_weights() != []:
|
434
|
+
new_weights = []
|
435
|
+
for arr in layer.get_weights():
|
436
|
+
new_weights.append(cropping_weights.get(str(iterator)))
|
437
|
+
iterator += 1
|
438
|
+
layer.set_weights(new_weights)
|
439
|
+
hf.close()
|
440
|
+
|
441
|
+
# predict
|
442
|
+
probabilities = model.predict(predict_generator, verbose=0, steps=None)
|
443
|
+
probabilities = probabilities[:zsh]
|
444
|
+
probabilities = np.ravel(probabilities)
|
445
|
+
|
446
|
+
# plot prediction
|
447
|
+
if debug_cropping and path_to_final:
|
448
|
+
import matplotlib.pyplot as plt
|
449
|
+
import matplotlib
|
450
|
+
x = range(len(probabilities))
|
451
|
+
y = list(probabilities)
|
452
|
+
plt.plot(x, y)
|
453
|
+
|
454
|
+
# create mask
|
455
|
+
probabilities[probabilities > 0.5] = 1
|
456
|
+
probabilities[probabilities <= 0.5] = 0
|
457
|
+
|
458
|
+
# remove outliers
|
459
|
+
for k in range(4,zsh-4):
|
460
|
+
if np.all(probabilities[k-1:k+2] == np.array([0,1,0])):
|
461
|
+
probabilities[k-1:k+2] = 0
|
462
|
+
elif np.all(probabilities[k-2:k+2] == np.array([0,1,1,0])):
|
463
|
+
probabilities[k-2:k+2] = 0
|
464
|
+
elif np.all(probabilities[k-2:k+3] == np.array([0,1,1,1,0])):
|
465
|
+
probabilities[k-2:k+3] = 0
|
466
|
+
elif np.all(probabilities[k-3:k+3] == np.array([0,1,1,1,1,0])):
|
467
|
+
probabilities[k-3:k+3] = 0
|
468
|
+
elif np.all(probabilities[k-3:k+4] == np.array([0,1,1,1,1,1,0])):
|
469
|
+
probabilities[k-3:k+4] = 0
|
470
|
+
elif np.all(probabilities[k-4:k+4] == np.array([0,1,1,1,1,1,1,0])):
|
471
|
+
probabilities[k-4:k+4] = 0
|
472
|
+
elif np.all(probabilities[k-4:k+5] == np.array([0,1,1,1,1,1,1,1,0])):
|
473
|
+
probabilities[k-4:k+5] = 0
|
474
|
+
|
475
|
+
# create final
|
476
|
+
if z_range is not None:
|
477
|
+
z_lower, z_upper = z_range
|
478
|
+
else:
|
479
|
+
z_lower = max(0,np.argmax(probabilities[:z_shape]) - z_puffer)
|
480
|
+
z_upper = min(z_shape,z_shape - np.argmax(np.flip(probabilities[:z_shape])) + z_puffer +1)
|
481
|
+
|
482
|
+
if y_range is not None:
|
483
|
+
y_lower, y_upper = y_range
|
484
|
+
else:
|
485
|
+
y_lower = max(0,np.argmax(probabilities[z_shape:z_shape+y_shape]) - y_puffer)
|
486
|
+
y_upper = min(y_shape,y_shape - np.argmax(np.flip(probabilities[z_shape:z_shape+y_shape])) + y_puffer +1)
|
487
|
+
|
488
|
+
if x_range is not None:
|
489
|
+
x_lower, x_upper = x_range
|
490
|
+
else:
|
491
|
+
x_lower = max(0,np.argmax(probabilities[z_shape+y_shape:]) - x_puffer)
|
492
|
+
x_upper = min(x_shape,x_shape - np.argmax(np.flip(probabilities[z_shape+y_shape:])) + x_puffer +1)
|
493
|
+
|
494
|
+
# plot result
|
495
|
+
if debug_cropping and path_to_final:
|
496
|
+
y = np.zeros_like(probabilities)
|
497
|
+
y[z_lower:z_upper] = 1
|
498
|
+
y[z_shape+y_lower:z_shape+y_upper] = 1
|
499
|
+
y[z_shape+y_shape+x_lower:z_shape+y_shape+x_upper] = 1
|
500
|
+
plt.plot(x, y, '--')
|
501
|
+
plt.tight_layout() # To prevent overlapping of subplots
|
502
|
+
#matplotlib.use("GTK3Agg")
|
503
|
+
plt.savefig(path_to_final.replace('.tif','.png'), dpi=300)
|
504
|
+
|
505
|
+
# crop image data
|
506
|
+
cropped_volume = img_data[z_lower:z_upper, y_lower:y_upper, x_lower:x_upper]
|
507
|
+
if save_cropped and path_to_final:
|
508
|
+
save_data(path_to_final, cropped_volume, compress=False)
|
509
|
+
|
510
|
+
return z_lower, z_upper, y_lower, y_upper, x_lower, x_upper, cropped_volume
|
511
|
+
|
512
|
+
#=====================
|
513
|
+
# main functions
|
514
|
+
#=====================
|
515
|
+
|
516
|
+
def load_and_train(normalize,path_to_img,path_to_labels,path_to_model,
|
517
|
+
epochs,batch_size,validation_split,
|
518
|
+
flip_x,flip_y,flip_z,rotate,labels_to_compute,labels_to_remove,
|
519
|
+
path_val_img=[None],path_val_labels=[None],
|
520
|
+
img=None, label=None, img_val=None, label_val=None,
|
521
|
+
x_scale=256, y_scale=256, z_scale=256):
|
522
|
+
|
523
|
+
# load training data
|
524
|
+
img, label, normalization_parameters, channels = load_cropping_training_data(normalize,
|
525
|
+
path_to_img, path_to_labels, x_scale, y_scale, z_scale, labels_to_compute, labels_to_remove,
|
526
|
+
img, label)
|
527
|
+
|
528
|
+
# load validation data
|
529
|
+
if any(path_val_img) or img_val is not None:
|
530
|
+
img_val, label_val, _, _ = load_cropping_training_data(normalize,
|
531
|
+
path_val_img, path_val_labels, x_scale, y_scale, z_scale,
|
532
|
+
labels_to_compute, labels_to_remove,
|
533
|
+
img_val, label_val, normalization_parameters, channels)
|
534
|
+
|
535
|
+
# train cropping
|
536
|
+
train_cropping(img, label, path_to_model, epochs,
|
537
|
+
batch_size, validation_split,
|
538
|
+
flip_x, flip_y, flip_z, rotate,
|
539
|
+
img_val, label_val)
|
540
|
+
|
541
|
+
# load weights
|
542
|
+
model = load_model(str(path_to_model))
|
543
|
+
cropping_weights = []
|
544
|
+
for layer in model.layers:
|
545
|
+
if layer.get_weights() != []:
|
546
|
+
for arr in layer.get_weights():
|
547
|
+
cropping_weights.append(arr)
|
548
|
+
|
549
|
+
# configuration data
|
550
|
+
cropping_config = np.array([channels, x_scale, y_scale, z_scale, normalize, 0, 1])
|
551
|
+
|
552
|
+
return cropping_weights, cropping_config, normalization_parameters
|
553
|
+
|
554
|
+
def crop_data(path_to_data, path_to_model, path_to_cropped_image, batch_size,
|
555
|
+
debug_cropping=False, save_cropped=True, img_data=None,
|
556
|
+
x_range=None, y_range=None, z_range=None):
|
557
|
+
|
558
|
+
# get meta data
|
559
|
+
hf = h5py.File(path_to_model, 'r')
|
560
|
+
meta = hf.get('cropping_meta')
|
561
|
+
configuration = meta.get('configuration')
|
562
|
+
channels, x_scale, y_scale, z_scale, normalize, mu, sig = np.array(configuration)[:]
|
563
|
+
channels, x_scale, y_scale, z_scale, normalize, mu, sig = int(channels), int(x_scale), \
|
564
|
+
int(y_scale), int(z_scale), int(normalize), float(mu), float(sig)
|
565
|
+
if '/cropping_meta/normalization' in hf:
|
566
|
+
normalization_parameters = np.array(meta.get('normalization'), dtype=float)
|
567
|
+
else:
|
568
|
+
# old configuration
|
569
|
+
normalization_parameters = np.array([[mu],[sig]])
|
570
|
+
channels = 1
|
571
|
+
hf.close()
|
572
|
+
|
573
|
+
# load data
|
574
|
+
img, z_shape, y_shape, x_shape, img_data = load_data_to_crop(path_to_data, channels,
|
575
|
+
x_scale, y_scale, z_scale, normalize, normalization_parameters, img_data)
|
576
|
+
|
577
|
+
# make prediction
|
578
|
+
z_lower, z_upper, y_lower, y_upper, x_lower, x_upper, cropped_volume = crop_volume(img, path_to_model,
|
579
|
+
path_to_cropped_image, z_shape, y_shape, x_shape, batch_size, debug_cropping, save_cropped, img_data,
|
580
|
+
x_range, y_range, z_range)
|
581
|
+
|
582
|
+
# region of interest
|
583
|
+
region_of_interest = np.array([z_lower, z_upper, y_lower, y_upper, x_lower, x_upper])
|
584
|
+
|
585
|
+
return region_of_interest, cropped_volume
|
586
|
+
|