biomedisa 24.8.11__py3-none-any.whl → 25.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,13 +39,13 @@ from tensorflow.keras.layers import (
39
39
  from tensorflow.keras import backend as K
40
40
  from tensorflow.keras.utils import to_categorical
41
41
  from tensorflow.keras.callbacks import Callback, ModelCheckpoint, EarlyStopping
42
- from biomedisa.features.DataGenerator import DataGenerator
42
+ from biomedisa.features.DataGenerator import DataGenerator, welford_mean_std
43
43
  from biomedisa.features.PredictDataGenerator import PredictDataGenerator
44
- from biomedisa.features.biomedisa_helper import (unique,
44
+ from biomedisa.features.biomedisa_helper import (unique, welford_mean_std,
45
45
  img_resize, load_data, save_data, set_labels_to_zero, id_generator, unique_file_path)
46
46
  from biomedisa.features.remove_outlier import clean, fill
47
47
  from biomedisa.features.active_contour import activeContour
48
- from tifffile import TiffFile, imread
48
+ from tifffile import TiffFile, imread, imwrite
49
49
  import matplotlib.pyplot as plt
50
50
  import SimpleITK as sitk
51
51
  import tensorflow as tf
@@ -220,11 +220,11 @@ def compute_position(position, zsh, ysh, xsh):
220
220
  position[k,l,m] = x+y+z
221
221
  return position
222
222
 
223
- def make_conv_block(nb_filters, input_tensor, block):
223
+ def make_conv_block(nb_filters, input_tensor, block, dtype):
224
224
  def make_stage(input_tensor, stage):
225
225
  name = 'conv_{}_{}'.format(block, stage)
226
226
  x = Conv3D(nb_filters, (3, 3, 3), activation='relu',
227
- padding='same', name=name, data_format="channels_last")(input_tensor)
227
+ padding='same', name=name, data_format="channels_last", dtype=dtype)(input_tensor)
228
228
  name = 'batch_norm_{}_{}'.format(block, stage)
229
229
  try:
230
230
  x = BatchNormalization(name=name, synchronized=True)(x)
@@ -266,56 +266,65 @@ def make_conv_block_resnet(nb_filters, input_tensor, block):
266
266
 
267
267
  return out
268
268
 
269
- def make_unet(input_shape, nb_labels, filters='32-64-128-256-512', resnet=False):
269
+ def make_unet(bm, input_shape, nb_labels):
270
+ # enable mixed_precision
271
+ if bm.mixed_precision:
272
+ dtype = "float16"
273
+ else:
274
+ dtype = "float32"
270
275
 
276
+ # input
271
277
  nb_plans, nb_rows, nb_cols, _ = input_shape
278
+ inputs = Input(input_shape, dtype=dtype)
272
279
 
273
- inputs = Input(input_shape)
274
-
275
- filters = filters.split('-')
280
+ # configure number of layers and filters
281
+ filters = bm.network_filters.split('-')
276
282
  filters = np.array(filters, dtype=int)
277
283
  latent_space_size = filters[-1]
278
284
  filters = filters[:-1]
285
+
286
+ # initialize blocks
279
287
  convs = []
280
288
 
289
+ # encoder
281
290
  i = 1
282
291
  for f in filters:
283
292
  if i==1:
284
- if resnet:
293
+ if bm.resnet:
285
294
  conv = make_conv_block_resnet(f, inputs, i)
286
295
  else:
287
- conv = make_conv_block(f, inputs, i)
296
+ conv = make_conv_block(f, inputs, i, dtype)
288
297
  else:
289
- if resnet:
298
+ if bm.resnet:
290
299
  conv = make_conv_block_resnet(f, pool, i)
291
300
  else:
292
- conv = make_conv_block(f, pool, i)
301
+ conv = make_conv_block(f, pool, i, dtype)
293
302
  pool = MaxPooling3D(pool_size=(2, 2, 2))(conv)
294
303
  convs.append(conv)
295
304
  i += 1
296
305
 
297
- if resnet:
306
+ # latent space
307
+ if bm.resnet:
298
308
  conv = make_conv_block_resnet(latent_space_size, pool, i)
299
309
  else:
300
- conv = make_conv_block(latent_space_size, pool, i)
310
+ conv = make_conv_block(latent_space_size, pool, i, dtype)
301
311
  i += 1
302
312
 
313
+ # decoder
303
314
  for k, f in enumerate(filters[::-1]):
304
315
  up = Concatenate()([UpSampling3D(size=(2, 2, 2))(conv), convs[-(k+1)]])
305
- if resnet:
316
+ if bm.resnet:
306
317
  conv = make_conv_block_resnet(f, up, i)
307
318
  else:
308
- conv = make_conv_block(f, up, i)
319
+ conv = make_conv_block(f, up, i, dtype)
309
320
  i += 1
310
321
 
322
+ # final layer and output
311
323
  conv = Conv3D(nb_labels, (1, 1, 1), name=f'conv_{i}_1')(conv)
312
-
313
324
  x = Reshape((nb_plans * nb_rows * nb_cols, nb_labels))(conv)
314
325
  x = Activation('softmax')(x)
315
326
  outputs = Reshape((nb_plans, nb_rows, nb_cols, nb_labels))(x)
316
-
317
327
  model = Model(inputs=inputs, outputs=outputs)
318
-
319
328
  return model
320
329
 
321
330
  def get_labels(arr, allLabels):
@@ -415,6 +424,13 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
415
424
  print(f'{os.path.basename(label_names[0])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
416
425
  label = img_resize(label, bm.z_scale, bm.y_scale, bm.x_scale, labels=True)
417
426
 
427
+ # label channel must be 1 or 2 if using ignore mask
428
+ if len(label.shape)>3 and label.shape[3]>1 and not bm.ignore_mask:
429
+ InputError.message = 'Training labels must have one channel (gray values).'
430
+ raise InputError()
431
+ if len(label.shape)==3:
432
+ label = label.reshape(label.shape[0], label.shape[1], label.shape[2], 1)
433
+
418
434
  # if header is not single data stream Amira Mesh falling back to Multi-TIFF
419
435
  if extension != '.am':
420
436
  extension, header = '.tif', None
@@ -424,7 +440,7 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
424
440
  else:
425
441
  header = header[0]
426
442
 
427
- # load first img
443
+ # load first image
428
444
  if any(img_list):
429
445
  img, _ = load_data(img_names[0], 'first_queue')
430
446
  if img is None:
@@ -436,14 +452,15 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
436
452
  else:
437
453
  img = img_in
438
454
  img_names = ['img_1']
439
- if label_dim != img.shape[:3]:
455
+
456
+ # label and image dimensions must match
457
+ if label_dim[:3] != img.shape[:3]:
440
458
  InputError.message = f'Dimensions of "{os.path.basename(img_names[0])}" and "{os.path.basename(label_names[0])}" do not match'
441
459
  raise InputError()
442
460
 
443
- # ensure images have channels >=1
461
+ # image channels must be >=1
444
462
  if len(img.shape)==3:
445
- z_shape, y_shape, x_shape = img.shape
446
- img = img.reshape(z_shape, y_shape, x_shape, 1)
463
+ img = img.reshape(img.shape[0], img.shape[1], img.shape[2], 1)
447
464
  if channels is None:
448
465
  channels = img.shape[3]
449
466
  if channels != img.shape[3]:
@@ -462,30 +479,30 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
462
479
  # scale data to the range from 0 to 1
463
480
  if not bm.patch_normalization:
464
481
  img = img.astype(np.float32)
465
- for c in range(channels):
466
- img[:,:,:,c] -= np.amin(img[:,:,:,c])
467
- img[:,:,:,c] /= np.amax(img[:,:,:,c])
482
+ for ch in range(channels):
483
+ img[...,ch] -= np.amin(img[...,ch])
484
+ img[...,ch] /= np.amax(img[...,ch])
468
485
 
469
486
  # normalize first validation image
470
487
  if bm.normalize and np.any(normalization_parameters):
471
488
  img = img.astype(np.float32)
472
- for c in range(channels):
473
- mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
474
- img[:,:,:,c] = (img[:,:,:,c] - mean) / std
475
- img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
489
+ for ch in range(channels):
490
+ mean, std = welford_mean_std(img[...,ch])
491
+ img[...,ch] = (img[...,ch] - mean) / std
492
+ img[...,ch] = img[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
476
493
 
477
494
  # get normalization parameters from first image
478
495
  if normalization_parameters is None:
479
496
  normalization_parameters = np.zeros((2,channels))
480
497
  if bm.normalize:
481
- for c in range(channels):
482
- normalization_parameters[0,c] = np.mean(img[:,:,:,c])
483
- normalization_parameters[1,c] = np.std(img[:,:,:,c])
498
+ for ch in range(channels):
499
+ normalization_parameters[:,ch] = welford_mean_std(img[...,ch])
484
500
 
485
501
  # pad data
486
502
  if not bm.scaling:
487
503
  img_data_list = [img]
488
504
  label_data_list = [label]
505
+ img_dtype = img.dtype
489
506
  # no-scaling for list of images needs negative values as it encodes padded areas as -1
490
507
  label_dtype = label.dtype
491
508
  if label_dtype==np.uint8:
@@ -499,7 +516,7 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
499
516
 
500
517
  for k in range(1, number_of_images):
501
518
 
502
- # append label
519
+ # load label data and pre-process
503
520
  if any(label_list):
504
521
  a, _ = load_data(label_names[k], 'first_queue')
505
522
  if a is None:
@@ -516,11 +533,21 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
516
533
  label_values, counts = unique(a, return_counts=True)
517
534
  print(f'{os.path.basename(label_names[k])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
518
535
  a = img_resize(a, bm.z_scale, bm.y_scale, bm.x_scale, labels=True)
536
+
537
+ # label channel must be 1 or 2 if using ignore mask
538
+ if len(a.shape)>3 and a.shape[3]>1 and not bm.ignore_mask:
539
+ InputError.message = 'Training labels must have one channel (gray values).'
540
+ raise InputError()
541
+ if len(a.shape)==3:
542
+ a = a.reshape(a.shape[0], a.shape[1], a.shape[2], 1)
543
+
544
+ # append label data
545
+ if bm.scaling:
519
546
  label = np.append(label, a, axis=0)
520
547
  else:
521
548
  label_data_list.append(a)
522
549
 
523
- # append image
550
+ # load image data and pre-process
524
551
  if any(img_list):
525
552
  a, _ = load_data(img_names[k], 'first_queue')
526
553
  if a is None:
@@ -528,12 +555,11 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
528
555
  raise InputError()
529
556
  else:
530
557
  a = img_in[k]
531
- if label_dim != a.shape[:3]:
558
+ if label_dim[:3] != a.shape[:3]:
532
559
  InputError.message = f'Dimensions of "{os.path.basename(img_names[k])}" and "{os.path.basename(label_names[k])}" do not match'
533
560
  raise InputError()
534
561
  if len(a.shape)==3:
535
- z_shape, y_shape, x_shape = a.shape
536
- a = a.reshape(z_shape, y_shape, x_shape, 1)
562
+ a = a.reshape(a.shape[0], a.shape[1], a.shape[2], 1)
537
563
  if a.shape[3] != channels:
538
564
  InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
539
565
  raise InputError()
@@ -544,15 +570,17 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
544
570
  a = img_resize(a, bm.z_scale, bm.y_scale, bm.x_scale)
545
571
  if not bm.patch_normalization:
546
572
  a = a.astype(np.float32)
547
- for c in range(channels):
548
- a[:,:,:,c] -= np.amin(a[:,:,:,c])
549
- a[:,:,:,c] /= np.amax(a[:,:,:,c])
573
+ for ch in range(channels):
574
+ a[...,ch] -= np.amin(a[...,ch])
575
+ a[...,ch] /= np.amax(a[...,ch])
550
576
  if bm.normalize:
551
577
  a = a.astype(np.float32)
552
- for c in range(channels):
553
- mean, std = np.mean(a[:,:,:,c]), np.std(a[:,:,:,c])
554
- a[:,:,:,c] = (a[:,:,:,c] - mean) / std
555
- a[:,:,:,c] = a[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
578
+ for ch in range(channels):
579
+ mean, std = welford_mean_std(a[...,ch])
580
+ a[...,ch] = (a[...,ch] - mean) / std
581
+ a[...,ch] = a[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
582
+
583
+ # append image data
556
584
  if bm.scaling:
557
585
  img = np.append(img, a, axis=0)
558
586
  else:
@@ -564,15 +592,14 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
564
592
  for img in img_data_list:
565
593
  target_y = max(target_y, img.shape[1])
566
594
  target_x = max(target_x, img.shape[2])
567
- img = np.empty((0, target_y, target_x, channels), dtype=np.float32)
568
- label = np.empty((0, target_y, target_x), dtype=label_dtype)
595
+ img = np.empty((0, target_y, target_x, channels), dtype=img_dtype)
596
+ label = np.empty((0, target_y, target_x, 2 if bm.ignore_mask else 1), dtype=label_dtype)
569
597
  for k in range(len(img_data_list)):
570
598
  pad_y = target_y - img_data_list[k].shape[1]
571
599
  pad_x = target_x - img_data_list[k].shape[2]
572
600
  pad_width = [(0, 0), (0, pad_y), (0, pad_x), (0, 0)]
573
601
  tmp = np.pad(img_data_list[k], pad_width, mode='constant', constant_values=0)
574
602
  img = np.append(img, tmp, axis=0)
575
- pad_width = [(0, 0), (0, pad_y), (0, pad_x)]
576
603
  tmp = np.pad(label_data_list[k].astype(label_dtype), pad_width, mode='constant', constant_values=-1)
577
604
  label = np.append(label, tmp, axis=0)
578
605
 
@@ -586,13 +613,13 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
586
613
  else:
587
614
  # get labels
588
615
  if allLabels is None:
589
- allLabels = unique(label)
616
+ allLabels = unique(label[...,0])
590
617
  index = np.argwhere(allLabels<0)
591
618
  allLabels = np.delete(allLabels, index)
592
619
 
593
620
  # labels must be in ascending order
594
621
  for k, l in enumerate(allLabels):
595
- label[label==l] = k
622
+ label[...,0][label[...,0]==l] = k
596
623
 
597
624
  return img, label, allLabels, normalization_parameters, header, extension, channels
598
625
 
@@ -724,10 +751,11 @@ class Metrics(Callback):
724
751
  m = rest % self.dim_img[2]
725
752
  tmp_X = self.img[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]]
726
753
  if self.patch_normalization:
727
- tmp_X = np.copy(tmp_X, order='C')
728
- for c in range(self.n_channels):
729
- tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
730
- tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
754
+ tmp_X = tmp_X.copy().astype(np.float32)
755
+ for ch in range(self.n_channels):
756
+ mean, std = welford_mean_std(tmp_X[...,ch])
757
+ tmp_X[...,ch] -= mean
758
+ tmp_X[...,ch] /= max(std, 1e-6)
731
759
  X_val[i] = tmp_X
732
760
 
733
761
  # Prediction segmentation
@@ -752,6 +780,7 @@ class Metrics(Callback):
752
780
  # get result
753
781
  result = np.argmax(result, axis=-1)
754
782
  result = result.astype(np.uint8)
783
+ result = result.reshape(*result.shape, 1)
755
784
 
756
785
  # calculate standard accuracy
757
786
  if not self.train:
@@ -771,17 +800,17 @@ class Metrics(Callback):
771
800
  logs['dice'] = dice
772
801
  else:
773
802
  # save best model only
774
- if epoch == 0 or round(dice,4) > max(self.history['val_dice']):
803
+ if epoch == 0 or dice > max(self.history['val_dice']):
775
804
  self.model.save(str(self.path_to_model))
776
805
 
777
806
  # add accuracy to history
778
- self.history['loss'].append(round(logs['loss'],4))
779
- self.history['accuracy'].append(round(logs['accuracy'],4))
807
+ self.history['loss'].append(logs['loss'])
808
+ self.history['accuracy'].append(logs['accuracy'])
780
809
  if self.train_dice:
781
- self.history['dice'].append(round(logs['dice'],4))
782
- self.history['val_accuracy'].append(round(accuracy,4))
783
- self.history['val_dice'].append(round(dice,4))
784
- self.history['val_loss'].append(round(val_loss,4))
810
+ self.history['dice'].append(logs['dice'])
811
+ self.history['val_accuracy'].append(accuracy)
812
+ self.history['val_dice'].append(dice)
813
+ self.history['val_loss'].append(val_loss)
785
814
 
786
815
  # tensorflow monitoring variables
787
816
  logs['val_loss'] = val_loss
@@ -798,11 +827,11 @@ class Metrics(Callback):
798
827
 
799
828
  # print accuracies
800
829
  print('\nValidation history:')
801
- print('train_acc:', self.history['accuracy'])
830
+ print("train_acc: [" + " ".join(f"{x:.4f}" for x in self.history['accuracy']) + "]")
802
831
  if self.train_dice:
803
- print('train_dice:', self.history['dice'])
804
- print('val_acc:', self.history['val_accuracy'])
805
- print('val_dice:', self.history['val_dice'])
832
+ print("train_dice: [" + " ".join(f"{x:.4f}" for x in self.history['dice']) + "]")
833
+ print("val_acc: [" + " ".join(f"{x:.4f}" for x in self.history['val_accuracy']) + "]")
834
+ print("val_dice: [" + " ".join(f"{x:.4f}" for x in self.history['val_dice']) + "]")
806
835
  print('')
807
836
 
808
837
  # early stopping
@@ -849,13 +878,13 @@ def categorical_crossentropy(true_labels, predicted_probs):
849
878
  # Clip predicted probabilities to avoid log(0) issues
850
879
  predicted_probs = np.clip(predicted_probs, 1e-7, 1 - 1e-7)
851
880
  predicted_probs = -np.log(predicted_probs)
852
- zsh,ysh,xsh = true_labels.shape
881
+ zsh, ysh, xsh, _ = true_labels.shape
853
882
  # Calculate categorical crossentropy
854
883
  loss = 0
855
884
  for z in range(zsh):
856
885
  for y in range(ysh):
857
886
  for x in range(xsh):
858
- l = true_labels[z,y,x]
887
+ l = true_labels[z,y,x,0]
859
888
  loss += predicted_probs[z,y,x,l]
860
889
  loss = loss / float(zsh*ysh*xsh)
861
890
  return loss
@@ -879,6 +908,42 @@ def dice_coef_loss(nb_labels):
879
908
  return loss
880
909
  return loss_fn
881
910
 
911
+ def custom_loss(y_true, y_pred):
912
+ # Extract labels and ignore mask
913
+ labels = tf.cast(y_true[..., 0], tf.int32) # First channel contains class labels
914
+ ignore_mask = tf.cast(y_true[..., 1], tf.float32) # Second channel contains mask (0 = ignore, 1 = include)
915
+
916
+ # Convert integer labels to one-hot encoding
917
+ y_true_one_hot = tf.one_hot(labels, depth=2)
918
+
919
+ # Clip y_pred to avoid log(0)
920
+ y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0)
921
+
922
+ # Compute categorical cross-entropy
923
+ loss = -tf.reduce_sum(y_true_one_hot * tf.math.log(y_pred), axis=-1)
924
+
925
+ # Apply ignore mask (ignore = 0 → loss is zero, include = 1 → loss is counted)
926
+ loss = loss * ignore_mask
927
+
928
+ # Return mean loss over valid (non-ignored) samples
929
+ return tf.reduce_sum(loss) / tf.reduce_sum(ignore_mask)
930
+
931
+ def custom_accuracy(y_true, y_pred):
932
+ labels = tf.cast(y_true[..., 0], tf.int32) # Extract actual values
933
+ ignore_mask = y_true[..., 1] # Extract mask (1 = include, 0 = ignore)
934
+
935
+ # Convert predictions to discrete values (assuming regression: round values)
936
+ y_pred_class = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
937
+
938
+ # Compute correct predictions (1 where correct, 0 where incorrect)
939
+ correct_predictions = tf.cast(tf.equal(labels, y_pred_class), tf.float32)
940
+
941
+ # Apply ignore mask
942
+ masked_correct_predictions = correct_predictions * ignore_mask
943
+
944
+ # Compute accuracy only over valid (non-ignored) pixels
945
+ return tf.reduce_sum(masked_correct_predictions) / tf.reduce_sum(ignore_mask)
946
+
882
947
  def train_segmentation(bm):
883
948
 
884
949
  # training data
@@ -988,18 +1053,19 @@ def train_segmentation(bm):
988
1053
  'n_channels': bm.channels,
989
1054
  'augment': (bm.flip_x, bm.flip_y, bm.flip_z, bm.swapaxes, bm.rotate, bm.rotate3d),
990
1055
  'patch_normalization': bm.patch_normalization,
991
- 'separation': bm.separation}
1056
+ 'separation': bm.separation,
1057
+ 'ignore_mask': bm.ignore_mask}
992
1058
 
993
1059
  # data generator
994
1060
  validation_generator = None
995
- training_generator = DataGenerator(bm.img_data, bm.label_data, list_IDs_fg, list_IDs_bg, True, True, False, **params)
1061
+ training_generator = DataGenerator(bm.img_data, bm.label_data, list_IDs_fg, list_IDs_bg, True, True, **params)
996
1062
  if bm.val_img_data is not None:
997
1063
  if bm.val_dice:
998
1064
  val_metrics = Metrics(bm, bm.val_img_data, bm.val_label_data, list_IDs_val_fg, (zsh_val, ysh_val, xsh_val), nb_labels, False)
999
1065
  else:
1000
1066
  params['dim_img'] = (zsh_val, ysh_val, xsh_val)
1001
1067
  params['augment'] = (False, False, False, False, 0, False)
1002
- validation_generator = DataGenerator(bm.val_img_data, bm.val_label_data, list_IDs_val_fg, list_IDs_val_bg, True, False, False, **params)
1068
+ validation_generator = DataGenerator(bm.val_img_data, bm.val_label_data, list_IDs_val_fg, list_IDs_val_bg, True, False, **params)
1003
1069
 
1004
1070
  # monitor dice score on training data
1005
1071
  if bm.train_dice:
@@ -1017,7 +1083,7 @@ def train_segmentation(bm):
1017
1083
  with strategy.scope():
1018
1084
 
1019
1085
  # build model
1020
- model = make_unet(input_shape, nb_labels, bm.network_filters, bm.resnet)
1086
+ model = make_unet(bm, input_shape, nb_labels)
1021
1087
  model.summary()
1022
1088
 
1023
1089
  # pretrained model
@@ -1036,13 +1102,28 @@ def train_segmentation(bm):
1036
1102
  layer.trainable = False
1037
1103
 
1038
1104
  # optimizer
1039
- sgd = SGD(learning_rate=bm.learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
1105
+ optimizer = SGD(learning_rate=bm.learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
1106
+ #optimizer = tf.keras.optimizers.Adam(learning_rate=bm.learning_rate, epsilon=1e-4) lr=0.0001
1107
+ if bm.mixed_precision:
1108
+ optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer, dynamic=False, initial_scale=128)
1109
+
1110
+ # Rename the function to appear as "accuracy" in logs
1111
+ if bm.ignore_mask:
1112
+ custom_accuracy.__name__ = "accuracy"
1113
+ metrics=[custom_accuracy]
1114
+ else:
1115
+ metrics=['accuracy']
1116
+
1117
+ # loss function
1118
+ if bm.ignore_mask:
1119
+ loss=custom_loss
1120
+ else:
1121
+ loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
1040
1122
 
1041
1123
  # comile model
1042
- loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
1043
1124
  model.compile(loss=loss,
1044
- optimizer=sgd,
1045
- metrics=['accuracy'])
1125
+ optimizer=optimizer,
1126
+ metrics=metrics)
1046
1127
 
1047
1128
  # save meta data
1048
1129
  meta_data = MetaData(bm.path_to_model, configuration_data, allLabels,
@@ -1081,7 +1162,7 @@ def train_segmentation(bm):
1081
1162
  callbacks=callbacks,
1082
1163
  workers=bm.workers)
1083
1164
 
1084
- def load_prediction_data(bm, channels, normalize, normalization_parameters,
1165
+ def load_prediction_data(bm, channels, normalization_parameters,
1085
1166
  region_of_interest, img, img_header, load_blockwise=False, z=None):
1086
1167
 
1087
1168
  # read image data
@@ -1109,10 +1190,9 @@ def load_prediction_data(bm, channels, normalize, normalization_parameters,
1109
1190
  if bm.acwe:
1110
1191
  img_data = img.copy()
1111
1192
 
1112
- # handle all images using number of channels >=1
1193
+ # image data must have number of channels >=1
1113
1194
  if len(img.shape)==3:
1114
- z_shape, y_shape, x_shape = img.shape
1115
- img = img.reshape(z_shape, y_shape, x_shape, 1)
1195
+ img = img.reshape(img.shape[0], img.shape[1], img.shape[2], 1)
1116
1196
  if img.shape[3] != channels:
1117
1197
  InputError.message = f'Number of channels must be {channels}.'
1118
1198
  raise InputError()
@@ -1127,22 +1207,27 @@ def load_prediction_data(bm, channels, normalize, normalization_parameters,
1127
1207
  region_of_interest = np.array([min_z,max_z,min_y,max_y,min_x,max_x,z_shape,y_shape,x_shape])
1128
1208
  z_shape, y_shape, x_shape = max_z-min_z, max_y-min_y, max_x-min_x
1129
1209
 
1130
- # scale/resize image data
1131
- img = img.astype(np.float32)
1210
+ # resize image data
1132
1211
  if bm.scaling:
1212
+ img = img.astype(np.float32)
1133
1213
  img = img_resize(img, bm.z_scale, bm.y_scale, bm.x_scale)
1134
1214
 
1215
+ # scale image data
1216
+ if not bm.patch_normalization:
1217
+ img = img.astype(np.float32)
1218
+ for ch in range(channels):
1219
+ img[...,ch] -= np.amin(img[...,ch])
1220
+ img[...,ch] /= np.amax(img[...,ch])
1221
+
1135
1222
  # normalize image data
1136
- for c in range(channels):
1137
- img[:,:,:,c] -= np.amin(img[:,:,:,c])
1138
- img[:,:,:,c] /= np.amax(img[:,:,:,c])
1139
- if normalize:
1140
- mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
1141
- img[:,:,:,c] = (img[:,:,:,c] - mean) / std
1142
- img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
1223
+ if bm.normalize:
1224
+ img = img.astype(np.float32)
1225
+ for ch in range(channels):
1226
+ mean, std = welford_mean_std(img[...,ch])
1227
+ img[...,ch] = (img[...,ch] - mean) / std
1228
+ img[...,ch] = img[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
1143
1229
 
1144
- # limit intensity range
1145
- if normalize:
1230
+ # limit intensity range
1146
1231
  img[img<0] = 0
1147
1232
  img[img>1] = 1
1148
1233
 
@@ -1185,6 +1270,20 @@ def gradient(volData):
1185
1270
  grad[grad>0]=1
1186
1271
  return grad
1187
1272
 
1273
+ @numba.jit(nopython=True)
1274
+ def scale_probabilities(final):
1275
+ zsh, ysh, xsh, nb_labels = final.shape
1276
+ for k in range(zsh):
1277
+ for l in range(ysh):
1278
+ for m in range(xsh):
1279
+ scale_factor = 0
1280
+ for n in range(nb_labels):
1281
+ scale_factor += final[k,l,m,n]
1282
+ scale_factor = max(1, scale_factor)
1283
+ for n in range(nb_labels):
1284
+ final[k,l,m,n] /= scale_factor
1285
+ return final
1286
+
1188
1287
  def predict_segmentation(bm, region_of_interest, channels, normalization_parameters):
1189
1288
 
1190
1289
  from mpi4py import MPI
@@ -1192,13 +1291,26 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1192
1291
  rank = comm.Get_rank()
1193
1292
  ngpus = comm.Get_size()
1194
1293
 
1294
+ # optional result paths
1295
+ if bm.path_to_image:
1296
+ filename, bm.extension = os.path.splitext(bm.path_to_final)
1297
+ if bm.extension == '.gz':
1298
+ bm.extension = '.nii.gz'
1299
+ filename = filename[:-4]
1300
+ path_to_cleaned = filename + '.cleaned' + bm.extension
1301
+ path_to_filled = filename + '.filled' + bm.extension
1302
+ path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
1303
+ path_to_refined = filename + '.refined' + bm.extension
1304
+ path_to_acwe = filename + '.acwe' + bm.extension
1305
+ path_to_probs = filename + '.probs.tif'
1306
+
1195
1307
  # Set the visible GPU by ID
1196
1308
  gpus = tf.config.experimental.list_physical_devices('GPU')
1197
1309
  if gpus:
1198
1310
  try:
1199
1311
  # Restrict TensorFlow to only use the specified GPU
1200
- tf.config.experimental.set_visible_devices(gpus[rank], 'GPU')
1201
- tf.config.experimental.set_memory_growth(gpus[rank], True)
1312
+ tf.config.experimental.set_visible_devices(gpus[rank % len(gpus)], 'GPU')
1313
+ tf.config.experimental.set_memory_growth(gpus[rank % len(gpus)], True)
1202
1314
  except RuntimeError as e:
1203
1315
  print(e)
1204
1316
 
@@ -1209,7 +1321,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1209
1321
  nb_labels = len(bm.allLabels)
1210
1322
  results['allLabels'] = bm.allLabels
1211
1323
 
1212
- # load model
1324
+ # custom objects
1213
1325
  if bm.dice_loss:
1214
1326
  def loss_fn(y_true, y_pred):
1215
1327
  dice = 0
@@ -1220,25 +1332,30 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1220
1332
  loss = 1 - dice
1221
1333
  return loss
1222
1334
  custom_objects = {'dice_coef_loss': dice_coef_loss,'loss_fn': loss_fn}
1223
- model = load_model(bm.path_to_model, custom_objects=custom_objects)
1335
+ elif bm.ignore_mask:
1336
+ custom_objects={'custom_loss': custom_loss}
1224
1337
  else:
1225
- model = load_model(bm.path_to_model)
1338
+ custom_objects=None
1339
+
1340
+ # load model
1341
+ model = load_model(bm.path_to_model, custom_objects=custom_objects)
1226
1342
 
1227
1343
  # check if data can be loaded blockwise to save host memory
1228
1344
  load_blockwise = False
1229
1345
  if not bm.scaling and not bm.normalize and bm.path_to_image and not np.any(region_of_interest) and \
1230
1346
  os.path.splitext(bm.path_to_image)[1] in ['.tif', '.tiff'] and not bm.acwe:
1347
+
1231
1348
  # get image shape
1232
1349
  tif = TiffFile(bm.path_to_image)
1233
1350
  zsh = len(tif.pages)
1234
1351
  ysh, xsh = tif.pages[0].shape
1235
1352
 
1236
1353
  # load mask
1237
- if bm.separation or bm.refinement:
1354
+ '''if bm.separation or bm.refinement:
1238
1355
  mask, _ = load_data(bm.mask)
1239
1356
  mask = mask.reshape(zsh, ysh, xsh, 1)
1240
1357
  mask, _, _, _ = append_ghost_areas(bm, mask)
1241
- mask = mask.reshape(mask.shape[:-1])
1358
+ mask = mask.reshape(mask.shape[:-1])'''
1242
1359
 
1243
1360
  # determine new image size after appending ghost areas to make image dimensions divisible by patch size
1244
1361
  z_rest = bm.z_patch - (zsh % bm.z_patch)
@@ -1258,7 +1375,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1258
1375
  xsh += x_rest
1259
1376
 
1260
1377
  # get Ids of patches
1261
- list_IDs = []
1378
+ '''list_IDs = []
1262
1379
  for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1263
1380
  for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1264
1381
  for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
@@ -1268,19 +1385,18 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1268
1385
  if centerLabel>0 and np.any(patch!=centerLabel):
1269
1386
  list_IDs.append(k*ysh*xsh+l*xsh+m)
1270
1387
  elif bm.refinement:
1271
- patch = mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1272
- if np.any(patch==0) and np.any(patch!=0):
1388
+ if np.any(mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
1273
1389
  list_IDs.append(k*ysh*xsh+l*xsh+m)
1274
1390
  else:
1275
- list_IDs.append(k*ysh*xsh+l*xsh+m)
1391
+ list_IDs.append(k*ysh*xsh+l*xsh+m)'''
1276
1392
 
1277
1393
  # make length of list divisible by batch size
1278
- max_i = len(list_IDs)
1394
+ '''max_i = len(list_IDs)
1279
1395
  rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
1280
- list_IDs = list_IDs + list_IDs[:rest]
1396
+ list_IDs = list_IDs + list_IDs[:rest]'''
1281
1397
 
1282
1398
  # prediction
1283
- if len(list_IDs) > 400:
1399
+ if zsh*ysh*xsh > 256**3:
1284
1400
  load_blockwise = True
1285
1401
 
1286
1402
  # load image data and calculate patch IDs
@@ -1288,7 +1404,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1288
1404
 
1289
1405
  # load prediction data
1290
1406
  img, bm.img_header, z_shape, y_shape, x_shape, region_of_interest, bm.img_data = load_prediction_data(
1291
- bm, channels, bm.normalize, normalization_parameters, region_of_interest, bm.img_data, bm.img_header)
1407
+ bm, channels, normalization_parameters, region_of_interest, bm.img_data, bm.img_header)
1292
1408
 
1293
1409
  # append ghost areas
1294
1410
  img, z_rest, y_rest, x_rest = append_ghost_areas(bm, img)
@@ -1314,6 +1430,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1314
1430
 
1315
1431
  # load all patches on GPU memory
1316
1432
  if not load_blockwise and nb_patches < 400:
1433
+ if rank==0:
1317
1434
 
1318
1435
  # parameters
1319
1436
  params = {'dim': (bm.z_patch, bm.y_patch, bm.x_patch),
@@ -1346,7 +1463,11 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1346
1463
 
1347
1464
  # allocate final probabilities array
1348
1465
  if rank==0 and bm.return_probs:
1349
- final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1466
+ if load_blockwise:
1467
+ if not os.path.exists(path_to_probs[:-4]):
1468
+ os.mkdir(path_to_probs[:-4])
1469
+ else:
1470
+ final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1350
1471
 
1351
1472
  # allocate final result array
1352
1473
  if rank==0:
@@ -1361,27 +1482,38 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1361
1482
  else:
1362
1483
  nprocs = ngpus
1363
1484
  if j % ngpus == rank:
1485
+
1364
1486
  # load blockwise from TIFF
1365
1487
  if load_blockwise:
1366
1488
  img, _, _, _, _, _, _ = load_prediction_data(bm,
1367
- channels, bm.normalize, normalization_parameters,
1489
+ channels, normalization_parameters,
1368
1490
  region_of_interest, bm.img_data, bm.img_header, load_blockwise, z)
1369
1491
  img, _, _, _ = append_ghost_areas(bm, img)
1370
1492
 
1493
+ # load mask block
1494
+ if bm.separation or bm.refinement:
1495
+ mask = imread(bm.mask, key=range(z,min(len(tif.pages),z+bm.z_patch)))
1496
+ # pad zeros to make dimensions divisible by patch dimensions
1497
+ pad_z = bm.z_patch - mask.shape[0]
1498
+ pad_y = (bm.y_patch - (mask.shape[1] % bm.y_patch)) % bm.y_patch
1499
+ pad_x = (bm.x_patch - (mask.shape[2] % bm.x_patch)) % bm.x_patch
1500
+ pad_width = [(0, pad_z), (0, pad_y), (0, pad_x)]
1501
+ mask = np.pad(mask, pad_width, mode='constant', constant_values=0)
1502
+
1371
1503
  # list of IDs
1372
1504
  list_IDs_block = []
1373
1505
 
1374
1506
  # get Ids of patches
1507
+ k = 0 if load_blockwise else z
1375
1508
  for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1376
1509
  for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1377
1510
  if bm.separation:
1378
- centerLabel = mask[z+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
1379
- patch = mask[z:z+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1511
+ centerLabel = mask[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
1512
+ patch = mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1380
1513
  if centerLabel>0 and np.any(patch!=centerLabel):
1381
1514
  list_IDs_block.append(z*ysh*xsh+l*xsh+m)
1382
1515
  elif bm.refinement:
1383
- patch = mask[z:z+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1384
- if np.any(patch==0) and np.any(patch!=0):
1516
+ if np.any(mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
1385
1517
  list_IDs_block.append(z*ysh*xsh+l*xsh+m)
1386
1518
  else:
1387
1519
  list_IDs_block.append(z*ysh*xsh+l*xsh+m)
@@ -1413,10 +1545,11 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1413
1545
  # get patch
1414
1546
  tmp_X = img[k:k+bm.z_patch,l:l+bm.y_patch,m:m+bm.x_patch]
1415
1547
  if bm.patch_normalization:
1416
- tmp_X = np.copy(tmp_X, order='C')
1417
- for c in range(channels):
1418
- tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
1419
- tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
1548
+ tmp_X = tmp_X.copy().astype(np.float32)
1549
+ for ch in range(channels):
1550
+ mean, std = welford_mean_std(tmp_X[...,ch])
1551
+ tmp_X[...,ch] -= mean
1552
+ tmp_X[...,ch] /= max(std, 1e-6)
1420
1553
  X[i] = tmp_X
1421
1554
 
1422
1555
  # predict batch
@@ -1460,19 +1593,28 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1460
1593
  # overlap in z direction
1461
1594
  if bm.stride_size < bm.z_patch:
1462
1595
  if j+source>0:
1463
- probs[:bm.stride_size] += overlap
1596
+ probs[:-bm.stride_size] += overlap
1464
1597
  overlap = probs[bm.stride_size:].copy()
1465
1598
 
1466
- # calculate result
1599
+ # block z dimension
1467
1600
  block_z = z_indices[j+source]
1468
- if j+source==len(z_indices)-1:
1469
- label[block_z:block_z+bm.z_patch] = np.argmax(probs, axis=-1).astype(np.uint8)
1470
- if bm.return_probs:
1471
- final[block_z:block_z+bm.z_patch] = probs
1601
+ if j+source==len(z_indices)-1: # last block
1602
+ block_zsh = bm.z_patch
1603
+ block_z_rest = z_rest if z_rest>0 else -block_zsh
1472
1604
  else:
1473
1605
  block_zsh = min(bm.stride_size, bm.z_patch)
1474
- label[block_z:block_z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
1475
- if bm.return_probs:
1606
+ block_z_rest = -block_zsh
1607
+
1608
+ # calculate result
1609
+ label[block_z:block_z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
1610
+
1611
+ # return probabilities
1612
+ if bm.return_probs:
1613
+ if load_blockwise:
1614
+ block_output = scale_probabilities(probs[:block_zsh])
1615
+ block_output = block_output[:-block_z_rest,:-y_rest,:-x_rest]
1616
+ imwrite(path_to_probs[:-4] + f"/block-{j+source}.tif", block_output)
1617
+ else:
1476
1618
  final[block_z:block_z+block_zsh] = probs[:block_zsh]
1477
1619
  else:
1478
1620
  for i in range(bm.z_patch):
@@ -1480,7 +1622,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1480
1622
  if rank==0:
1481
1623
 
1482
1624
  # refine mask data with result
1483
- if bm.refinement:
1625
+ '''if bm.refinement:
1484
1626
  # loop over boundary patches
1485
1627
  for i, ID in enumerate(list_IDs):
1486
1628
  if i < max_i:
@@ -1489,25 +1631,17 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1489
1631
  l = rest // xsh
1490
1632
  m = rest % xsh
1491
1633
  mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] = label[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1492
- label = mask
1634
+ label = mask'''
1493
1635
 
1494
- # remove appendix
1495
- if bm.return_probs:
1636
+ # remove ghost areas
1637
+ if bm.return_probs and not load_blockwise:
1496
1638
  final = final[:-z_rest,:-y_rest,:-x_rest]
1497
1639
  label = label[:-z_rest,:-y_rest,:-x_rest]
1498
1640
  zsh, ysh, xsh = label.shape
1499
1641
 
1500
1642
  # return probabilities
1501
- if bm.return_probs:
1502
- counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1503
- nb = 0
1504
- for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1505
- for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1506
- for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1507
- counter[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += 1
1508
- nb += 1
1509
- counter[counter==0] = 1
1510
- probabilities = final / counter
1643
+ if bm.return_probs and not load_blockwise:
1644
+ probabilities = scale_probabilities(final)
1511
1645
  if bm.scaling:
1512
1646
  probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
1513
1647
  if np.any(region_of_interest):
@@ -1570,17 +1704,8 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1570
1704
  # save result
1571
1705
  if bm.path_to_image:
1572
1706
  save_data(bm.path_to_final, label, header=bm.header, compress=bm.compression)
1573
-
1574
- # paths to optional results
1575
- filename, bm.extension = os.path.splitext(bm.path_to_final)
1576
- if bm.extension == '.gz':
1577
- bm.extension = '.nii.gz'
1578
- filename = filename[:-4]
1579
- path_to_cleaned = filename + '.cleaned' + bm.extension
1580
- path_to_filled = filename + '.filled' + bm.extension
1581
- path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
1582
- path_to_refined = filename + '.refined' + bm.extension
1583
- path_to_acwe = filename + '.acwe' + bm.extension
1707
+ if bm.return_probs and not load_blockwise:
1708
+ imwrite(path_to_probs, probabilities)
1584
1709
 
1585
1710
  # remove outliers
1586
1711
  if bm.clean: