biomedisa 24.8.5__py3-none-any.whl → 24.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
biomedisa/deeplearning.py CHANGED
@@ -71,7 +71,7 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
71
71
  learning_rate=0.01, stride_size=32, validation_stride_size=32, validation_freq=1,
72
72
  batch_size=None, x_scale=256, y_scale=256, z_scale=256, scaling=True, early_stopping=0,
73
73
  pretrained_model=None, fine_tune=False, workers=1, cropping_epochs=50,
74
- x_range=None, y_range=None, z_range=None, header=None, extension='.tif',
74
+ x_range=None, y_range=None, z_range=None, header=None, extension=None,
75
75
  img_header=None, img_extension='.tif', average_dice=False, django_env=False,
76
76
  path=None, success=True, return_probs=False, patch_normalization=False,
77
77
  z_patch=64, y_patch=64, x_patch=64, path_to_logfile=None, img_id=None, label_id=None,
@@ -228,9 +228,11 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
228
228
  bm.scaling = bool(meta['scaling'][()])
229
229
 
230
230
  # check if amira header is available in the network
231
- if bm.header is None and meta.get('header') is not None:
231
+ if bm.extension is None and bm.header is None and meta.get('header') is not None:
232
232
  bm.header = [np.array(meta.get('header'))]
233
233
  bm.extension = '.am'
234
+ if bm.extension is None:
235
+ bm.extension = '.tif'
234
236
 
235
237
  # crop data
236
238
  crop_data = True if 'cropping_weights' in hf else False
@@ -301,6 +303,12 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
301
303
  results, bm = predict_segmentation(bm, region_of_interest,
302
304
  channels, normalization_parameters)
303
305
 
306
+ from mpi4py import MPI
307
+ comm = MPI.COMM_WORLD
308
+ rank = comm.Get_rank()
309
+ if rank>0:
310
+ return 0
311
+
304
312
  # results
305
313
  if cropped_volume is not None:
306
314
  results['cropped_volume'] = cropped_volume
@@ -488,8 +496,8 @@ if __name__ == '__main__':
488
496
  help='Location of mask')
489
497
  parser.add_argument('-rf','--refinement', action='store_true', default=False,
490
498
  help='Refine segmentation on full size data')
491
- parser.add_argument('-ext','--extension', type=str, default='.tif',
492
- help='Save data for example as NRRD file using --extension=".nrrd"')
499
+ parser.add_argument('-ext','--extension', type=str, default=None,
500
+ help='Save data in formats like NRRD or TIFF using --extension=".nrrd"')
493
501
  bm = parser.parse_args()
494
502
  bm.success = True
495
503
 
@@ -129,39 +129,25 @@ class DataGenerator(tf.keras.utils.Sequence):
129
129
  def __len__(self):
130
130
  'Denotes the number of batches per epoch'
131
131
  if len(self.list_IDs_bg) > 0:
132
- len_IDs = 2 * max(len(self.list_IDs_fg), len(self.list_IDs_bg))
132
+ len_IDs = max(len(self.list_IDs_fg), len(self.list_IDs_bg))
133
+ n_batches = len_IDs // (self.batch_size // 2)
133
134
  else:
134
135
  len_IDs = len(self.list_IDs_fg)
135
- n_batches = int(np.floor(len_IDs / self.batch_size))
136
+ n_batches = len_IDs // self.batch_size
136
137
  return n_batches
137
138
 
138
139
  def __getitem__(self, index):
139
140
  'Generate one batch of data'
140
-
141
141
  if len(self.list_IDs_bg) > 0:
142
-
143
- # len IDs
144
- len_IDs = max(len(self.list_IDs_fg), len(self.list_IDs_bg))
145
-
146
- # upsample lists of indexes to the same size
147
- repetitions = int(np.floor(len_IDs / len(self.list_IDs_fg))) + 1
148
- upsampled_indexes_fg = np.tile(self.indexes_fg, repetitions)
149
- upsampled_indexes_fg = upsampled_indexes_fg[:len_IDs]
150
-
151
- repetitions = int(np.floor(len_IDs / len(self.list_IDs_bg))) + 1
152
- upsampled_indexes_bg = np.tile(self.indexes_bg, repetitions)
153
- upsampled_indexes_bg = upsampled_indexes_bg[:len_IDs]
154
-
155
142
  # Generate indexes of the batch
156
- tmp_batch_size = int(self.batch_size / 2)
157
- indexes_fg = upsampled_indexes_fg[index*tmp_batch_size:(index+1)*tmp_batch_size]
158
- indexes_bg = upsampled_indexes_bg[index*tmp_batch_size:(index+1)*tmp_batch_size]
143
+ half_batch_size = self.batch_size // 2
144
+ indexes_fg = self.indexes_fg[index*half_batch_size:(index+1)*half_batch_size]
145
+ indexes_bg = self.indexes_bg[index*half_batch_size:(index+1)*half_batch_size]
159
146
 
160
147
  # Find list of IDs
161
148
  list_IDs_temp = [self.list_IDs_fg[k] for k in indexes_fg] + [self.list_IDs_bg[k] for k in indexes_bg]
162
149
 
163
150
  else:
164
-
165
151
  # Generate indexes of the batch
166
152
  indexes_fg = self.indexes_fg[index*self.batch_size:(index+1)*self.batch_size]
167
153
 
@@ -175,11 +161,22 @@ class DataGenerator(tf.keras.utils.Sequence):
175
161
 
176
162
  def on_epoch_end(self):
177
163
  'Updates indexes after each epoch'
178
- self.indexes_fg = np.arange(len(self.list_IDs_fg))
179
- self.indexes_bg = np.arange(len(self.list_IDs_bg))
164
+ if len(self.list_IDs_bg) > 0:
165
+ # upsample lists of indexes
166
+ indexes_fg = np.arange(len(self.list_IDs_fg))
167
+ indexes_bg = np.arange(len(self.list_IDs_bg))
168
+ len_IDs = max(len(self.list_IDs_fg), len(self.list_IDs_bg))
169
+ repetitions = len_IDs // len(self.list_IDs_fg) + 1
170
+ self.indexes_fg = np.tile(indexes_fg, repetitions)
171
+ repetitions = len_IDs // len(self.list_IDs_bg) + 1
172
+ self.indexes_bg = np.tile(indexes_bg, repetitions)
173
+ else:
174
+ self.indexes_fg = np.arange(len(self.list_IDs_fg))
175
+ # shuffle indexes
180
176
  if self.shuffle == True:
181
177
  np.random.shuffle(self.indexes_fg)
182
- np.random.shuffle(self.indexes_bg)
178
+ if len(self.list_IDs_bg) > 0:
179
+ np.random.shuffle(self.indexes_bg)
183
180
 
184
181
  def __data_generation(self, list_IDs_temp):
185
182
  'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
@@ -28,22 +28,25 @@
28
28
 
29
29
  import numpy as np
30
30
  import tensorflow as tf
31
- from scipy.ndimage import gaussian_filter, map_coordinates
31
+ from scipy.ndimage import gaussian_filter, map_coordinates, rotate
32
+ import random
32
33
 
33
34
  def elastic_transform(image, alpha=100, sigma=20):
34
- zsh, ysh, xsh = image.shape
35
+ ysh, xsh, csh = image.shape
35
36
  dx = gaussian_filter((np.random.rand(ysh, xsh) * 2 - 1) * alpha, sigma)
36
37
  dy = gaussian_filter((np.random.rand(ysh, xsh) * 2 - 1) * alpha, sigma)
37
38
  y, x = np.meshgrid(np.arange(ysh), np.arange(xsh), indexing='ij')
38
39
  indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))
39
- for k in range(zsh):
40
- image[k] = map_coordinates(image[k], indices, order=1, mode='reflect').reshape(ysh, xsh)
40
+ for k in range(csh):
41
+ image[:,:,k] = map_coordinates(image[:,:,k], indices, order=0, mode='reflect').reshape(ysh, xsh)
41
42
  return image
