biomedisa 24.5.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. biomedisa/__init__.py +49 -0
  2. biomedisa/__main__.py +18 -0
  3. biomedisa/deeplearning.py +529 -0
  4. biomedisa/features/DataGenerator.py +299 -0
  5. biomedisa/features/DataGeneratorCrop.py +121 -0
  6. biomedisa/features/PredictDataGenerator.py +87 -0
  7. biomedisa/features/PredictDataGeneratorCrop.py +74 -0
  8. biomedisa/features/__init__.py +0 -0
  9. biomedisa/features/active_contour.py +430 -0
  10. biomedisa/features/amira_to_np/__init__.py +0 -0
  11. biomedisa/features/amira_to_np/amira_data_stream.py +980 -0
  12. biomedisa/features/amira_to_np/amira_grammar.py +369 -0
  13. biomedisa/features/amira_to_np/amira_header.py +290 -0
  14. biomedisa/features/amira_to_np/amira_helper.py +72 -0
  15. biomedisa/features/assd.py +167 -0
  16. biomedisa/features/biomedisa_helper.py +842 -0
  17. biomedisa/features/create_slices.py +277 -0
  18. biomedisa/features/crop_helper.py +581 -0
  19. biomedisa/features/curvop_numba.py +149 -0
  20. biomedisa/features/django_env.py +171 -0
  21. biomedisa/features/keras_helper.py +1195 -0
  22. biomedisa/features/nc_reader.py +179 -0
  23. biomedisa/features/pid.py +52 -0
  24. biomedisa/features/process_image.py +251 -0
  25. biomedisa/features/pycuda_test.py +85 -0
  26. biomedisa/features/random_walk/__init__.py +0 -0
  27. biomedisa/features/random_walk/gpu_kernels.py +184 -0
  28. biomedisa/features/random_walk/pycuda_large.py +826 -0
  29. biomedisa/features/random_walk/pycuda_large_allx.py +806 -0
  30. biomedisa/features/random_walk/pycuda_small.py +414 -0
  31. biomedisa/features/random_walk/pycuda_small_allx.py +493 -0
  32. biomedisa/features/random_walk/pyopencl_large.py +760 -0
  33. biomedisa/features/random_walk/pyopencl_small.py +441 -0
  34. biomedisa/features/random_walk/rw_large.py +389 -0
  35. biomedisa/features/random_walk/rw_small.py +307 -0
  36. biomedisa/features/remove_outlier.py +396 -0
  37. biomedisa/features/split_volume.py +167 -0
  38. biomedisa/interpolation.py +369 -0
  39. biomedisa/mesh.py +403 -0
  40. biomedisa-24.5.23.dist-info/LICENSE +191 -0
  41. biomedisa-24.5.23.dist-info/METADATA +261 -0
  42. biomedisa-24.5.23.dist-info/RECORD +44 -0
  43. biomedisa-24.5.23.dist-info/WHEEL +5 -0
  44. biomedisa-24.5.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,581 @@
1
+ ##########################################################################
2
+ ## ##
3
+ ## Copyright (c) 2019-2024 Philipp Lösel. All rights reserved. ##
4
+ ## ##
5
+ ## This file is part of the open source project biomedisa. ##
6
+ ## ##
7
+ ## Licensed under the European Union Public Licence (EUPL) ##
8
+ ## v1.2, or - as soon as they will be approved by the ##
9
+ ## European Commission - subsequent versions of the EUPL; ##
10
+ ## ##
11
+ ## You may redistribute it and/or modify it under the terms ##
12
+ ## of the EUPL v1.2. You may not use this work except in ##
13
+ ## compliance with this Licence. ##
14
+ ## ##
15
+ ## You can obtain a copy of the Licence at: ##
16
+ ## ##
17
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
18
+ ## ##
19
+ ## Unless required by applicable law or agreed to in ##
20
+ ## writing, software distributed under the Licence is ##
21
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
22
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
23
+ ## ##
24
+ ## See the Licence for the specific language governing ##
25
+ ## permissions and limitations under the Licence. ##
26
+ ## ##
27
+ ##########################################################################
28
+
29
+ import os
30
+ from biomedisa.features.keras_helper import read_img_list
31
+ from biomedisa.features.biomedisa_helper import img_resize, load_data, save_data, set_labels_to_zero
32
+ from tensorflow.python.framework.errors_impl import ResourceExhaustedError
33
+ from tensorflow.keras.applications import DenseNet121, densenet
34
+ from tensorflow.keras.optimizers import Adam
35
+ from tensorflow.keras.models import Model, load_model
36
+ from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dropout, Dense
37
+ from tensorflow.keras.callbacks import Callback, ModelCheckpoint
38
+ from biomedisa.features.DataGeneratorCrop import DataGeneratorCrop
39
+ from biomedisa.features.PredictDataGeneratorCrop import PredictDataGeneratorCrop
40
+ import tensorflow as tf
41
+ import numpy as np
42
+ from glob import glob
43
+ import h5py
44
+ import tarfile
45
+ import matplotlib.pyplot as plt
46
+ import tempfile
47
+
48
+ class InputError(Exception):
49
+ def __init__(self, message=None):
50
+ self.message = message
51
+
52
+ def save_history(history, path_to_model):
53
+ # summarize history for accuracy
54
+ plt.plot(history['accuracy'])
55
+ plt.plot(history['val_accuracy'])
56
+ if 'val_loss' in history:
57
+ plt.legend(['train', 'test'], loc='upper left')
58
+ else:
59
+ plt.legend(['train', 'test (Dice)'], loc='upper left')
60
+ plt.title('model accuracy')
61
+ plt.ylabel('accuracy')
62
+ plt.xlabel('epoch')
63
+ plt.tight_layout() # To prevent overlapping of subplots
64
+ plt.savefig(path_to_model.replace(".h5","_acc.png"), dpi=300, bbox_inches='tight')
65
+ plt.clf()
66
+ # summarize history for loss
67
+ plt.plot(history['loss'])
68
+ if 'val_loss' in history:
69
+ plt.plot(history['val_loss'])
70
+ plt.title('model loss')
71
+ plt.ylabel('loss')
72
+ plt.xlabel('epoch')
73
+ plt.legend(['train', 'test'], loc='upper left')
74
+ plt.tight_layout() # To prevent overlapping of subplots
75
+ plt.savefig(path_to_model.replace(".h5","_loss.png"), dpi=300, bbox_inches='tight')
76
+ plt.clf()
77
+
78
+ def make_densenet(inputshape):
79
+ base_model = DenseNet121(
80
+ input_tensor=Input(inputshape),
81
+ include_top=False,)
82
+
83
+ base_model.trainable= False
84
+
85
+ inputs = Input(inputshape)
86
+ x = densenet.preprocess_input(inputs)
87
+
88
+ x = base_model(x, training=False)
89
+ x = GlobalAveragePooling2D()(x)
90
+ x = Dropout(0.3)(x)
91
+
92
+ outputs = Dense(1, activation='sigmoid')(x)
93
+
94
+ model = Model(inputs, outputs)
95
+ return model
96
+
97
+ def load_cropping_training_data(normalize, img_list, label_list, x_scale, y_scale, z_scale,
98
+ labels_to_compute, labels_to_remove, img_in, label_in, normalization_parameters=None, channels=None):
99
+
100
+ # make temporary directories
101
+ with tempfile.TemporaryDirectory() as temp_img_dir:
102
+ with tempfile.TemporaryDirectory() as temp_label_dir:
103
+
104
+ # read image lists
105
+ if any(img_list):
106
+ img_names, label_names = read_img_list(img_list, label_list, temp_img_dir, temp_label_dir)
107
+
108
+ # load first label
109
+ if any(img_list):
110
+ a, _, _ = load_data(label_names[0], 'first_queue', True)
111
+ if a is None:
112
+ InputError.message = f'Invalid label data "{os.path.basename(label_names[0])}"'
113
+ raise InputError()
114
+ elif type(label_in) is list:
115
+ a = label_in[0]
116
+ else:
117
+ a = label_in
118
+ a = a.astype(np.uint8)
119
+ a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
120
+ label_z = np.any(a,axis=(1,2))
121
+ label_y = np.any(a,axis=(0,2))
122
+ label_x = np.any(a,axis=(0,1))
123
+ label = np.append(label_z,label_y,axis=0)
124
+ label = np.append(label,label_x,axis=0)
125
+
126
+ # load first img
127
+ if any(img_list):
128
+ img, _ = load_data(img_names[0], 'first_queue')
129
+ if img is None:
130
+ InputError.message = f'Invalid image data "{os.path.basename(img_names[0])}"'
131
+ raise InputError()
132
+ elif type(img_in) is list:
133
+ img = img_in[0]
134
+ else:
135
+ img = img_in
136
+ # handle all images having channels >=1
137
+ if len(img.shape)==3:
138
+ z_shape, y_shape, x_shape = img.shape
139
+ img = img.reshape(z_shape, y_shape, x_shape, 1)
140
+ if channels is None:
141
+ channels = img.shape[3]
142
+ if img.shape[3] != channels:
143
+ InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[0])}"'
144
+ raise InputError()
145
+ img = img.astype(np.float32)
146
+ img_z = img_resize(img, a.shape[0], y_scale, x_scale)
147
+ img_y = np.swapaxes(img_resize(img, z_scale, a.shape[1], x_scale),0,1)
148
+ img_x = np.swapaxes(img_resize(img, z_scale, y_scale, a.shape[2]),0,2)
149
+ img = np.append(img_z,img_y,axis=0)
150
+ img = np.append(img,img_x,axis=0)
151
+
152
+ # normalize image data
153
+ for c in range(channels):
154
+ img[:,:,:,c] -= np.amin(img[:,:,:,c])
155
+ img[:,:,:,c] /= np.amax(img[:,:,:,c])
156
+ if normalization_parameters is None:
157
+ normalization_parameters = np.zeros((2,channels))
158
+ normalization_parameters[0,c] = np.mean(img[:,:,:,c])
159
+ normalization_parameters[1,c] = np.std(img[:,:,:,c])
160
+ elif normalize:
161
+ mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
162
+ img[:,:,:,c] = (img[:,:,:,c] - mean) / std
163
+ img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
164
+
165
+ # loop over list of images
166
+ if any(img_list) or type(img_in) is list:
167
+ number_of_images = len(img_names) if any(img_list) else len(img_in)
168
+
169
+ for k in range(1, number_of_images):
170
+
171
+ # append label
172
+ if any(label_list):
173
+ a, _ = load_data(label_names[k], 'first_queue')
174
+ if a is None:
175
+ InputError.message = f'Invalid label data "{os.path.basename(label_names[k])}"'
176
+ raise InputError()
177
+ else:
178
+ a = label_in[k]
179
+ a = a.astype(np.uint8)
180
+ a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
181
+ next_label_z = np.any(a,axis=(1,2))
182
+ next_label_y = np.any(a,axis=(0,2))
183
+ next_label_x = np.any(a,axis=(0,1))
184
+ label = np.append(label,next_label_z,axis=0)
185
+ label = np.append(label,next_label_y,axis=0)
186
+ label = np.append(label,next_label_x,axis=0)
187
+
188
+ # append image
189
+ if any(img_list):
190
+ a, _ = load_data(img_names[k], 'first_queue')
191
+ if a is None:
192
+ InputError.message = f'Invalid image data "{os.path.basename(img_names[k])}"'
193
+ raise InputError()
194
+ else:
195
+ a = img_in[k]
196
+ if len(a.shape)==3:
197
+ z_shape, y_shape, x_shape = a.shape
198
+ a = a.reshape(z_shape, y_shape, x_shape, 1)
199
+ if a.shape[3] != channels:
200
+ InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
201
+ raise InputError()
202
+ a = a.astype(np.float32)
203
+ img_z = img_resize(a, a.shape[0], y_scale, x_scale)
204
+ img_y = np.swapaxes(img_resize(a, z_scale, a.shape[1], x_scale),0,1)
205
+ img_x = np.swapaxes(img_resize(a, z_scale, y_scale, a.shape[2]),0,2)
206
+ next_img = np.append(img_z,img_y,axis=0)
207
+ next_img = np.append(next_img,img_x,axis=0)
208
+ for c in range(channels):
209
+ next_img[:,:,:,c] -= np.amin(next_img[:,:,:,c])
210
+ next_img[:,:,:,c] /= np.amax(next_img[:,:,:,c])
211
+ if normalize:
212
+ mean, std = np.mean(next_img[:,:,:,c]), np.std(next_img[:,:,:,c])
213
+ next_img[:,:,:,c] = (next_img[:,:,:,c] - mean) / std
214
+ next_img[:,:,:,c] = next_img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
215
+ img = np.append(img, next_img, axis=0)
216
+
217
+ # limit intensity range
218
+ img[img<0] = 0
219
+ img[img>1] = 1
220
+ img = np.uint8(img*255)
221
+
222
+ # number of channels must be three (reuse or cut off)
223
+ min_channels = min(3, channels)
224
+ img_rgb = np.empty((img.shape[:3] + (3,)), dtype=np.uint8)
225
+ for i in range(3):
226
+ img_rgb[...,i] = img[...,i % min_channels]
227
+
228
+ return img_rgb, label, normalization_parameters, channels
229
+
230
+ def train_cropping(img, label, path_to_model, epochs, batch_size,
231
+ validation_split, flip_x, flip_y, flip_z, rotate,
232
+ img_val, label_val):
233
+
234
+ # img shape
235
+ zsh, ysh, xsh, channels = img.shape
236
+
237
+ # list of IDs
238
+ list_IDs_fg = list(np.where(label)[0])
239
+ list_IDs_bg = list(np.where(label==False)[0])
240
+
241
+ # validation data
242
+ if np.any(img_val):
243
+ list_IDs_val_fg = list(np.where(label_val)[0])
244
+ list_IDs_val_bg = list(np.where(label_val==False)[0])
245
+ elif validation_split:
246
+ split_fg = int(len(list_IDs_fg) * validation_split)
247
+ split_bg = int(len(list_IDs_bg) * validation_split)
248
+ list_IDs_val_fg = list_IDs_fg[split_fg:]
249
+ list_IDs_val_bg = list_IDs_bg[split_bg:]
250
+ list_IDs_fg = list_IDs_fg[:split_fg]
251
+ list_IDs_bg = list_IDs_bg[:split_bg]
252
+
253
+ # upsample IDs
254
+ max_IDs = max(len(list_IDs_fg), len(list_IDs_bg))
255
+ tmp_fg = []
256
+ while len(tmp_fg) < max_IDs:
257
+ tmp_fg.extend(list_IDs_fg)
258
+ tmp_fg = tmp_fg[:max_IDs]
259
+ list_IDs_fg = tmp_fg[:]
260
+
261
+ tmp_bg = []
262
+ while len(tmp_bg) < max_IDs:
263
+ tmp_bg.extend(list_IDs_bg)
264
+ tmp_bg = tmp_bg[:max_IDs]
265
+ list_IDs_bg = tmp_bg[:]
266
+
267
+ # validation data
268
+ if np.any(img_val) or validation_split:
269
+ max_IDs = max(len(list_IDs_val_fg), len(list_IDs_val_bg))
270
+ tmp_fg = []
271
+ while len(tmp_fg) < max_IDs:
272
+ tmp_fg.extend(list_IDs_val_fg)
273
+ tmp_fg = tmp_fg[:max_IDs]
274
+ list_IDs_val_fg = tmp_fg[:]
275
+ tmp_bg = []
276
+
277
+ while len(tmp_bg) < max_IDs:
278
+ tmp_bg.extend(list_IDs_val_bg)
279
+ tmp_bg = tmp_bg[:max_IDs]
280
+ list_IDs_val_bg = tmp_bg[:]
281
+
282
+ # input shape
283
+ input_shape = (ysh, xsh, channels)
284
+
285
+ # parameters
286
+ params = {'dim': (ysh, xsh),
287
+ 'batch_size': batch_size,
288
+ 'n_classes': 2,
289
+ 'n_channels': channels,
290
+ 'shuffle': True}
291
+
292
+ # validation parameters
293
+ params_val = {'dim': (ysh, xsh),
294
+ 'batch_size': batch_size,
295
+ 'n_classes': 2,
296
+ 'n_channels': channels,
297
+ 'shuffle': False}
298
+
299
+ # data generator
300
+ training_generator = DataGeneratorCrop(img, label, list_IDs_fg, list_IDs_bg, **params)
301
+ if np.any(img_val):
302
+ validation_generator = DataGeneratorCrop(img_val, label_val, list_IDs_val_fg, list_IDs_val_bg, **params_val)
303
+ elif validation_split:
304
+ validation_generator = DataGeneratorCrop(img, label, list_IDs_val_fg, list_IDs_val_bg, **params_val)
305
+ else:
306
+ validation_generator = None
307
+
308
+ # create a MirroredStrategy
309
+ cdo = tf.distribute.ReductionToOneDevice()
310
+ strategy = tf.distribute.MirroredStrategy(cross_device_ops=cdo)
311
+ print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
312
+
313
+ # create callback
314
+ if np.any(img_val) or validation_split:
315
+ save_best_only = True
316
+ else:
317
+ save_best_only = False
318
+ checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(str(path_to_model), save_best_only=save_best_only)
319
+
320
+ # compile model
321
+ with strategy.scope():
322
+ model = make_densenet(input_shape)
323
+ model.compile(loss= tf.keras.losses.BinaryCrossentropy(),
324
+ optimizer= Adam(learning_rate=0.001),
325
+ metrics=['accuracy'])
326
+ # train model
327
+ history = model.fit(training_generator,
328
+ validation_data=validation_generator,
329
+ epochs=max(1,epochs),
330
+ callbacks=checkpoint_cb)
331
+
332
+ # save results in figure on train end
333
+ if np.any(img_val) or validation_split:
334
+ save_history(history.history, path_to_model.replace(".h5","_cropping.h5"))
335
+
336
+ # compile model for finetunning
337
+ with strategy.scope():
338
+ model = load_model(str(path_to_model))
339
+ model.trainable = True
340
+ model.compile(loss= tf.keras.losses.BinaryCrossentropy(),
341
+ optimizer= Adam(learning_rate=1e-5),
342
+ metrics=['accuracy'])
343
+
344
+ # finetune model
345
+ history = model.fit(training_generator,
346
+ validation_data=validation_generator,
347
+ epochs=max(1,epochs),
348
+ callbacks=checkpoint_cb)
349
+
350
+ # save results in figure on train end
351
+ if np.any(img_val) or validation_split:
352
+ save_history(history.history, path_to_model.replace(".h5","_cropfine.h5"))
353
+
354
+ def load_data_to_crop(path_to_img, channels, x_scale, y_scale, z_scale,
355
+ normalize, normalization_parameters, img):
356
+ # read image data
357
+ if img is None:
358
+ img, _, _ = load_data(path_to_img, 'first_queue', return_extension=True)
359
+ img_data = np.copy(img, order='C')
360
+ if img is None:
361
+ InputError.message = "Invalid image data %s." %(os.path.basename(path_to_img))
362
+ raise InputError()
363
+ # handle all images having channels >=1
364
+ if len(img.shape)==3:
365
+ z_shape, y_shape, x_shape = img.shape
366
+ img = img.reshape(z_shape, y_shape, x_shape, 1)
367
+ if img.shape[3] != channels:
368
+ InputError.message = f'Number of channels must be {channels}.'
369
+ raise InputError()
370
+ z_shape, y_shape, x_shape, _ = img.shape
371
+ img = img.astype(np.float32)
372
+ img_z = img_resize(img, z_shape, y_scale, x_scale)
373
+ img_y = np.swapaxes(img_resize(img,z_scale,y_shape,x_scale),0,1)
374
+ img_x = np.swapaxes(img_resize(img,z_scale,y_scale,x_shape),0,2)
375
+ img = np.append(img_z,img_y,axis=0)
376
+ img = np.append(img,img_x,axis=0)
377
+ for c in range(channels):
378
+ img[:,:,:,c] -= np.amin(img[:,:,:,c])
379
+ img[:,:,:,c] /= np.amax(img[:,:,:,c])
380
+ if normalize:
381
+ mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
382
+ img[:,:,:,c] = (img[:,:,:,c] - mean) / std
383
+ img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
384
+ img[img<0] = 0
385
+ img[img>1] = 1
386
+ img = np.uint8(img*255)
387
+
388
+ # number of channels must be three (reuse or cut off)
389
+ channels = min(3, channels)
390
+ img_rgb = np.empty((img.shape[:3] + (3,)), dtype=np.uint8)
391
+ for i in range(3):
392
+ img_rgb[...,i] = img[...,i % channels]
393
+ return img_rgb, z_shape, y_shape, x_shape, img_data
394
+
395
+ def crop_volume(img, path_to_model, path_to_final, z_shape, y_shape, x_shape, batch_size,
396
+ debug_cropping, save_cropped, img_data, x_range, y_range, z_range, x_puffer=25, y_puffer=25, z_puffer=25):
397
+
398
+ # img shape
399
+ zsh, ysh, xsh, channels = img.shape
400
+
401
+ # list of IDs
402
+ list_IDs = [x for x in range(zsh)]
403
+
404
+ # make length of list divisible by batch size
405
+ rest = batch_size - (len(list_IDs) % batch_size)
406
+ list_IDs = list_IDs + list_IDs[:rest]
407
+
408
+ # parameters
409
+ params = {'dim': (ysh,xsh),
410
+ 'dim_img': (zsh, ysh, xsh),
411
+ 'batch_size': batch_size,
412
+ 'n_channels': channels}
413
+
414
+ # data generator
415
+ predict_generator = PredictDataGeneratorCrop(img, list_IDs, **params)
416
+
417
+ # input shape
418
+ input_shape = (ysh, xsh, channels)
419
+
420
+ # load model
421
+ model = make_densenet(input_shape)
422
+
423
+ # load weights
424
+ hf = h5py.File(path_to_model, 'r')
425
+ cropping_weights = hf.get('cropping_weights')
426
+ iterator = 0
427
+ for layer in model.layers:
428
+ if layer.get_weights() != []:
429
+ new_weights = []
430
+ for arr in layer.get_weights():
431
+ new_weights.append(cropping_weights.get(str(iterator)))
432
+ iterator += 1
433
+ layer.set_weights(new_weights)
434
+ hf.close()
435
+
436
+ # predict
437
+ probabilities = model.predict(predict_generator, verbose=0, steps=None)
438
+ probabilities = probabilities[:zsh]
439
+ probabilities = np.ravel(probabilities)
440
+
441
+ # plot prediction
442
+ if debug_cropping and path_to_final:
443
+ import matplotlib.pyplot as plt
444
+ import matplotlib
445
+ x = range(len(probabilities))
446
+ y = list(probabilities)
447
+ plt.plot(x, y)
448
+
449
+ # create mask
450
+ probabilities[probabilities > 0.5] = 1
451
+ probabilities[probabilities <= 0.5] = 0
452
+
453
+ # remove outliers
454
+ for k in range(4,zsh-4):
455
+ if np.all(probabilities[k-1:k+2] == np.array([0,1,0])):
456
+ probabilities[k-1:k+2] = 0
457
+ elif np.all(probabilities[k-2:k+2] == np.array([0,1,1,0])):
458
+ probabilities[k-2:k+2] = 0
459
+ elif np.all(probabilities[k-2:k+3] == np.array([0,1,1,1,0])):
460
+ probabilities[k-2:k+3] = 0
461
+ elif np.all(probabilities[k-3:k+3] == np.array([0,1,1,1,1,0])):
462
+ probabilities[k-3:k+3] = 0
463
+ elif np.all(probabilities[k-3:k+4] == np.array([0,1,1,1,1,1,0])):
464
+ probabilities[k-3:k+4] = 0
465
+ elif np.all(probabilities[k-4:k+4] == np.array([0,1,1,1,1,1,1,0])):
466
+ probabilities[k-4:k+4] = 0
467
+ elif np.all(probabilities[k-4:k+5] == np.array([0,1,1,1,1,1,1,1,0])):
468
+ probabilities[k-4:k+5] = 0
469
+
470
+ # create final
471
+ if z_range is not None:
472
+ z_lower, z_upper = z_range
473
+ else:
474
+ z_lower = max(0,np.argmax(probabilities[:z_shape]) - z_puffer)
475
+ z_upper = min(z_shape,z_shape - np.argmax(np.flip(probabilities[:z_shape])) + z_puffer +1)
476
+
477
+ if y_range is not None:
478
+ y_lower, y_upper = y_range
479
+ else:
480
+ y_lower = max(0,np.argmax(probabilities[z_shape:z_shape+y_shape]) - y_puffer)
481
+ y_upper = min(y_shape,y_shape - np.argmax(np.flip(probabilities[z_shape:z_shape+y_shape])) + y_puffer +1)
482
+
483
+ if x_range is not None:
484
+ x_lower, x_upper = x_range
485
+ else:
486
+ x_lower = max(0,np.argmax(probabilities[z_shape+y_shape:]) - x_puffer)
487
+ x_upper = min(x_shape,x_shape - np.argmax(np.flip(probabilities[z_shape+y_shape:])) + x_puffer +1)
488
+
489
+ # plot result
490
+ if debug_cropping and path_to_final:
491
+ y = np.zeros_like(probabilities)
492
+ y[z_lower:z_upper] = 1
493
+ y[z_shape+y_lower:z_shape+y_upper] = 1
494
+ y[z_shape+y_shape+x_lower:z_shape+y_shape+x_upper] = 1
495
+ plt.plot(x, y, '--')
496
+ plt.tight_layout() # To prevent overlapping of subplots
497
+ #matplotlib.use("GTK3Agg")
498
+ plt.savefig(path_to_final.replace('.tif','.png'), dpi=300)
499
+
500
+ # crop image data
501
+ cropped_volume = img_data[z_lower:z_upper, y_lower:y_upper, x_lower:x_upper]
502
+ if save_cropped and path_to_final:
503
+ save_data(path_to_final, cropped_volume, compress=False)
504
+
505
+ return z_lower, z_upper, y_lower, y_upper, x_lower, x_upper, cropped_volume
506
+
507
+ #=====================
508
+ # main functions
509
+ #=====================
510
+
511
+ def load_and_train(normalize,path_to_img,path_to_labels,path_to_model,
512
+ epochs,batch_size,validation_split,
513
+ flip_x,flip_y,flip_z,rotate,labels_to_compute,labels_to_remove,
514
+ path_val_img=[None],path_val_labels=[None],
515
+ img=None, label=None, img_val=None, label_val=None,
516
+ x_scale=256, y_scale=256, z_scale=256):
517
+
518
+ # load training data
519
+ img, label, normalization_parameters, channels = load_cropping_training_data(normalize,
520
+ path_to_img, path_to_labels, x_scale, y_scale, z_scale, labels_to_compute, labels_to_remove,
521
+ img, label)
522
+
523
+ # load validation data
524
+ if any(path_val_img) or img_val is not None:
525
+ img_val, label_val, _, _ = load_cropping_training_data(normalize,
526
+ path_val_img, path_val_labels, x_scale, y_scale, z_scale,
527
+ labels_to_compute, labels_to_remove,
528
+ img_val, label_val, normalization_parameters, channels)
529
+
530
+ # train cropping
531
+ train_cropping(img, label, path_to_model, epochs,
532
+ batch_size, validation_split,
533
+ flip_x, flip_y, flip_z, rotate,
534
+ img_val, label_val)
535
+
536
+ # load weights
537
+ model = load_model(str(path_to_model))
538
+ cropping_weights = []
539
+ for layer in model.layers:
540
+ if layer.get_weights() != []:
541
+ for arr in layer.get_weights():
542
+ cropping_weights.append(arr)
543
+
544
+ # configuration data
545
+ cropping_config = np.array([channels, x_scale, y_scale, z_scale, normalize, 0, 1])
546
+
547
+ return cropping_weights, cropping_config, normalization_parameters
548
+
549
+ def crop_data(path_to_data, path_to_model, path_to_cropped_image, batch_size,
550
+ debug_cropping=False, save_cropped=True, img_data=None,
551
+ x_range=None, y_range=None, z_range=None):
552
+
553
+ # get meta data
554
+ hf = h5py.File(path_to_model, 'r')
555
+ meta = hf.get('cropping_meta')
556
+ configuration = meta.get('configuration')
557
+ channels, x_scale, y_scale, z_scale, normalize, mu, sig = np.array(configuration)[:]
558
+ channels, x_scale, y_scale, z_scale, normalize, mu, sig = int(channels), int(x_scale), \
559
+ int(y_scale), int(z_scale), int(normalize), float(mu), float(sig)
560
+ if '/cropping_meta/normalization' in hf:
561
+ normalization_parameters = np.array(meta.get('normalization'), dtype=float)
562
+ else:
563
+ # old configuration
564
+ normalization_parameters = np.array([[mu],[sig]])
565
+ channels = 1
566
+ hf.close()
567
+
568
+ # load data
569
+ img, z_shape, y_shape, x_shape, img_data = load_data_to_crop(path_to_data, channels,
570
+ x_scale, y_scale, z_scale, normalize, normalization_parameters, img_data)
571
+
572
+ # make prediction
573
+ z_lower, z_upper, y_lower, y_upper, x_lower, x_upper, cropped_volume = crop_volume(img, path_to_model,
574
+ path_to_cropped_image, z_shape, y_shape, x_shape, batch_size, debug_cropping, save_cropped, img_data,
575
+ x_range, y_range, z_range)
576
+
577
+ # region of interest
578
+ region_of_interest = np.array([z_lower, z_upper, y_lower, y_upper, x_lower, x_upper])
579
+
580
+ return region_of_interest, cropped_volume
581
+