biomedisa 2024.5.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. biomedisa/__init__.py +53 -0
  2. biomedisa/__main__.py +18 -0
  3. biomedisa/biomedisa_features/DataGenerator.py +299 -0
  4. biomedisa/biomedisa_features/DataGeneratorCrop.py +121 -0
  5. biomedisa/biomedisa_features/PredictDataGenerator.py +87 -0
  6. biomedisa/biomedisa_features/PredictDataGeneratorCrop.py +74 -0
  7. biomedisa/biomedisa_features/__init__.py +0 -0
  8. biomedisa/biomedisa_features/active_contour.py +434 -0
  9. biomedisa/biomedisa_features/amira_to_np/__init__.py +0 -0
  10. biomedisa/biomedisa_features/amira_to_np/amira_data_stream.py +980 -0
  11. biomedisa/biomedisa_features/amira_to_np/amira_grammar.py +369 -0
  12. biomedisa/biomedisa_features/amira_to_np/amira_header.py +290 -0
  13. biomedisa/biomedisa_features/amira_to_np/amira_helper.py +72 -0
  14. biomedisa/biomedisa_features/assd.py +167 -0
  15. biomedisa/biomedisa_features/biomedisa_helper.py +801 -0
  16. biomedisa/biomedisa_features/create_slices.py +286 -0
  17. biomedisa/biomedisa_features/crop_helper.py +586 -0
  18. biomedisa/biomedisa_features/curvop_numba.py +149 -0
  19. biomedisa/biomedisa_features/django_env.py +172 -0
  20. biomedisa/biomedisa_features/keras_helper.py +1219 -0
  21. biomedisa/biomedisa_features/nc_reader.py +179 -0
  22. biomedisa/biomedisa_features/pid.py +52 -0
  23. biomedisa/biomedisa_features/process_image.py +253 -0
  24. biomedisa/biomedisa_features/pycuda_test.py +84 -0
  25. biomedisa/biomedisa_features/random_walk/__init__.py +0 -0
  26. biomedisa/biomedisa_features/random_walk/gpu_kernels.py +183 -0
  27. biomedisa/biomedisa_features/random_walk/pycuda_large.py +826 -0
  28. biomedisa/biomedisa_features/random_walk/pycuda_large_allx.py +806 -0
  29. biomedisa/biomedisa_features/random_walk/pycuda_small.py +414 -0
  30. biomedisa/biomedisa_features/random_walk/pycuda_small_allx.py +493 -0
  31. biomedisa/biomedisa_features/random_walk/pyopencl_large.py +760 -0
  32. biomedisa/biomedisa_features/random_walk/pyopencl_small.py +441 -0
  33. biomedisa/biomedisa_features/random_walk/rw_large.py +390 -0
  34. biomedisa/biomedisa_features/random_walk/rw_small.py +310 -0
  35. biomedisa/biomedisa_features/remove_outlier.py +399 -0
  36. biomedisa/biomedisa_features/split_volume.py +274 -0
  37. biomedisa/deeplearning.py +519 -0
  38. biomedisa/interpolation.py +371 -0
  39. biomedisa/mesh.py +406 -0
  40. biomedisa-2024.5.14.dist-info/LICENSE +191 -0
  41. biomedisa-2024.5.14.dist-info/METADATA +306 -0
  42. biomedisa-2024.5.14.dist-info/RECORD +44 -0
  43. biomedisa-2024.5.14.dist-info/WHEEL +5 -0
  44. biomedisa-2024.5.14.dist-info/top_level.txt +1 -0
@@ -0,0 +1,586 @@
1
+ ##########################################################################
2
+ ## ##
3
+ ## Copyright (c) 2024 Philipp Lösel. All rights reserved. ##
4
+ ## ##
5
+ ## This file is part of the open source project biomedisa. ##
6
+ ## ##
7
+ ## Licensed under the European Union Public Licence (EUPL) ##
8
+ ## v1.2, or - as soon as they will be approved by the ##
9
+ ## European Commission - subsequent versions of the EUPL; ##
10
+ ## ##
11
+ ## You may redistribute it and/or modify it under the terms ##
12
+ ## of the EUPL v1.2. You may not use this work except in ##
13
+ ## compliance with this Licence. ##
14
+ ## ##
15
+ ## You can obtain a copy of the Licence at: ##
16
+ ## ##
17
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
18
+ ## ##
19
+ ## Unless required by applicable law or agreed to in ##
20
+ ## writing, software distributed under the Licence is ##
21
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
22
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
23
+ ## ##
24
+ ## See the Licence for the specific language governing ##
25
+ ## permissions and limitations under the Licence. ##
26
+ ## ##
27
+ ##########################################################################
28
+
29
+ import os
30
+ from biomedisa_features.keras_helper import read_img_list, remove_extracted_data
31
+ from biomedisa_features.biomedisa_helper import img_resize, load_data, save_data, set_labels_to_zero
32
+ from tensorflow.python.framework.errors_impl import ResourceExhaustedError
33
+ from tensorflow.keras.applications import DenseNet121, densenet
34
+ from tensorflow.keras.optimizers import Adam
35
+ from tensorflow.keras.models import Model, load_model
36
+ from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dropout, Dense
37
+ from tensorflow.keras.callbacks import Callback, ModelCheckpoint
38
+ from biomedisa_features.DataGeneratorCrop import DataGeneratorCrop
39
+ from biomedisa_features.PredictDataGeneratorCrop import PredictDataGeneratorCrop
40
+ import tensorflow as tf
41
+ import numpy as np
42
+ from glob import glob
43
+ import h5py
44
+ import tarfile
45
+ import matplotlib.pyplot as plt
46
+
47
+ class InputError(Exception):
48
+ def __init__(self, message=None, img_names=[], label_names=[]):
49
+ self.message = message
50
+ self.img_names = img_names
51
+ self.label_names = label_names
52
+
53
+ def save_history(history, path_to_model):
54
+ # summarize history for accuracy
55
+ plt.plot(history['accuracy'])
56
+ plt.plot(history['val_accuracy'])
57
+ if 'val_loss' in history:
58
+ plt.legend(['train', 'test'], loc='upper left')
59
+ else:
60
+ plt.legend(['train', 'test (Dice)'], loc='upper left')
61
+ plt.title('model accuracy')
62
+ plt.ylabel('accuracy')
63
+ plt.xlabel('epoch')
64
+ plt.tight_layout() # To prevent overlapping of subplots
65
+ plt.savefig(path_to_model.replace(".h5","_acc.png"), dpi=300, bbox_inches='tight')
66
+ plt.clf()
67
+ # summarize history for loss
68
+ plt.plot(history['loss'])
69
+ if 'val_loss' in history:
70
+ plt.plot(history['val_loss'])
71
+ plt.title('model loss')
72
+ plt.ylabel('loss')
73
+ plt.xlabel('epoch')
74
+ plt.legend(['train', 'test'], loc='upper left')
75
+ plt.tight_layout() # To prevent overlapping of subplots
76
+ plt.savefig(path_to_model.replace(".h5","_loss.png"), dpi=300, bbox_inches='tight')
77
+ plt.clf()
78
+
79
+ def make_densenet(inputshape):
80
+ base_model = DenseNet121(
81
+ input_tensor=Input(inputshape),
82
+ include_top=False,)
83
+
84
+ base_model.trainable= False
85
+
86
+ inputs = Input(inputshape)
87
+ x = densenet.preprocess_input(inputs)
88
+
89
+ x = base_model(x, training=False)
90
+ x = GlobalAveragePooling2D()(x)
91
+ x = Dropout(0.3)(x)
92
+
93
+ outputs = Dense(1, activation='sigmoid')(x)
94
+
95
+ model = Model(inputs, outputs)
96
+ return model
97
+
98
+ def load_cropping_training_data(normalize, img_list, label_list, x_scale, y_scale, z_scale,
99
+ labels_to_compute, labels_to_remove, img_in, label_in, normalization_parameters=None, channels=None):
100
+
101
+ # read image lists
102
+ if any(img_list):
103
+ img_names, label_names = read_img_list(img_list, label_list)
104
+ InputError.img_names = img_names
105
+ InputError.label_names = label_names
106
+
107
+ # load first label
108
+ if any(img_list):
109
+ a, _, _ = load_data(label_names[0], 'first_queue', True)
110
+ if a is None:
111
+ InputError.message = f'Invalid label data "{os.path.basename(label_names[0])}"'
112
+ raise InputError()
113
+ elif type(label_in) is list:
114
+ a = label_in[0]
115
+ else:
116
+ a = label_in
117
+ a = a.astype(np.uint8)
118
+ a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
119
+ label_z = np.any(a,axis=(1,2))
120
+ label_y = np.any(a,axis=(0,2))
121
+ label_x = np.any(a,axis=(0,1))
122
+ label = np.append(label_z,label_y,axis=0)
123
+ label = np.append(label,label_x,axis=0)
124
+
125
+ # load first img
126
+ if any(img_list):
127
+ img, _ = load_data(img_names[0], 'first_queue')
128
+ if img is None:
129
+ InputError.message = f'Invalid image data "{os.path.basename(img_names[0])}"'
130
+ raise InputError()
131
+ elif type(img_in) is list:
132
+ img = img_in[0]
133
+ else:
134
+ img = img_in
135
+ # handle all images having channels >=1
136
+ if len(img.shape)==3:
137
+ z_shape, y_shape, x_shape = img.shape
138
+ img = img.reshape(z_shape, y_shape, x_shape, 1)
139
+ if channels is None:
140
+ channels = img.shape[3]
141
+ if img.shape[3] != channels:
142
+ InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[0])}"'
143
+ raise InputError()
144
+ img = img.astype(np.float32)
145
+ img_z = img_resize(img, a.shape[0], y_scale, x_scale)
146
+ img_y = np.swapaxes(img_resize(img, z_scale, a.shape[1], x_scale),0,1)
147
+ img_x = np.swapaxes(img_resize(img, z_scale, y_scale, a.shape[2]),0,2)
148
+ img = np.append(img_z,img_y,axis=0)
149
+ img = np.append(img,img_x,axis=0)
150
+
151
+ # normalize image data
152
+ for c in range(channels):
153
+ img[:,:,:,c] -= np.amin(img[:,:,:,c])
154
+ img[:,:,:,c] /= np.amax(img[:,:,:,c])
155
+ if normalization_parameters is None:
156
+ normalization_parameters = np.zeros((2,channels))
157
+ normalization_parameters[0,c] = np.mean(img[:,:,:,c])
158
+ normalization_parameters[1,c] = np.std(img[:,:,:,c])
159
+ elif normalize:
160
+ mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
161
+ img[:,:,:,c] = (img[:,:,:,c] - mean) / std
162
+ img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
163
+
164
+ # loop over list of images
165
+ if any(img_list) or type(img_in) is list:
166
+ number_of_images = len(img_names) if any(img_list) else len(img_in)
167
+
168
+ for k in range(1, number_of_images):
169
+
170
+ # append label
171
+ if any(label_list):
172
+ a, _ = load_data(label_names[k], 'first_queue')
173
+ if a is None:
174
+ InputError.message = f'Invalid label data "{os.path.basename(label_names[k])}"'
175
+ raise InputError()
176
+ else:
177
+ a = label_in[k]
178
+ a = a.astype(np.uint8)
179
+ a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
180
+ next_label_z = np.any(a,axis=(1,2))
181
+ next_label_y = np.any(a,axis=(0,2))
182
+ next_label_x = np.any(a,axis=(0,1))
183
+ label = np.append(label,next_label_z,axis=0)
184
+ label = np.append(label,next_label_y,axis=0)
185
+ label = np.append(label,next_label_x,axis=0)
186
+
187
+ # append image
188
+ if any(img_list):
189
+ a, _ = load_data(img_names[k], 'first_queue')
190
+ if a is None:
191
+ InputError.message = f'Invalid image data "{os.path.basename(img_names[k])}"'
192
+ raise InputError()
193
+ else:
194
+ a = img_in[k]
195
+ if len(a.shape)==3:
196
+ z_shape, y_shape, x_shape = a.shape
197
+ a = a.reshape(z_shape, y_shape, x_shape, 1)
198
+ if a.shape[3] != channels:
199
+ InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
200
+ raise InputError()
201
+ a = a.astype(np.float32)
202
+ img_z = img_resize(a, a.shape[0], y_scale, x_scale)
203
+ img_y = np.swapaxes(img_resize(a, z_scale, a.shape[1], x_scale),0,1)
204
+ img_x = np.swapaxes(img_resize(a, z_scale, y_scale, a.shape[2]),0,2)
205
+ next_img = np.append(img_z,img_y,axis=0)
206
+ next_img = np.append(next_img,img_x,axis=0)
207
+ for c in range(channels):
208
+ next_img[:,:,:,c] -= np.amin(next_img[:,:,:,c])
209
+ next_img[:,:,:,c] /= np.amax(next_img[:,:,:,c])
210
+ if normalize:
211
+ mean, std = np.mean(next_img[:,:,:,c]), np.std(next_img[:,:,:,c])
212
+ next_img[:,:,:,c] = (next_img[:,:,:,c] - mean) / std
213
+ next_img[:,:,:,c] = next_img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
214
+ img = np.append(img, next_img, axis=0)
215
+
216
+ # remove extracted data
217
+ if any(img_list):
218
+ remove_extracted_data(img_names, label_names)
219
+
220
+ # limit intensity range
221
+ img[img<0] = 0
222
+ img[img>1] = 1
223
+ img = np.uint8(img*255)
224
+
225
+ # number of channels must be three (reuse or cut off)
226
+ min_channels = min(3, channels)
227
+ img_rgb = np.empty((img.shape[:3] + (3,)), dtype=np.uint8)
228
+ for i in range(3):
229
+ img_rgb[...,i] = img[...,i % min_channels]
230
+
231
+ return img_rgb, label, normalization_parameters, channels
232
+
233
+ def train_cropping(img, label, path_to_model, epochs, batch_size,
234
+ validation_split, flip_x, flip_y, flip_z, rotate,
235
+ img_val, label_val):
236
+
237
+ # img shape
238
+ zsh, ysh, xsh, channels = img.shape
239
+
240
+ # list of IDs
241
+ list_IDs_fg = list(np.where(label)[0])
242
+ list_IDs_bg = list(np.where(label==False)[0])
243
+
244
+ # validation data
245
+ if np.any(img_val):
246
+ list_IDs_val_fg = list(np.where(label_val)[0])
247
+ list_IDs_val_bg = list(np.where(label_val==False)[0])
248
+ elif validation_split:
249
+ split_fg = int(len(list_IDs_fg) * validation_split)
250
+ split_bg = int(len(list_IDs_bg) * validation_split)
251
+ list_IDs_val_fg = list_IDs_fg[split_fg:]
252
+ list_IDs_val_bg = list_IDs_bg[split_bg:]
253
+ list_IDs_fg = list_IDs_fg[:split_fg]
254
+ list_IDs_bg = list_IDs_bg[:split_bg]
255
+
256
+ # upsample IDs
257
+ max_IDs = max(len(list_IDs_fg), len(list_IDs_bg))
258
+ tmp_fg = []
259
+ while len(tmp_fg) < max_IDs:
260
+ tmp_fg.extend(list_IDs_fg)
261
+ tmp_fg = tmp_fg[:max_IDs]
262
+ list_IDs_fg = tmp_fg[:]
263
+
264
+ tmp_bg = []
265
+ while len(tmp_bg) < max_IDs:
266
+ tmp_bg.extend(list_IDs_bg)
267
+ tmp_bg = tmp_bg[:max_IDs]
268
+ list_IDs_bg = tmp_bg[:]
269
+
270
+ # validation data
271
+ if np.any(img_val) or validation_split:
272
+ max_IDs = max(len(list_IDs_val_fg), len(list_IDs_val_bg))
273
+ tmp_fg = []
274
+ while len(tmp_fg) < max_IDs:
275
+ tmp_fg.extend(list_IDs_val_fg)
276
+ tmp_fg = tmp_fg[:max_IDs]
277
+ list_IDs_val_fg = tmp_fg[:]
278
+ tmp_bg = []
279
+
280
+ while len(tmp_bg) < max_IDs:
281
+ tmp_bg.extend(list_IDs_val_bg)
282
+ tmp_bg = tmp_bg[:max_IDs]
283
+ list_IDs_val_bg = tmp_bg[:]
284
+
285
+ # input shape
286
+ input_shape = (ysh, xsh, channels)
287
+
288
+ # parameters
289
+ params = {'dim': (ysh, xsh),
290
+ 'batch_size': batch_size,
291
+ 'n_classes': 2,
292
+ 'n_channels': channels,
293
+ 'shuffle': True}
294
+
295
+ # validation parameters
296
+ params_val = {'dim': (ysh, xsh),
297
+ 'batch_size': batch_size,
298
+ 'n_classes': 2,
299
+ 'n_channels': channels,
300
+ 'shuffle': False}
301
+
302
+ # data generator
303
+ training_generator = DataGeneratorCrop(img, label, list_IDs_fg, list_IDs_bg, **params)
304
+ if np.any(img_val):
305
+ validation_generator = DataGeneratorCrop(img_val, label_val, list_IDs_val_fg, list_IDs_val_bg, **params_val)
306
+ elif validation_split:
307
+ validation_generator = DataGeneratorCrop(img, label, list_IDs_val_fg, list_IDs_val_bg, **params_val)
308
+ else:
309
+ validation_generator = None
310
+
311
+ # create a MirroredStrategy
312
+ cdo = tf.distribute.ReductionToOneDevice()
313
+ strategy = tf.distribute.MirroredStrategy(cross_device_ops=cdo)
314
+ print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
315
+
316
+ # create callback
317
+ if np.any(img_val) or validation_split:
318
+ save_best_only = True
319
+ else:
320
+ save_best_only = False
321
+ checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(str(path_to_model), save_best_only=save_best_only)
322
+
323
+ # compile model
324
+ with strategy.scope():
325
+ model = make_densenet(input_shape)
326
+ model.compile(loss= tf.keras.losses.BinaryCrossentropy(),
327
+ optimizer= Adam(learning_rate=0.001),
328
+ metrics=['accuracy'])
329
+ # train model
330
+ history = model.fit(training_generator,
331
+ validation_data=validation_generator,
332
+ epochs=max(1,epochs),
333
+ callbacks=checkpoint_cb)
334
+
335
+ # save results in figure on train end
336
+ if np.any(img_val) or validation_split:
337
+ save_history(history.history, path_to_model.replace(".h5","_cropping.h5"))
338
+
339
+ # compile model for finetunning
340
+ with strategy.scope():
341
+ model = load_model(str(path_to_model))
342
+ model.trainable = True
343
+ model.compile(loss= tf.keras.losses.BinaryCrossentropy(),
344
+ optimizer= Adam(learning_rate=1e-5),
345
+ metrics=['accuracy'])
346
+
347
+ # finetune model
348
+ history = model.fit(training_generator,
349
+ validation_data=validation_generator,
350
+ epochs=max(1,epochs),
351
+ callbacks=checkpoint_cb)
352
+
353
+ # save results in figure on train end
354
+ if np.any(img_val) or validation_split:
355
+ save_history(history.history, path_to_model.replace(".h5","_cropfine.h5"))
356
+
357
+ def load_data_to_crop(path_to_img, channels, x_scale, y_scale, z_scale,
358
+ normalize, normalization_parameters, img):
359
+ # read image data
360
+ if img is None:
361
+ img, _, _ = load_data(path_to_img, 'first_queue', return_extension=True)
362
+ InputError.img_names = [path_to_img]
363
+ InputError.label_names = []
364
+ img_data = np.copy(img, order='C')
365
+ if img is None:
366
+ InputError.message = "Invalid image data %s." %(os.path.basename(path_to_img))
367
+ raise InputError()
368
+ # handle all images having channels >=1
369
+ if len(img.shape)==3:
370
+ z_shape, y_shape, x_shape = img.shape
371
+ img = img.reshape(z_shape, y_shape, x_shape, 1)
372
+ if img.shape[3] != channels:
373
+ InputError.message = f'Number of channels must be {channels}.'
374
+ raise InputError()
375
+ z_shape, y_shape, x_shape, _ = img.shape
376
+ img = img.astype(np.float32)
377
+ img_z = img_resize(img, z_shape, y_scale, x_scale)
378
+ img_y = np.swapaxes(img_resize(img,z_scale,y_shape,x_scale),0,1)
379
+ img_x = np.swapaxes(img_resize(img,z_scale,y_scale,x_shape),0,2)
380
+ img = np.append(img_z,img_y,axis=0)
381
+ img = np.append(img,img_x,axis=0)
382
+ for c in range(channels):
383
+ img[:,:,:,c] -= np.amin(img[:,:,:,c])
384
+ img[:,:,:,c] /= np.amax(img[:,:,:,c])
385
+ if normalize:
386
+ mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
387
+ img[:,:,:,c] = (img[:,:,:,c] - mean) / std
388
+ img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
389
+ img[img<0] = 0
390
+ img[img>1] = 1
391
+ img = np.uint8(img*255)
392
+
393
+ # number of channels must be three (reuse or cut off)
394
+ channels = min(3, channels)
395
+ img_rgb = np.empty((img.shape[:3] + (3,)), dtype=np.uint8)
396
+ for i in range(3):
397
+ img_rgb[...,i] = img[...,i % channels]
398
+ return img_rgb, z_shape, y_shape, x_shape, img_data
399
+
400
+ def crop_volume(img, path_to_model, path_to_final, z_shape, y_shape, x_shape, batch_size,
401
+ debug_cropping, save_cropped, img_data, x_range, y_range, z_range, x_puffer=25, y_puffer=25, z_puffer=25):
402
+
403
+ # img shape
404
+ zsh, ysh, xsh, channels = img.shape
405
+
406
+ # list of IDs
407
+ list_IDs = [x for x in range(zsh)]
408
+
409
+ # make length of list divisible by batch size
410
+ rest = batch_size - (len(list_IDs) % batch_size)
411
+ list_IDs = list_IDs + list_IDs[:rest]
412
+
413
+ # parameters
414
+ params = {'dim': (ysh,xsh),
415
+ 'dim_img': (zsh, ysh, xsh),
416
+ 'batch_size': batch_size,
417
+ 'n_channels': channels}
418
+
419
+ # data generator
420
+ predict_generator = PredictDataGeneratorCrop(img, list_IDs, **params)
421
+
422
+ # input shape
423
+ input_shape = (ysh, xsh, channels)
424
+
425
+ # load model
426
+ model = make_densenet(input_shape)
427
+
428
+ # load weights
429
+ hf = h5py.File(path_to_model, 'r')
430
+ cropping_weights = hf.get('cropping_weights')
431
+ iterator = 0
432
+ for layer in model.layers:
433
+ if layer.get_weights() != []:
434
+ new_weights = []
435
+ for arr in layer.get_weights():
436
+ new_weights.append(cropping_weights.get(str(iterator)))
437
+ iterator += 1
438
+ layer.set_weights(new_weights)
439
+ hf.close()
440
+
441
+ # predict
442
+ probabilities = model.predict(predict_generator, verbose=0, steps=None)
443
+ probabilities = probabilities[:zsh]
444
+ probabilities = np.ravel(probabilities)
445
+
446
+ # plot prediction
447
+ if debug_cropping and path_to_final:
448
+ import matplotlib.pyplot as plt
449
+ import matplotlib
450
+ x = range(len(probabilities))
451
+ y = list(probabilities)
452
+ plt.plot(x, y)
453
+
454
+ # create mask
455
+ probabilities[probabilities > 0.5] = 1
456
+ probabilities[probabilities <= 0.5] = 0
457
+
458
+ # remove outliers
459
+ for k in range(4,zsh-4):
460
+ if np.all(probabilities[k-1:k+2] == np.array([0,1,0])):
461
+ probabilities[k-1:k+2] = 0
462
+ elif np.all(probabilities[k-2:k+2] == np.array([0,1,1,0])):
463
+ probabilities[k-2:k+2] = 0
464
+ elif np.all(probabilities[k-2:k+3] == np.array([0,1,1,1,0])):
465
+ probabilities[k-2:k+3] = 0
466
+ elif np.all(probabilities[k-3:k+3] == np.array([0,1,1,1,1,0])):
467
+ probabilities[k-3:k+3] = 0
468
+ elif np.all(probabilities[k-3:k+4] == np.array([0,1,1,1,1,1,0])):
469
+ probabilities[k-3:k+4] = 0
470
+ elif np.all(probabilities[k-4:k+4] == np.array([0,1,1,1,1,1,1,0])):
471
+ probabilities[k-4:k+4] = 0
472
+ elif np.all(probabilities[k-4:k+5] == np.array([0,1,1,1,1,1,1,1,0])):
473
+ probabilities[k-4:k+5] = 0
474
+
475
+ # create final
476
+ if z_range is not None:
477
+ z_lower, z_upper = z_range
478
+ else:
479
+ z_lower = max(0,np.argmax(probabilities[:z_shape]) - z_puffer)
480
+ z_upper = min(z_shape,z_shape - np.argmax(np.flip(probabilities[:z_shape])) + z_puffer +1)
481
+
482
+ if y_range is not None:
483
+ y_lower, y_upper = y_range
484
+ else:
485
+ y_lower = max(0,np.argmax(probabilities[z_shape:z_shape+y_shape]) - y_puffer)
486
+ y_upper = min(y_shape,y_shape - np.argmax(np.flip(probabilities[z_shape:z_shape+y_shape])) + y_puffer +1)
487
+
488
+ if x_range is not None:
489
+ x_lower, x_upper = x_range
490
+ else:
491
+ x_lower = max(0,np.argmax(probabilities[z_shape+y_shape:]) - x_puffer)
492
+ x_upper = min(x_shape,x_shape - np.argmax(np.flip(probabilities[z_shape+y_shape:])) + x_puffer +1)
493
+
494
+ # plot result
495
+ if debug_cropping and path_to_final:
496
+ y = np.zeros_like(probabilities)
497
+ y[z_lower:z_upper] = 1
498
+ y[z_shape+y_lower:z_shape+y_upper] = 1
499
+ y[z_shape+y_shape+x_lower:z_shape+y_shape+x_upper] = 1
500
+ plt.plot(x, y, '--')
501
+ plt.tight_layout() # To prevent overlapping of subplots
502
+ #matplotlib.use("GTK3Agg")
503
+ plt.savefig(path_to_final.replace('.tif','.png'), dpi=300)
504
+
505
+ # crop image data
506
+ cropped_volume = img_data[z_lower:z_upper, y_lower:y_upper, x_lower:x_upper]
507
+ if save_cropped and path_to_final:
508
+ save_data(path_to_final, cropped_volume, compress=False)
509
+
510
+ return z_lower, z_upper, y_lower, y_upper, x_lower, x_upper, cropped_volume
511
+
512
+ #=====================
513
+ # main functions
514
+ #=====================
515
+
516
+ def load_and_train(normalize,path_to_img,path_to_labels,path_to_model,
517
+ epochs,batch_size,validation_split,
518
+ flip_x,flip_y,flip_z,rotate,labels_to_compute,labels_to_remove,
519
+ path_val_img=[None],path_val_labels=[None],
520
+ img=None, label=None, img_val=None, label_val=None,
521
+ x_scale=256, y_scale=256, z_scale=256):
522
+
523
+ # load training data
524
+ img, label, normalization_parameters, channels = load_cropping_training_data(normalize,
525
+ path_to_img, path_to_labels, x_scale, y_scale, z_scale, labels_to_compute, labels_to_remove,
526
+ img, label)
527
+
528
+ # load validation data
529
+ if any(path_val_img) or img_val is not None:
530
+ img_val, label_val, _, _ = load_cropping_training_data(normalize,
531
+ path_val_img, path_val_labels, x_scale, y_scale, z_scale,
532
+ labels_to_compute, labels_to_remove,
533
+ img_val, label_val, normalization_parameters, channels)
534
+
535
+ # train cropping
536
+ train_cropping(img, label, path_to_model, epochs,
537
+ batch_size, validation_split,
538
+ flip_x, flip_y, flip_z, rotate,
539
+ img_val, label_val)
540
+
541
+ # load weights
542
+ model = load_model(str(path_to_model))
543
+ cropping_weights = []
544
+ for layer in model.layers:
545
+ if layer.get_weights() != []:
546
+ for arr in layer.get_weights():
547
+ cropping_weights.append(arr)
548
+
549
+ # configuration data
550
+ cropping_config = np.array([channels, x_scale, y_scale, z_scale, normalize, 0, 1])
551
+
552
+ return cropping_weights, cropping_config, normalization_parameters
553
+
554
+ def crop_data(path_to_data, path_to_model, path_to_cropped_image, batch_size,
555
+ debug_cropping=False, save_cropped=True, img_data=None,
556
+ x_range=None, y_range=None, z_range=None):
557
+
558
+ # get meta data
559
+ hf = h5py.File(path_to_model, 'r')
560
+ meta = hf.get('cropping_meta')
561
+ configuration = meta.get('configuration')
562
+ channels, x_scale, y_scale, z_scale, normalize, mu, sig = np.array(configuration)[:]
563
+ channels, x_scale, y_scale, z_scale, normalize, mu, sig = int(channels), int(x_scale), \
564
+ int(y_scale), int(z_scale), int(normalize), float(mu), float(sig)
565
+ if '/cropping_meta/normalization' in hf:
566
+ normalization_parameters = np.array(meta.get('normalization'), dtype=float)
567
+ else:
568
+ # old configuration
569
+ normalization_parameters = np.array([[mu],[sig]])
570
+ channels = 1
571
+ hf.close()
572
+
573
+ # load data
574
+ img, z_shape, y_shape, x_shape, img_data = load_data_to_crop(path_to_data, channels,
575
+ x_scale, y_scale, z_scale, normalize, normalization_parameters, img_data)
576
+
577
+ # make prediction
578
+ z_lower, z_upper, y_lower, y_upper, x_lower, x_upper, cropped_volume = crop_volume(img, path_to_model,
579
+ path_to_cropped_image, z_shape, y_shape, x_shape, batch_size, debug_cropping, save_cropped, img_data,
580
+ x_range, y_range, z_range)
581
+
582
+ # region of interest
583
+ region_of_interest = np.array([z_lower, z_upper, y_lower, y_upper, x_lower, x_upper])
584
+
585
+ return region_of_interest, cropped_volume
586
+