42
43
 
43
44
  class DataGeneratorCrop(tf.keras.utils.Sequence):
44
45
  'Generates data for Keras'
45
- def __init__(self, img, label, list_IDs_fg, list_IDs_bg, batch_size=32, dim=(32,32,32),
46
- n_channels=3, n_classes=2, shuffle=True):
46
+ def __init__(self, img, label, list_IDs_fg, list_IDs_bg, batch_size=32,
47
+ dim=(32,32,32), n_channels=3, n_classes=2, shuffle=True,
48
+ augment=(False,False,False,0), train=True):
49
+
47
50
  'Initialization'
48
51
  self.dim = dim
49
52
  self.list_IDs_fg = list_IDs_fg
@@ -54,6 +57,8 @@ class DataGeneratorCrop(tf.keras.utils.Sequence):
54
57
  self.n_channels = n_channels
55
58
  self.n_classes = n_classes
56
59
  self.shuffle = shuffle
60
+ self.augment = augment
61
+ self.train = train
57
62
  self.on_epoch_end()
58
63
 
59
64
  def __len__(self):
@@ -108,14 +113,37 @@ class DataGeneratorCrop(tf.keras.utils.Sequence):
108
113
  def __data_generation(self, list_IDs_temp):
109
114
  'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
110
115
 
116
+ # get augmentation parameter
117
+ flip_x, flip_y, flip_z, rotation = self.augment
118
+ elastic = False
119
+
111
120
  # Initialization
112
121
  X = np.empty((self.batch_size, *self.dim, self.n_channels), dtype=np.uint8)
113
122
  y = np.empty((self.batch_size,), dtype=np.int32)
114
123
 
115
124
  # Generate data
116
125
  for i, ID in enumerate(list_IDs_temp):
117
- X[i,...] = self.img[ID,...]
118
- y[i] = self.label[ID]
126
+ tmp_X = self.img[ID,...].copy()
127
+
128
+ # augmentation
129
+ if self.train and (any(self.augment) or elastic):
130
+ if flip_x and np.random.randint(2) and abs(self.label[ID])!=3:
131
+ tmp_X = np.flip(tmp_X, 1)
132
+ if flip_y and np.random.randint(2):
133
+ if abs(self.label[ID])==1:
134
+ tmp_X = np.flip(tmp_X, 0)
135
+ elif abs(self.label[ID])==3:
136
+ tmp_X = np.flip(tmp_X, 1)
137
+ if flip_z and np.random.randint(2) and abs(self.label[ID])!=1:
138
+ tmp_X = np.flip(tmp_X, 0)
139
+ if rotation:
140
+ angle = random.uniform(-rotation, rotation)
141
+ tmp_X = rotate(tmp_X, angle, order=0, mode='reflect', reshape=False)
142
+ if elastic:
143
+ tmp_X = elastic_transform(tmp_X)
144
+
145
+ X[i,...] = tmp_X
146
+ y[i] = 0 if self.label[ID] < 0 else 1
119
147
 
120
148
  return X, y
121
149
 
@@ -44,6 +44,7 @@ import h5py
44
44
  import tarfile
45
45
  import matplotlib.pyplot as plt
46
46
  import tempfile
47
+ import copy
47
48
 
48
49
  class InputError(Exception):
49
50
  def __init__(self, message=None):
@@ -117,9 +118,12 @@ def load_cropping_training_data(normalize, img_list, label_list, x_scale, y_scal
117
118
  a = label_in
118
119
  a = a.astype(np.uint8)
119
120
  a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
120
- label_z = np.any(a,axis=(1,2))
121
- label_y = np.any(a,axis=(0,2))
122
- label_x = np.any(a,axis=(0,1))
121
+ label_z = np.any(a,axis=(1,2)).astype(np.int8) * 1
122
+ label_y = np.any(a,axis=(0,2)).astype(np.int8) * 2
123
+ label_x = np.any(a,axis=(0,1)).astype(np.int8) * 3
124
+ label_z[label_z==0] = -1
125
+ label_y[label_y==0] = -2
126
+ label_x[label_x==0] = -3
123
127
  label = np.append(label_z,label_y,axis=0)
124
128
  label = np.append(label,label_x,axis=0)
125
129
 
@@ -178,12 +182,15 @@ def load_cropping_training_data(normalize, img_list, label_list, x_scale, y_scal
178
182
  a = label_in[k]
179
183
  a = a.astype(np.uint8)
180
184
  a = set_labels_to_zero(a, labels_to_compute, labels_to_remove)
181
- next_label_z = np.any(a,axis=(1,2))
182
- next_label_y = np.any(a,axis=(0,2))
183
- next_label_x = np.any(a,axis=(0,1))
184
- label = np.append(label,next_label_z,axis=0)
185
- label = np.append(label,next_label_y,axis=0)
186
- label = np.append(label,next_label_x,axis=0)
185
+ label_z = np.any(a,axis=(1,2)).astype(np.int8) * 1
186
+ label_y = np.any(a,axis=(0,2)).astype(np.int8) * 2
187
+ label_x = np.any(a,axis=(0,1)).astype(np.int8) * 3
188
+ label_z[label_z==0] = -1
189
+ label_y[label_y==0] = -2
190
+ label_x[label_x==0] = -3
191
+ label = np.append(label,label_z,axis=0)
192
+ label = np.append(label,label_y,axis=0)
193
+ label = np.append(label,label_x,axis=0)
187
194
 
188
195
  # append image
189
196
  if any(img_list):
@@ -228,20 +235,20 @@ def load_cropping_training_data(normalize, img_list, label_list, x_scale, y_scal
228
235
  return img_rgb, label, normalization_parameters, channels
229
236
 
230
237
  def train_cropping(img, label, path_to_model, epochs, batch_size,
231
- validation_split, flip_x, flip_y, flip_z, rotate,
232
- img_val, label_val):
238
+ validation_split, flip_x, flip_y, flip_z, rotate,
239
+ img_val, label_val):
233
240
 
234
241
  # img shape
235
242
  zsh, ysh, xsh, channels = img.shape
236
243
 
237
244
  # list of IDs
238
- list_IDs_fg = list(np.where(label)[0])
239
- list_IDs_bg = list(np.where(label==False)[0])
245
+ list_IDs_fg = list(np.where(label>0)[0])
246
+ list_IDs_bg = list(np.where(label<0)[0])
240
247
 
241
248
  # validation data
242
249
  if np.any(img_val):
243
- list_IDs_val_fg = list(np.where(label_val)[0])
244
- list_IDs_val_bg = list(np.where(label_val==False)[0])
250
+ list_IDs_val_fg = list(np.where(label_val>0)[0])
251
+ list_IDs_val_bg = list(np.where(label_val<0)[0])
245
252
  elif validation_split:
246
253
  split_fg = int(len(list_IDs_fg) * validation_split)
247
254
  split_bg = int(len(list_IDs_bg) * validation_split)
@@ -287,14 +294,13 @@ def train_cropping(img, label, path_to_model, epochs, batch_size,
287
294
  'batch_size': batch_size,
288
295
  'n_classes': 2,
289
296
  'n_channels': channels,
290
- 'shuffle': True}
297
+ 'shuffle': True,
298
+ 'augment': (flip_x, flip_y, flip_z, rotate),
299
+ 'train': True}
291
300
 
292
301
  # validation parameters
293
- params_val = {'dim': (ysh, xsh),
294
- 'batch_size': batch_size,
295
- 'n_classes': 2,
296
- 'n_channels': channels,
297
- 'shuffle': False}
302
+ params_val = copy.deepcopy(params)
303
+ params_val['train'] = False
298
304
 
299
305
  # data generator
300
306
  training_generator = DataGeneratorCrop(img, label, list_IDs_fg, list_IDs_bg, **params)
@@ -497,7 +503,7 @@ def crop_volume(img, path_to_model, path_to_final, z_shape, y_shape, x_shape, ba
497
503
  # main functions
498
504
  #=====================
499
505
 
500
- def load_and_train(bm, x_scale=256, y_scale=256, z_scale=256):
506
+ def load_and_train(bm, x_scale=256, y_scale=256, z_scale=256, batch_size=24):
501
507
 
502
508
  # load training data
503
509
  img, label, normalization_parameters, channels = load_cropping_training_data(bm.normalize,
@@ -512,8 +518,8 @@ def load_and_train(bm, x_scale=256, y_scale=256, z_scale=256):
512
518
  bm.only, bm.ignore, bm.val_img_data, bm.val_label_data, normalization_parameters, channels)
513
519
 
514
520
  # train cropping
515
- train_cropping(img, label, bm.path_to_model, bm.cropping_epochs,
516
- bm.batch_size, bm.validation_split,
521
+ train_cropping(img, label, bm.path_to_model,
522
+ bm.cropping_epochs, batch_size, bm.validation_split,
517
523
  bm.flip_x, bm.flip_y, bm.flip_z, bm.rotate,
518
524
  img_val, label_val)
519
525
 
@@ -531,7 +537,7 @@ def load_and_train(bm, x_scale=256, y_scale=256, z_scale=256):
531
537
 
532
538
  return cropping_weights, cropping_config, normalization_parameters
533
539
 
534
- def crop_data(bm):
540
+ def crop_data(bm, batch_size=32):
535
541
 
536
542
  # get meta data
537
543
  hf = h5py.File(bm.path_to_model, 'r')
@@ -554,7 +560,7 @@ def crop_data(bm):
554
560
 
555
561
  # make prediction
556
562
  z_lower, z_upper, y_lower, y_upper, x_lower, x_upper, cropped_volume = crop_volume(img, bm.path_to_model,
557
- bm.path_to_cropped_image, z_shape, y_shape, x_shape, bm.batch_size, bm.debug_cropping, bm.save_cropped, bm.img_data)
563
+ bm.path_to_cropped_image, z_shape, y_shape, x_shape, batch_size, bm.debug_cropping, bm.save_cropped, bm.img_data)
558
564
 
559
565
  # region of interest
560
566
  region_of_interest = np.array([z_lower, z_upper, y_lower, y_upper, x_lower, x_upper])
@@ -400,6 +400,12 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
400
400
  label = label_in
401
401
  label_names = ['label_1']
402
402
  label_dim = label.shape
403
+ # no-scaling for list of images needs negative values as it encodes padded areas as -1
404
+ label_dtype = label.dtype
405
+ if label_dtype==np.uint8:
406
+ label_dtype = np.int16
407
+ elif label_dtype in [np.uint16, np.uint32]:
408
+ label_dtype = np.int32
403
409
  label = set_labels_to_zero(label, bm.only, bm.ignore)
404
410
  if any([bm.x_range, bm.y_range, bm.z_range]):
405
411
  if len(label_names)>1:
@@ -418,8 +424,6 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
418
424
 
419
425
  # if header is not single data stream Amira Mesh falling back to Multi-TIFF
420
426
  if extension != '.am':
421
- if extension != '.tif':
422
- print(f'Warning! Please use --header_file="path_to_training_label{extension}" for prediction to save your result as "{extension}"')
423
427
  extension, header = '.tif', None
424
428
  elif len(header) > 1:
425
429
  print('Warning! Multiple data streams are not supported. Falling back to TIFF.')
@@ -475,6 +479,11 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
475
479
  img[:,:,:,c] = (img[:,:,:,c] - mean) / std
476
480
  img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
477
481
 
482
+ # pad data
483
+ if not bm.scaling:
484
+ img_data_list = [img]
485
+ label_data_list = [label]
486
+
478
487
  # loop over list of images
479
488
  if any(img_list) or type(img_in) is list:
480
489
  number_of_images = len(img_names) if any(img_list) else len(img_in)
@@ -498,7 +507,9 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
498
507
  label_values, counts = np.unique(a, return_counts=True)
499
508
  print(f'{os.path.basename(label_names[k])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
500
509
  a = img_resize(a, bm.z_scale, bm.y_scale, bm.x_scale, labels=True)
501
- label = np.append(label, a, axis=0)
510
+ label = np.append(label, a, axis=0)
511
+ else:
512
+ label_data_list.append(a)
502
513
 
503
514
  # append image
504
515
  if any(img_list):
@@ -529,11 +540,33 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
529
540
  mean, std = np.mean(a[:,:,:,c]), np.std(a[:,:,:,c])
530
541
  a[:,:,:,c] = (a[:,:,:,c] - mean) / std
531
542
  a[:,:,:,c] = a[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
532
- img = np.append(img, a, axis=0)
543
+ if bm.scaling:
544
+ img = np.append(img, a, axis=0)
545
+ else:
546
+ img_data_list.append(a)
547
+
548
+ # pad and append data to a single volume
549
+ if not bm.scaling and len(img_data_list)>1:
550
+ target_y, target_x = 0, 0
551
+ for img in img_data_list:
552
+ target_y = max(target_y, img.shape[1])
553
+ target_x = max(target_x, img.shape[2])
554
+ img = np.empty((0, target_y, target_x, channels), dtype=np.float32)
555
+ label = np.empty((0, target_y, target_x), dtype=label_dtype)
556
+ for k in range(len(img_data_list)):
557
+ pad_y = target_y - img_data_list[k].shape[1]
558
+ pad_x = target_x - img_data_list[k].shape[2]
559
+ pad_width = [(0, 0), (0, pad_y), (0, pad_x), (0, 0)]
560
+ tmp = np.pad(img_data_list[k], pad_width, mode='constant', constant_values=0)
561
+ img = np.append(img, tmp, axis=0)
562
+ pad_width = [(0, 0), (0, pad_y), (0, pad_x)]
563
+ tmp = np.pad(label_data_list[k].astype(label_dtype), pad_width, mode='constant', constant_values=-1)
564
+ label = np.append(label, tmp, axis=0)
533
565
 
534
566
  # limit intensity range
535
- img[img<0] = 0
536
- img[img>1] = 1
567
+ if bm.normalize:
568
+ img[img<0] = 0
569
+ img[img>1] = 1
537
570
 
538
571
  if bm.separation:
539
572
  allLabels = np.array([0,1])
@@ -541,6 +574,8 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
541
574
  # get labels
542
575
  if allLabels is None:
543
576
  allLabels = np.unique(label)
577
+ index = np.argwhere(allLabels<0)
578
+ allLabels = np.delete(allLabels, index)
544
579
 
545
580
  # labels must be in ascending order
546
581
  for k, l in enumerate(allLabels):
@@ -761,6 +796,35 @@ class Metrics(Callback):
761
796
  if self.early_stopping > 0 and max(self.history['val_dice']) not in self.history['val_dice'][-self.early_stopping:]:
762
797
  self.model.stop_training = True
763
798
 
799
+ class HistoryCallback(Callback):
800
+ def __init__(self, bm):
801
+ self.path_to_model = bm.path_to_model
802
+ self.train_dice = bm.train_dice
803
+ self.val_img_data = bm.val_img_data
804
+
805
+ def on_train_begin(self, logs={}):
806
+ self.history = {}
807
+ self.history['loss'] = []
808
+ self.history['accuracy'] = []
809
+ if self.train_dice:
810
+ self.history['dice'] = []
811
+ if self.val_img_data is not None:
812
+ self.history['val_loss'] = []
813
+ self.history['val_accuracy'] = []
814
+
815
+ def on_epoch_end(self, epoch, logs={}):
816
+ # append history
817
+ self.history['loss'].append(logs['loss'])
818
+ self.history['accuracy'].append(logs['accuracy'])
819
+ if self.train_dice:
820
+ self.history['dice'].append(logs['dice'])
821
+ if self.val_img_data is not None:
822
+ self.history['val_loss'].append(logs['val_loss'])
823
+ self.history['val_accuracy'].append(logs['val_accuracy'])
824
+
825
+ # plot history in figure and save as numpy array
826
+ save_history(self.history, self.path_to_model)
827
+
764
828
  def softmax(x):
765
829
  # Avoiding numerical instability by subtracting the maximum value
766
830
  exp_values = np.exp(x - np.max(x, axis=-1, keepdims=True))
@@ -833,25 +897,30 @@ def train_segmentation(bm):
833
897
  list_IDs_fg, list_IDs_bg = [], []
834
898
 
835
899
  # get IDs of patches
836
- if bm.balance:
837
- for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
838
- for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
839
- for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
840
- if np.any(bm.label_data[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
841
- list_IDs_fg.append(k*ysh*xsh+l*xsh+m)
842
- else:
843
- list_IDs_bg.append(k*ysh*xsh+l*xsh+m)
844
- else:
845
- for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
846
- for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
847
- for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
848
- if bm.separation:
849
- centerLabel = bm.label_data[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
850
- patch = bm.label_data[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
851
- if centerLabel>0 and np.any(patch!=centerLabel):
852
- list_IDs_fg.append(k*ysh*xsh+l*xsh+m)
900
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
901
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
902
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
903
+ patch = bm.label_data[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
904
+ index = k*ysh*xsh+l*xsh+m
905
+ if not np.any(patch==-1): # ignore padded areas
906
+ if bm.balance:
907
+ if bm.separation:
908
+ centerLabel = bm.label_data[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
909
+ if centerLabel>0 and np.any(np.logical_and(patch!=centerLabel, patch>0)):
910
+ list_IDs_fg.append(index)
911
+ elif centerLabel>0 and np.any(patch!=centerLabel):
912
+ list_IDs_bg.append(index)
913
+ elif np.any(patch>0):
914
+ list_IDs_fg.append(index)
915
+ else:
916
+ list_IDs_bg.append(index)
853
917
  else:
854
- list_IDs_fg.append(k*ysh*xsh+l*xsh+m)
918
+ if bm.separation:
919
+ centerLabel = bm.label_data[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
920
+ if centerLabel>0 and np.any(patch!=centerLabel):
921
+ list_IDs_fg.append(index)
922
+ else:
923
+ list_IDs_fg.append(index)
855
924
 
856
925
  if bm.val_img_data is not None:
857
926
 
@@ -862,25 +931,35 @@ def train_segmentation(bm):
862
931
  list_IDs_val_fg, list_IDs_val_bg = [], []
863
932
 
864
933
  # get validation IDs of patches
865
- if bm.balance and not bm.val_dice:
866
- for k in range(0, zsh_val-bm.z_patch+1, bm.validation_stride_size):
867
- for l in range(0, ysh_val-bm.y_patch+1, bm.validation_stride_size):
868
- for m in range(0, xsh_val-bm.x_patch+1, bm.validation_stride_size):
869
- if np.any(bm.val_label_data[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
870
- list_IDs_val_fg.append(k*ysh_val*xsh_val+l*xsh_val+m)
934
+ for k in range(0, zsh_val-bm.z_patch+1, bm.validation_stride_size):
935
+ for l in range(0, ysh_val-bm.y_patch+1, bm.validation_stride_size):
936
+ for m in range(0, xsh_val-bm.x_patch+1, bm.validation_stride_size):
937
+ patch = bm.val_label_data[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
938
+ index = k*ysh_val*xsh_val+l*xsh_val+m
939
+ if not np.any(patch==-1): # ignore padded areas
940
+ if bm.balance and not bm.val_dice:
941
+ if bm.separation:
942
+ centerLabel = bm.val_label_data[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
943
+ if centerLabel>0 and np.any(np.logical_and(patch!=centerLabel, patch>0)):
944
+ list_IDs_val_fg.append(index)
945
+ elif centerLabel>0 and np.any(patch!=centerLabel):
946
+ list_IDs_val_bg.append(index)
947
+ elif np.any(patch>0):
948
+ list_IDs_val_fg.append(index)
949
+ else:
950
+ list_IDs_val_bg.append(index)
871
951
  else:
872
- list_IDs_val_bg.append(k*ysh_val*xsh_val+l*xsh_val+m)
873
- else:
874
- for k in range(0, zsh_val-bm.z_patch+1, bm.validation_stride_size):
875
- for l in range(0, ysh_val-bm.y_patch+1, bm.validation_stride_size):
876
- for m in range(0, xsh_val-bm.x_patch+1, bm.validation_stride_size):
877
- if bm.separation:
878
- centerLabel = bm.val_label_data[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
879
- patch = bm.val_label_data[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
880
- if centerLabel>0 and np.any(patch!=centerLabel):
881
- list_IDs_val_fg.append(k*ysh_val*xsh_val+l*xsh_val+m)
882
- else:
883
- list_IDs_val_fg.append(k*ysh_val*xsh_val+l*xsh_val+m)
952
+ if bm.separation:
953
+ centerLabel = bm.val_label_data[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
954
+ if centerLabel>0 and np.any(patch!=centerLabel):
955
+ list_IDs_val_fg.append(index)
956
+ else:
957
+ list_IDs_val_fg.append(index)
958
+
959
+ # remove padding label
960
+ bm.label_data[bm.label_data<0]=0
961
+ if bm.val_img_data is not None:
962
+ bm.val_label_data[bm.val_label_data<0]=0
884
963
 
885
964
  # number of labels
886
965
  nb_labels = len(allLabels)
@@ -968,11 +1047,11 @@ def train_segmentation(bm):
968
1047
  monitor='val_accuracy',
969
1048
  mode='max',
970
1049
  save_best_only=True)
971
- callbacks = [model_checkpoint_callback, meta_data]
1050
+ callbacks = [model_checkpoint_callback, HistoryCallback(bm), meta_data]
972
1051
  if bm.early_stopping > 0:
973
1052
  callbacks.insert(0, EarlyStopping(monitor='val_accuracy', mode='max', patience=bm.early_stopping))
974
1053
  else:
975
- callbacks = [ModelCheckpoint(filepath=str(bm.path_to_model)), meta_data]
1054
+ callbacks = [ModelCheckpoint(filepath=str(bm.path_to_model)), HistoryCallback(bm), meta_data]
976
1055
 
977
1056
  # monitor dice score on training data
978
1057
  if bm.train_dice:
@@ -983,15 +1062,11 @@ def train_segmentation(bm):
983
1062
  callbacks.insert(-1, CustomCallback(bm.img_id, bm.epochs))
984
1063
 
985
1064
  # train model
986
- history = model.fit(training_generator,
987
- epochs=bm.epochs,
988
- validation_data=validation_generator,
989
- callbacks=callbacks,
990
- workers=bm.workers)
991
-
992
- # save monitoring figure on train end
993
- if bm.val_img_data is None or not bm.val_dice:
994
- save_history(history.history, bm.path_to_model)
1065
+ model.fit(training_generator,
1066
+ epochs=bm.epochs,
1067
+ validation_data=validation_generator,
1068
+ callbacks=callbacks,
1069
+ workers=bm.workers)
995
1070
 
996
1071
  def load_prediction_data(bm, channels, normalize, normalization_parameters,
997
1072
  region_of_interest, img, img_header, load_blockwise=False, z=None):
@@ -1099,6 +1174,21 @@ def gradient(volData):
1099
1174
 
1100
1175
  def predict_segmentation(bm, region_of_interest, channels, normalization_parameters):
1101
1176
 
1177
+ from mpi4py import MPI
1178
+ comm = MPI.COMM_WORLD
1179
+ rank = comm.Get_rank()
1180
+ ngpus = comm.Get_size()
1181
+
1182
+ # Set the visible GPU by ID
1183
+ gpus = tf.config.experimental.list_physical_devices('GPU')
1184
+ if gpus:
1185
+ try:
1186
+ # Restrict TensorFlow to only use the specified GPU
1187
+ tf.config.experimental.set_visible_devices(gpus[rank], 'GPU')
1188
+ tf.config.experimental.set_memory_growth(gpus[rank], True)
1189
+ except RuntimeError as e:
1190
+ print(e)
1191
+
1102
1192
  # initialize results
1103
1193
  results = {}
1104
1194
 
@@ -1125,6 +1215,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1125
1215
  load_blockwise = False
1126
1216
  if not bm.scaling and not bm.normalize and bm.path_to_image and not np.any(region_of_interest) and \
1127
1217
  os.path.splitext(bm.path_to_image)[1] in ['.tif', '.tiff'] and not bm.acwe:
1218
+ # get image shape
1128
1219
  tif = TiffFile(bm.path_to_image)
1129
1220
  zsh = len(tif.pages)
1130
1221
  ysh, xsh = tif.pages[0].shape
@@ -1240,17 +1331,23 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1240
1331
  # stream data batchwise to GPU to reduce memory usage
1241
1332
  X = np.empty((bm.batch_size, bm.z_patch, bm.y_patch, bm.x_patch, channels), dtype=np.float32)
1242
1333
 
1243
- # allocate final array
1244
- if bm.return_probs:
1334
+ # allocate final probabilities array
1335
+ if rank==0 and bm.return_probs:
1245
1336
  final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1246
1337
 
1247
- # allocate result array
1248
- label = np.zeros((zsh, ysh, xsh), dtype=np.uint8)
1338
+ # allocate final result array
1339
+ if rank==0:
1340
+ label = np.zeros((zsh, ysh, xsh), dtype=np.uint8)
1249
1341
 
1250
1342
  # predict segmentation block by block
1251
1343
  z_indices = range(0, zsh-bm.z_patch+1, bm.stride_size)
1252
1344
  for j, z in enumerate(z_indices):
1253
-
1345
+ # handle len(z_indices) % ngpus != 0
1346
+ if len(z_indices)-1-j < ngpus:
1347
+ nprocs = len(z_indices)-j
1348
+ else:
1349
+ nprocs = ngpus
1350
+ if j % ngpus == rank:
1254
1351
  # load blockwise from TIFF
1255
1352
  if load_blockwise:
1256
1353
  img, _, _, _, _, _, _ = load_prediction_data(bm,
@@ -1284,8 +1381,11 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1284
1381
  # number of patches
1285
1382
  nb_patches = len(list_IDs_block)
1286
1383
 
1287
- # allocate tmp probabilities array
1288
- probs = np.zeros((bm.z_patch, ysh, xsh, nb_labels), dtype=np.float32)
1384
+ # allocate block array
1385
+ if bm.separation:
1386
+ block_label = np.zeros((bm.z_patch, ysh, xsh), dtype=np.uint8)
1387
+ else:
1388
+ block_probs = np.zeros((bm.z_patch, ysh, xsh, nb_labels), dtype=np.float32)
1289
1389
 
1290
1390
  # get one batch of image patches
1291
1391
  for step in range(nb_patches//bm.batch_size):
@@ -1317,157 +1417,186 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1317
1417
  if step*bm.batch_size+i < max_i_block:
1318
1418
  if bm.separation:
1319
1419
  patch = np.argmax(Y[i], axis=-1).astype(np.uint8)
1320
- label[z:z+bm.z_patch,l:l+bm.y_patch,m:m+bm.x_patch] += gradient(patch)
1420
+ block_label[:,l:l+bm.y_patch,m:m+bm.x_patch] += gradient(patch)
1321
1421
  else:
1322
- probs[:,l:l+bm.y_patch,m:m+bm.x_patch] += Y[i]
1323
-
1324
- if not bm.separation:
1325
- # overlap in z direction
1326
- if bm.stride_size < bm.z_patch:
1327
- if j>0:
1328
- probs[:bm.stride_size] += overlap
1329
- overlap = probs[bm.stride_size:].copy()
1330
-
1331
- # calculate result
1332
- if z==z_indices[-1]:
1333
- label[z:z+bm.z_patch] = np.argmax(probs, axis=-1).astype(np.uint8)
1334
- if bm.return_probs:
1335
- final[z:z+bm.z_patch] = probs
1422
+ block_probs[:,l:l+bm.y_patch,m:m+bm.x_patch] += Y[i]
1423
+
1424
+ # communicate results
1425
+ if bm.separation:
1426
+ if rank==0:
1427
+ label[z:z+bm.z_patch] += block_label
1428
+ for source in range(1, nprocs):
1429
+ receivedata = np.empty_like(block_label)
1430
+ for i in range(bm.z_patch):
1431
+ comm.Recv([receivedata[i], MPI.BYTE], source=source, tag=i)
1432
+ block_start = z_indices[j+source]
1433
+ label[block_start:block_start+bm.z_patch] += receivedata
1336
1434
  else:
1337
- block_zsh = min(bm.stride_size, bm.z_patch)
1338
- label[z:z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
1339
- if bm.return_probs:
1340
- final[z:z+block_zsh] = probs[:block_zsh]
1341
-
1342
- # refine mask data with result
1343
- if bm.refinement:
1344
- # loop over boundary patches
1345
- for i, ID in enumerate(list_IDs):
1346
- if i < max_i:
1347
- k = ID // (ysh*xsh)
1348
- rest = ID % (ysh*xsh)
1349
- l = rest // xsh
1350
- m = rest % xsh
1351
- mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] = label[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1352
- label = mask
1353
-
1354
- # remove appendix
1355
- if bm.return_probs:
1356
- final = final[:-z_rest,:-y_rest,:-x_rest]
1357
- label = label[:-z_rest,:-y_rest,:-x_rest]
1358
- zsh, ysh, xsh = label.shape
1359
-
1360
- # return probabilities
1361
- if bm.return_probs:
1362
- counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1363
- nb = 0
1364
- for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1365
- for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1366
- for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1367
- counter[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += 1
1368
- nb += 1
1369
- counter[counter==0] = 1
1370
- probabilities = final / counter
1435
+ for i in range(bm.z_patch):
1436
+ comm.Send([block_label[i].copy(), MPI.BYTE], dest=0, tag=i)
1437
+ else:
1438
+ if rank==0:
1439
+ for source in range(nprocs):
1440
+ if source>0:
1441
+ probs = np.empty_like(block_probs)
1442
+ for i in range(bm.z_patch):
1443
+ comm.Recv([probs[i], MPI.FLOAT], source=source, tag=i)
1444
+ else:
1445
+ probs = block_probs
1446
+
1447
+ # overlap in z direction
1448
+ if bm.stride_size < bm.z_patch:
1449
+ if j+source>0:
1450
+ probs[:bm.stride_size] += overlap
1451
+ overlap = probs[bm.stride_size:].copy()
1452
+
1453
+ # calculate result
1454
+ block_z = z_indices[j+source]
1455
+ if j+source==len(z_indices)-1:
1456
+ label[block_z:block_z+bm.z_patch] = np.argmax(probs, axis=-1).astype(np.uint8)
1457
+ if bm.return_probs:
1458
+ final[block_z:block_z+bm.z_patch] = probs
1459
+ else:
1460
+ block_zsh = min(bm.stride_size, bm.z_patch)
1461
+ label[block_z:block_z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
1462
+ if bm.return_probs:
1463
+ final[block_z:block_z+block_zsh] = probs[:block_zsh]
1464
+ else:
1465
+ for i in range(bm.z_patch):
1466
+ comm.Send([block_probs[i].copy(), MPI.FLOAT], dest=0, tag=i)
1467
+ if rank==0:
1468
+
1469
+ # refine mask data with result
1470
+ if bm.refinement:
1471
+ # loop over boundary patches
1472
+ for i, ID in enumerate(list_IDs):
1473
+ if i < max_i:
1474
+ k = ID // (ysh*xsh)
1475
+ rest = ID % (ysh*xsh)
1476
+ l = rest // xsh
1477
+ m = rest % xsh
1478
+ mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] = label[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1479
+ label = mask
1480
+
1481
+ # remove appendix
1482
+ if bm.return_probs:
1483
+ final = final[:-z_rest,:-y_rest,:-x_rest]
1484
+ label = label[:-z_rest,:-y_rest,:-x_rest]
1485
+ zsh, ysh, xsh = label.shape
1486
+
1487
+ # return probabilities
1488
+ if bm.return_probs:
1489
+ counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1490
+ nb = 0
1491
+ for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1492
+ for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1493
+ for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1494
+ counter[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += 1
1495
+ nb += 1
1496
+ counter[counter==0] = 1
1497
+ probabilities = final / counter
1498
+ if bm.scaling:
1499
+ probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
1500
+ if np.any(region_of_interest):
1501
+ min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
1502
+ tmp = np.zeros((original_zsh, original_ysh, original_xsh, nb_labels), dtype=np.float32)
1503
+ tmp[min_z:max_z,min_y:max_y,min_x:max_x] = probabilities
1504
+ probabilities = np.copy(tmp, order='C')
1505
+ results['probs'] = probabilities
1506
+
1507
+ # rescale final to input size
1371
1508
  if bm.scaling:
1372
- probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
1509
+ label = img_resize(label, z_shape, y_shape, x_shape, labels=True)
1510
+
1511
+ # revert automatic cropping
1373
1512
  if np.any(region_of_interest):
1374
1513
  min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
1375
- tmp = np.zeros((original_zsh, original_ysh, original_xsh, nb_labels), dtype=np.float32)
1376
- tmp[min_z:max_z,min_y:max_y,min_x:max_x] = probabilities
1377
- probabilities = np.copy(tmp, order='C')
1378
- results['probs'] = probabilities
1379
-
1380
- # rescale final to input size
1381
- if bm.scaling:
1382
- label = img_resize(label, z_shape, y_shape, x_shape, labels=True)
1514
+ tmp = np.zeros((original_zsh, original_ysh, original_xsh), dtype=np.uint8)
1515
+ tmp[min_z:max_z,min_y:max_y,min_x:max_x] = label
1516
+ label = np.copy(tmp, order='C')
1517
+
1518
+ # get result
1519
+ if not bm.separation:
1520
+ label = get_labels(label, bm.allLabels)
1521
+ results['regular'] = label
1522
+
1523
+ # load header from file
1524
+ if bm.header_file and os.path.exists(bm.header_file):
1525
+ _, bm.header = load_data(bm.header_file)
1526
+ # update file extension
1527
+ if bm.header is not None and bm.path_to_image:
1528
+ bm.extension = os.path.splitext(bm.header_file)[1]
1529
+ if bm.extension == '.gz':
1530
+ bm.extension = '.nii.gz'
1531
+ bm.path_to_final = os.path.splitext(bm.path_to_final)[0] + bm.extension
1532
+ if bm.django_env and not bm.remote and not bm.tarfile:
1533
+ bm.path_to_final = unique_file_path(bm.path_to_final)
1534
+
1535
+ # handle amira header
1536
+ if bm.header is not None:
1537
+ if bm.extension == '.am':
1538
+ bm.header = set_image_dimensions(bm.header[0], label)
1539
+ if bm.img_header is not None:
1540
+ try:
1541
+ bm.header = set_physical_size(bm.header, bm.img_header[0])
1542
+ except:
1543
+ pass
1544
+ bm.header = [bm.header]
1545
+ else:
1546
+ # build new header
1547
+ if bm.img_header is None:
1548
+ zsh, ysh, xsh = label.shape
1549
+ bm.img_header = sitk.Image(xsh, ysh, zsh, bm.header.GetPixelID())
1550
+ # copy metadata
1551
+ for key in bm.header.GetMetaDataKeys():
1552
+ if not (re.match(r'Segment\d+_Extent$', key) or key=='Segmentation_ConversionParameters'):
1553
+ bm.img_header.SetMetaData(key, bm.header.GetMetaData(key))
1554
+ bm.header = bm.img_header
1555
+ results['header'] = bm.header
1556
+
1557
+ # save result
1558
+ if bm.path_to_image:
1559
+ save_data(bm.path_to_final, label, header=bm.header, compress=bm.compression)
1383
1560
 
1384
- # revert automatic cropping
1385
- if np.any(region_of_interest):
1386
- min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
1387
- tmp = np.zeros((original_zsh, original_ysh, original_xsh), dtype=np.uint8)
1388
- tmp[min_z:max_z,min_y:max_y,min_x:max_x] = label
1389
- label = np.copy(tmp, order='C')
1390
-
1391
- # get result
1392
- if not bm.separation:
1393
- label = get_labels(label, bm.allLabels)
1394
- results['regular'] = label
1395
-
1396
- # load header from file
1397
- if bm.header_file and os.path.exists(bm.header_file):
1398
- _, bm.header = load_data(bm.header_file)
1399
- # update file extension
1400
- if bm.header is not None and bm.path_to_image:
1401
- bm.extension = os.path.splitext(bm.header_file)[1]
1561
+ # paths to optional results
1562
+ filename, bm.extension = os.path.splitext(bm.path_to_final)
1402
1563
  if bm.extension == '.gz':
1403
1564
  bm.extension = '.nii.gz'
1404
- bm.path_to_final = os.path.splitext(bm.path_to_final)[0] + bm.extension
1405
- if bm.django_env and not bm.remote and not bm.tarfile:
1406
- bm.path_to_final = unique_file_path(bm.path_to_final)
1407
-
1408
- # handle amira header
1409
- if bm.header is not None:
1410
- if bm.extension == '.am':
1411
- bm.header = set_image_dimensions(bm.header[0], label)
1412
- if bm.img_header is not None:
1413
- try:
1414
- bm.header = set_physical_size(bm.header, bm.img_header[0])
1415
- except:
1416
- pass
1417
- bm.header = [bm.header]
1418
- else:
1419
- # build new header
1420
- if bm.img_header is None:
1421
- zsh, ysh, xsh = label.shape
1422
- bm.img_header = sitk.Image(xsh, ysh, zsh, bm.header.GetPixelID())
1423
- # copy metadata
1424
- for key in bm.header.GetMetaDataKeys():
1425
- if not (re.match(r'Segment\d+_Extent$', key) or key=='Segmentation_ConversionParameters'):
1426
- bm.img_header.SetMetaData(key, bm.header.GetMetaData(key))
1427
- bm.header = bm.img_header
1428
- results['header'] = bm.header
1429
-
1430
- # save result
1431
- if bm.path_to_image:
1432
- save_data(bm.path_to_final, label, header=bm.header, compress=bm.compression)
1433
-
1434
- # paths to optional results
1435
- filename, bm.extension = os.path.splitext(bm.path_to_final)
1436
- if bm.extension == '.gz':
1437
- bm.extension = '.nii.gz'
1438
- filename = filename[:-4]
1439
- path_to_cleaned = filename + '.cleaned' + bm.extension
1440
- path_to_filled = filename + '.filled' + bm.extension
1441
- path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
1442
- path_to_refined = filename + '.refined' + bm.extension
1443
- path_to_acwe = filename + '.acwe' + bm.extension
1444
-
1445
- # remove outliers
1446
- if bm.clean:
1447
- cleaned_result = clean(label, bm.clean)
1448
- results['cleaned'] = cleaned_result
1449
- if bm.path_to_image:
1450
- save_data(path_to_cleaned, cleaned_result, header=bm.header, compress=bm.compression)
1451
- if bm.fill:
1452
- filled_result = clean(label, bm.fill)
1453
- results['filled'] = filled_result
1454
- if bm.path_to_image:
1455
- save_data(path_to_filled, filled_result, header=bm.header, compress=bm.compression)
1456
- if bm.clean and bm.fill:
1457
- cleaned_filled_result = cleaned_result + (filled_result - label)
1458
- results['cleaned_filled'] = cleaned_filled_result
1459
- if bm.path_to_image:
1460
- save_data(path_to_cleaned_filled, cleaned_filled_result, header=bm.header, compress=bm.compression)
1461
-
1462
- # post-processing with active contour
1463
- if bm.acwe:
1464
- acwe_result = activeContour(bm.img_data, label, bm.acwe_alpha, bm.acwe_smooth, bm.acwe_steps)
1465
- refined_result = activeContour(bm.img_data, label, simple=True)
1466
- results['acwe'] = acwe_result
1467
- results['refined'] = refined_result
1468
- if bm.path_to_image:
1469
- save_data(path_to_acwe, acwe_result, header=bm.header, compress=bm.compression)
1470
- save_data(path_to_refined, refined_result, header=bm.header, compress=bm.compression)
1471
-
1472
- return results, bm
1565
+ filename = filename[:-4]
1566
+ path_to_cleaned = filename + '.cleaned' + bm.extension
1567
+ path_to_filled = filename + '.filled' + bm.extension
1568
+ path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
1569
+ path_to_refined = filename + '.refined' + bm.extension
1570
+ path_to_acwe = filename + '.acwe' + bm.extension
1571
+
1572
+ # remove outliers
1573
+ if bm.clean:
1574
+ cleaned_result = clean(label, bm.clean)
1575
+ results['cleaned'] = cleaned_result
1576
+ if bm.path_to_image:
1577
+ save_data(path_to_cleaned, cleaned_result, header=bm.header, compress=bm.compression)
1578
+ if bm.fill:
1579
+ filled_result = clean(label, bm.fill)
1580
+ results['filled'] = filled_result
1581
+ if bm.path_to_image:
1582
+ save_data(path_to_filled, filled_result, header=bm.header, compress=bm.compression)
1583
+ if bm.clean and bm.fill:
1584
+ cleaned_filled_result = cleaned_result + (filled_result - label)
1585
+ results['cleaned_filled'] = cleaned_filled_result
1586
+ if bm.path_to_image:
1587
+ save_data(path_to_cleaned_filled, cleaned_filled_result, header=bm.header, compress=bm.compression)
1588
+
1589
+ # post-processing with active contour
1590
+ if bm.acwe:
1591
+ acwe_result = activeContour(bm.img_data, label, bm.acwe_alpha, bm.acwe_smooth, bm.acwe_steps)
1592
+ refined_result = activeContour(bm.img_data, label, simple=True)
1593
+ results['acwe'] = acwe_result
1594
+ results['refined'] = refined_result
1595
+ if bm.path_to_image:
1596
+ save_data(path_to_acwe, acwe_result, header=bm.header, compress=bm.compression)
1597
+ save_data(path_to_refined, refined_result, header=bm.header, compress=bm.compression)
1598
+
1599
+ return results, bm
1600
+ else:
1601
+ return None, None
1473
1602
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: biomedisa
3
- Version: 24.8.5
3
+ Version: 24.8.7
4
4
  Summary: Segmentation of 3D volumetric image data
5
5
  Author: Philipp Lösel
6
6
  Author-email: philipp.loesel@anu.edu.au
@@ -10,7 +10,7 @@ Project-URL: GitHub, https://github.com/biomedisa/biomedisa
10
10
  Classifier: Programming Language :: Python :: 3
11
11
  Classifier: License :: OSI Approved :: European Union Public Licence 1.2 (EUPL 1.2)
12
12
  Classifier: Operating System :: OS Independent
13
- Requires-Python: >=3.8
13
+ Requires-Python: >=3.10
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
16
 
@@ -19,6 +19,7 @@ License-File: LICENSE
19
19
  - [Overview](#overview)
20
20
  - [Hardware Requirements](#hardware-requirements)
21
21
  - [Installation (command-line based)](#installation-command-line-based)
22
+ - [Installation (3D Slicer extension)](#installation-3d-slicer-extension)
22
23
  - [Installation (browser based)](#installation-browser-based)
23
24
  - [Download Data](#download-data)
24
25
  - [Revisions](#revisions)
@@ -33,20 +34,21 @@ License-File: LICENSE
33
34
  - [License](#license)
34
35
 
35
36
  ## Overview
36
- Biomedisa (https://biomedisa.info) is a free and easy-to-use open-source application for segmenting large 3D volumetric images such as CT and MRI scans, developed at [The Australian National University CTLab](https://ctlab.anu.edu.au/). Biomedisa's smart interpolation of sparsely pre-segmented slices enables accurate semi-automated segmentation by considering the complete underlying image data. Additionally, Biomedisa enables deep learning for fully automated segmentation across similar samples and structures. It is compatible with segmentation tools like Amira/Avizo, ImageJ/Fiji and 3D Slicer. If you are using Biomedisa or the data for your research please cite: Lösel, P.D. et al. [Introducing Biomedisa as an open-source online platform for biomedical image segmentation.](https://www.nature.com/articles/s41467-020-19303-w) *Nat. Commun.* **11**, 5577 (2020).
37
+ Biomedisa (https://biomedisa.info) is a free and easy-to-use open-source application for segmenting large 3D volumetric images such as CT and MRI scans, developed at [The Australian National University CTLab](https://ctlab.anu.edu.au/). Biomedisa's smart interpolation of sparsely pre-segmented slices enables accurate semi-automated segmentation by considering the complete underlying image data. Additionally, Biomedisa enables deep learning for fully automated segmentation across similar samples and structures. It is compatible with segmentation tools like Amira/Avizo, ImageJ/Fiji, and 3D Slicer. If you are using Biomedisa or the data for your research please cite: Lösel, P.D. et al. [Introducing Biomedisa as an open-source online platform for biomedical image segmentation.](https://www.nature.com/articles/s41467-020-19303-w) *Nat. Commun.* **11**, 5577 (2020).
37
38
 
38
39
  ## Hardware Requirements
39
- + One or more NVIDIA GPUs with compute capability 3.0 or higher.
40
+ + One or more NVIDIA GPUs.
40
41
 
41
42
  ## Installation (command-line based)
42
- + [Ubuntu 22.04 + Smart Interpolation](https://github.com/biomedisa/biomedisa/blob/master/README/ubuntu2204_interpolation_cli.md)
43
- + [Ubuntu 22.04 + Smart Interpolation + Deep Learning](https://github.com/biomedisa/biomedisa/blob/master/README/ubuntu2204_cuda11.8_gpu_cli.md)
44
- + [Windows 10 + Smart Interpolation + Deep Learning](https://github.com/biomedisa/biomedisa/blob/master/README/windows10_cuda_gpu_cli.md)
43
+ + [Ubuntu 22/24 + Smart Interpolation](https://github.com/biomedisa/biomedisa/blob/master/README/ubuntu_interpolation_cli.md)
44
+ + [Ubuntu 22/24 + Deep Learning](https://github.com/biomedisa/biomedisa/blob/master/README/ubuntu_deeplearning_cli.md)
45
+ + [Ubuntu 22/24 + Smart Interpolation + Deep Learning](https://github.com/biomedisa/biomedisa/blob/master/README/ubuntu_cli.md)
46
+ + [Windows 10/11 + Smart Interpolation + Deep Learning](https://github.com/biomedisa/biomedisa/blob/master/README/windows10_cuda_gpu_cli.md)
45
47
  + [Windows (WSL) + Smart Interpolation + Deep Learning](https://github.com/biomedisa/biomedisa/blob/master/README/windows_wsl.md)
46
48
 
47
49
  ## Installation (3D Slicer extension)
48
50
  + [Ubuntu 22.04 + Smart Interpolation + Deep Learning](https://github.com/biomedisa/biomedisa/blob/master/README/ubuntu2204_cuda11.8_gpu_slicer.md)
49
- + [Windows 10 + Smart Interpolation](https://github.com/biomedisa/biomedisa/blob/master/README/windows10_cuda_gpu_slicer.md)
51
+ + [Windows 10/11 + Smart Interpolation](https://github.com/biomedisa/biomedisa/blob/master/README/windows10_cuda_gpu_slicer.md)
50
52
 
51
53
  ## Installation (browser based)
52
54
  + [Ubuntu 22.04](https://github.com/biomedisa/biomedisa/blob/master/README/ubuntu2204_cuda11.8.md)
@@ -165,7 +167,7 @@ save_data('final.Head5.am', results['regular'], results['header'])
165
167
 
166
168
  #### Command-line based (prediction)
167
169
  ```
168
- python -m biomedisa.deeplearning C:\Users\%USERNAME%\Downloads\testing_axial_crop_pat13.nii.gz C:\Users\%USERNAME%\Downloads\heart.h5 -p
170
+ python -m biomedisa.deeplearning C:\Users\%USERNAME%\Downloads\testing_axial_crop_pat13.nii.gz C:\Users\%USERNAME%\Downloads\heart.h5
169
171
  ```
170
172
 
171
173
  ## Mesh Generator
@@ -1,10 +1,10 @@
1
1
  biomedisa/__init__.py,sha256=hw4mzEjGFXm-vxus2DBfKFW0nKoG0ibL5SH6ShfchrY,1526
2
2
  biomedisa/__main__.py,sha256=a1--8vhtztWEloHVtbM43FZLCfrFo4BELgdsgtWE8ls,536
3
- biomedisa/deeplearning.py,sha256=UD4IeaxITLLqzoaQ1Ul5HI_bN8wINouOGxf14yZ7SWQ,28008
3
+ biomedisa/deeplearning.py,sha256=r4YzFUss_2u6I7clIg9QD4DDx7bq8AFUnUgU1TCwS1c,28273
4
4
  biomedisa/interpolation.py,sha256=i10aqwEl-wsVU_nQ-zyubhAs27NSKF4ial7LyhaBLv0,17273
5
5
  biomedisa/mesh.py,sha256=8-iuVsrfW5JovaMrAez7qSxv1LCU3eiqOdik0s0DV1w,16062
6
- biomedisa/features/DataGenerator.py,sha256=MINc7emhELPxACOYdnLFHgh-RHUKODLH0DH7zz6u6mc,13116
7
- biomedisa/features/DataGeneratorCrop.py,sha256=23R4Z-8tB1CsjYTYfhHGovlJpAny_q9OV9hq8kc2GJg,5454
6
+ biomedisa/features/DataGenerator.py,sha256=m7vsKkLhRsVF1BE3Y8YGVx-xx0DWjbBw_inIdZBq6pQ,13111
7
+ biomedisa/features/DataGeneratorCrop.py,sha256=KtGqNadghOd59wIU9hATM_5YgSks95rS1kJ2lsSSX7w,6612
8
8
  biomedisa/features/PredictDataGenerator.py,sha256=JH8SPGQm-Y7_Drec2fw3jBUupvpIkQ1FvkDXP7mUjDY,4074
9
9
  biomedisa/features/PredictDataGeneratorCrop.py,sha256=HF5tJbGtlJMHr7lMT9IiIdLG2CTjXstbKoOjlZJ93Is,3431
10
10
  biomedisa/features/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -12,10 +12,10 @@ biomedisa/features/active_contour.py,sha256=FnPuvYck_KL4xYRFqwzm4Grdm288EdlcLFt8
12
12
  biomedisa/features/assd.py,sha256=q9NUQXEoA4Pi3d8b5fmys615CWu06Sm0N9-OGwJOFnw,6537
13
13
  biomedisa/features/biomedisa_helper.py,sha256=oIsOmIJ8xUQVwgbzMgw65dXcZmzwePpFQoIdbgLTnF8,32727
14
14
  biomedisa/features/create_slices.py,sha256=uSDH1OcEYc5BFPZHSy3UpS4P2DuoVnxOZ-l7wmyT_Po,13108
15
- biomedisa/features/crop_helper.py,sha256=R0_uZ8rxg1l0-HxP1fYBWWCSUjvSvTBSgW47Nq8eK4s,23789
15
+ biomedisa/features/crop_helper.py,sha256=wNVbb1c6qNmJFbDJ2jrjjvqJw-EMLkJt1HLjEolMwAA,24089
16
16
  biomedisa/features/curvop_numba.py,sha256=AjKQJcUBoURTB8pq1HmugQYpBwBELthhcEu51_r_xPI,7049
17
17
  biomedisa/features/django_env.py,sha256=LNrZ6rBHZ5I0FaWa5xN8K-ASPgq0r5dGDEUI56HzJxE,8615
18
- biomedisa/features/keras_helper.py,sha256=Emf-Fx9rB3jECL7XgK8DYMRwO27ynijt5YdnwMlDMXY,62747
18
+ biomedisa/features/keras_helper.py,sha256=r9Knb5GlJCU3nQ601wBJF_9pOpdTL20KUO-x0xRP6WM,68853
19
19
  biomedisa/features/nc_reader.py,sha256=RoRMwu3ELSNfoV3qZtaT2OWACnXb2EhNFu_DAF1T93o,7406
20
20
  biomedisa/features/pid.py,sha256=Jmn1VIp0fBlgBrqZ-yUIQVVb5-NAxNBdibXALVr2PPI,2545
21
21
  biomedisa/features/process_image.py,sha256=VtS3fGDvglqJiiJLPK1toe76J58j914NJ8XQKg3CRwo,11091
@@ -37,8 +37,8 @@ biomedisa/features/random_walk/pyopencl_large.py,sha256=q79AxG3p3qFjxfiAZfUK9I5B
37
37
  biomedisa/features/random_walk/pyopencl_small.py,sha256=opNlS-qzOa9qWafBNJdvf6r1aRAFf7_JXf6ISDnkdXE,17068
38
38
  biomedisa/features/random_walk/rw_large.py,sha256=ZnITvk00Y11ZZlGuBRaJO1EwU0wYBdEwdpj9vvXCqF4,19805
39
39
  biomedisa/features/random_walk/rw_small.py,sha256=RPzZe24YrEwYelJukDjvqaoD_SyhgdriEi7uV3kZGXI,14881
40
- biomedisa-24.8.5.dist-info/LICENSE,sha256=sehayP6UhydNnmstfL4yFR3genMRdpuUh6uZVWJN1H0,14152
41
- biomedisa-24.8.5.dist-info/METADATA,sha256=IMCcEHNKZKWZqsMN8vu4h6S812_vkYyKj0NNAiebrrY,10569
42
- biomedisa-24.8.5.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
43
- biomedisa-24.8.5.dist-info/top_level.txt,sha256=opsf1Eb4vCguPSxev4HHSeiUKCccT_C_RcUCdAYbHWQ,10
44
- biomedisa-24.8.5.dist-info/RECORD,,
40
+ biomedisa-24.8.7.dist-info/LICENSE,sha256=sehayP6UhydNnmstfL4yFR3genMRdpuUh6uZVWJN1H0,14152
41
+ biomedisa-24.8.7.dist-info/METADATA,sha256=7qJuRtt6SHy4MnZDyn-lqgSPXIjjSTrFZYRbXMbW5Mc,10708
42
+ biomedisa-24.8.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
43
+ biomedisa-24.8.7.dist-info/top_level.txt,sha256=opsf1Eb4vCguPSxev4HHSeiUKCccT_C_RcUCdAYbHWQ,10
44
+ biomedisa-24.8.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5