biomedisa 24.8.10__py3-none-any.whl → 25.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  ##########################################################################
2
2
  ## ##
3
- ## Copyright (c) 2019-2024 Philipp Lösel. All rights reserved. ##
3
+ ## Copyright (c) 2019-2025 Philipp Lösel. All rights reserved. ##
4
4
  ## ##
5
5
  ## This file is part of the open source project biomedisa. ##
6
6
  ## ##
@@ -39,13 +39,13 @@ from tensorflow.keras.layers import (
39
39
  from tensorflow.keras import backend as K
40
40
  from tensorflow.keras.utils import to_categorical
41
41
  from tensorflow.keras.callbacks import Callback, ModelCheckpoint, EarlyStopping
42
- from biomedisa.features.DataGenerator import DataGenerator
42
+ from biomedisa.features.DataGenerator import DataGenerator, welford_mean_std
43
43
  from biomedisa.features.PredictDataGenerator import PredictDataGenerator
44
- from biomedisa.features.biomedisa_helper import (
44
+ from biomedisa.features.biomedisa_helper import (unique, welford_mean_std,
45
45
  img_resize, load_data, save_data, set_labels_to_zero, id_generator, unique_file_path)
46
46
  from biomedisa.features.remove_outlier import clean, fill
47
47
  from biomedisa.features.active_contour import activeContour
48
- from tifffile import TiffFile, imread
48
+ from tifffile import TiffFile, imread, imwrite
49
49
  import matplotlib.pyplot as plt
50
50
  import SimpleITK as sitk
51
51
  import tensorflow as tf
@@ -220,11 +220,11 @@ def compute_position(position, zsh, ysh, xsh):
220
220
  position[k,l,m] = x+y+z
221
221
  return position
222
222
 
223
- def make_conv_block(nb_filters, input_tensor, block):
223
+ def make_conv_block(nb_filters, input_tensor, block, dtype):
224
224
  def make_stage(input_tensor, stage):
225
225
  name = 'conv_{}_{}'.format(block, stage)
226
226
  x = Conv3D(nb_filters, (3, 3, 3), activation='relu',
227
- padding='same', name=name, data_format="channels_last")(input_tensor)
227
+ padding='same', name=name, data_format="channels_last", dtype=dtype)(input_tensor)
228
228
  name = 'batch_norm_{}_{}'.format(block, stage)
229
229
  try:
230
230
  x = BatchNormalization(name=name, synchronized=True)(x)
@@ -266,62 +266,70 @@ def make_conv_block_resnet(nb_filters, input_tensor, block):
266
266
 
267
267
  return out
268
268
 
269
- def make_unet(input_shape, nb_labels, filters='32-64-128-256-512', resnet=False):
269
+ def make_unet(bm, input_shape, nb_labels):
270
+ # enable mixed_precision
271
+ if bm.mixed_precision:
272
+ dtype = "float16"
273
+ else:
274
+ dtype = "float32"
270
275
 
276
+ # input
271
277
  nb_plans, nb_rows, nb_cols, _ = input_shape
278
+ inputs = Input(input_shape, dtype=dtype)
272
279
 
273
- inputs = Input(input_shape)
274
-
275
- filters = filters.split('-')
280
+ # configure number of layers and filters
281
+ filters = bm.network_filters.split('-')
276
282
  filters = np.array(filters, dtype=int)
277
283
  latent_space_size = filters[-1]
278
284
  filters = filters[:-1]
285
+
286
+ # initialize blocks
279
287
  convs = []
280
288
 
289
+ # encoder
281
290
  i = 1
282
291
  for f in filters:
283
292
  if i==1:
284
- if resnet:
293
+ if bm.resnet:
285
294
  conv = make_conv_block_resnet(f, inputs, i)
286
295
  else:
287
- conv = make_conv_block(f, inputs, i)
296
+ conv = make_conv_block(f, inputs, i, dtype)
288
297
  else:
289
- if resnet:
298
+ if bm.resnet:
290
299
  conv = make_conv_block_resnet(f, pool, i)
291
300
  else:
292
- conv = make_conv_block(f, pool, i)
301
+ conv = make_conv_block(f, pool, i, dtype)
293
302
  pool = MaxPooling3D(pool_size=(2, 2, 2))(conv)
294
303
  convs.append(conv)
295
304
  i += 1
296
305
 
297
- if resnet:
306
+ # latent space
307
+ if bm.resnet:
298
308
  conv = make_conv_block_resnet(latent_space_size, pool, i)
299
309
  else:
300
- conv = make_conv_block(latent_space_size, pool, i)
310
+ conv = make_conv_block(latent_space_size, pool, i, dtype)
301
311
  i += 1
302
312
 
313
+ # decoder
303
314
  for k, f in enumerate(filters[::-1]):
304
315
  up = Concatenate()([UpSampling3D(size=(2, 2, 2))(conv), convs[-(k+1)]])
305
- if resnet:
316
+ if bm.resnet:
306
317
  conv = make_conv_block_resnet(f, up, i)
307
318
  else:
308
- conv = make_conv_block(f, up, i)
319
+ conv = make_conv_block(f, up, i, dtype)
309
320
  i += 1
310
321
 
322
+ # final layer and output
311
323
  conv = Conv3D(nb_labels, (1, 1, 1), name=f'conv_{i}_1')(conv)
312
-
313
324
  x = Reshape((nb_plans * nb_rows * nb_cols, nb_labels))(conv)
314
325
  x = Activation('softmax')(x)
315
326
  outputs = Reshape((nb_plans, nb_rows, nb_cols, nb_labels))(x)
316
-
317
327
  model = Model(inputs=inputs, outputs=outputs)
318
-
319
328
  return model
320
329
 
321
330
  def get_labels(arr, allLabels):
322
- np_unique = np.unique(arr)
323
331
  final = np.zeros_like(arr)
324
- for k in np_unique:
332
+ for k in unique(arr):
325
333
  final[arr == k] = allLabels[k]
326
334
  return final
327
335
 
@@ -412,10 +420,17 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
412
420
  argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(label, x_puffer, y_puffer, z_puffer)
413
421
  label = label[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x].copy()
414
422
  if bm.scaling:
415
- label_values, counts = np.unique(label, return_counts=True)
423
+ label_values, counts = unique(label, return_counts=True)
416
424
  print(f'{os.path.basename(label_names[0])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
417
425
  label = img_resize(label, bm.z_scale, bm.y_scale, bm.x_scale, labels=True)
418
426
 
427
+ # label channel must be 1 or 2 if using ignore mask
428
+ if len(label.shape)>3 and label.shape[3]>1 and not bm.ignore_mask:
429
+ InputError.message = 'Training labels must have one channel (gray values).'
430
+ raise InputError()
431
+ if len(label.shape)==3:
432
+ label = label.reshape(label.shape[0], label.shape[1], label.shape[2], 1)
433
+
419
434
  # if header is not single data stream Amira Mesh falling back to Multi-TIFF
420
435
  if extension != '.am':
421
436
  extension, header = '.tif', None
@@ -425,7 +440,7 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
425
440
  else:
426
441
  header = header[0]
427
442
 
428
- # load first img
443
+ # load first image
429
444
  if any(img_list):
430
445
  img, _ = load_data(img_names[0], 'first_queue')
431
446
  if img is None:
@@ -437,14 +452,15 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
437
452
  else:
438
453
  img = img_in
439
454
  img_names = ['img_1']
440
- if label_dim != img.shape[:3]:
455
+
456
+ # label and image dimensions must match
457
+ if label_dim[:3] != img.shape[:3]:
441
458
  InputError.message = f'Dimensions of "{os.path.basename(img_names[0])}" and "{os.path.basename(label_names[0])}" do not match'
442
459
  raise InputError()
443
460
 
444
- # ensure images have channels >=1
461
+ # image channels must be >=1
445
462
  if len(img.shape)==3:
446
- z_shape, y_shape, x_shape = img.shape
447
- img = img.reshape(z_shape, y_shape, x_shape, 1)
463
+ img = img.reshape(img.shape[0], img.shape[1], img.shape[2], 1)
448
464
  if channels is None:
449
465
  channels = img.shape[3]
450
466
  if channels != img.shape[3]:
@@ -463,30 +479,30 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
463
479
  # scale data to the range from 0 to 1
464
480
  if not bm.patch_normalization:
465
481
  img = img.astype(np.float32)
466
- for c in range(channels):
467
- img[:,:,:,c] -= np.amin(img[:,:,:,c])
468
- img[:,:,:,c] /= np.amax(img[:,:,:,c])
482
+ for ch in range(channels):
483
+ img[...,ch] -= np.amin(img[...,ch])
484
+ img[...,ch] /= np.amax(img[...,ch])
469
485
 
470
486
  # normalize first validation image
471
487
  if bm.normalize and np.any(normalization_parameters):
472
488
  img = img.astype(np.float32)
473
- for c in range(channels):
474
- mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
475
- img[:,:,:,c] = (img[:,:,:,c] - mean) / std
476
- img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
489
+ for ch in range(channels):
490
+ mean, std = welford_mean_std(img[...,ch])
491
+ img[...,ch] = (img[...,ch] - mean) / std
492
+ img[...,ch] = img[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
477
493
 
478
494
  # get normalization parameters from first image
479
495
  if normalization_parameters is None:
480
496
  normalization_parameters = np.zeros((2,channels))
481
497
  if bm.normalize:
482
- for c in range(channels):
483
- normalization_parameters[0,c] = np.mean(img[:,:,:,c])
484
- normalization_parameters[1,c] = np.std(img[:,:,:,c])
498
+ for ch in range(channels):
499
+ normalization_parameters[:,ch] = welford_mean_std(img[...,ch])
485
500
 
486
501
  # pad data
487
502
  if not bm.scaling:
488
503
  img_data_list = [img]
489
504
  label_data_list = [label]
505
+ img_dtype = img.dtype
490
506
  # no-scaling for list of images needs negative values as it encodes padded areas as -1
491
507
  label_dtype = label.dtype
492
508
  if label_dtype==np.uint8:
@@ -500,7 +516,7 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
500
516
 
501
517
  for k in range(1, number_of_images):
502
518
 
503
- # append label
519
+ # load label data and pre-process
504
520
  if any(label_list):
505
521
  a, _ = load_data(label_names[k], 'first_queue')
506
522
  if a is None:
@@ -514,14 +530,24 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
514
530
  argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(a, x_puffer, y_puffer, z_puffer)
515
531
  a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
516
532
  if bm.scaling:
517
- label_values, counts = np.unique(a, return_counts=True)
533
+ label_values, counts = unique(a, return_counts=True)
518
534
  print(f'{os.path.basename(label_names[k])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
519
535
  a = img_resize(a, bm.z_scale, bm.y_scale, bm.x_scale, labels=True)
536
+
537
+ # label channel must be 1 or 2 if using ignore mask
538
+ if len(a.shape)>3 and a.shape[3]>1 and not bm.ignore_mask:
539
+ InputError.message = 'Training labels must have one channel (gray values).'
540
+ raise InputError()
541
+ if len(a.shape)==3:
542
+ a = a.reshape(a.shape[0], a.shape[1], a.shape[2], 1)
543
+
544
+ # append label data
545
+ if bm.scaling:
520
546
  label = np.append(label, a, axis=0)
521
547
  else:
522
548
  label_data_list.append(a)
523
549
 
524
- # append image
550
+ # load image data and pre-process
525
551
  if any(img_list):
526
552
  a, _ = load_data(img_names[k], 'first_queue')
527
553
  if a is None:
@@ -529,12 +555,11 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
529
555
  raise InputError()
530
556
  else:
531
557
  a = img_in[k]
532
- if label_dim != a.shape[:3]:
558
+ if label_dim[:3] != a.shape[:3]:
533
559
  InputError.message = f'Dimensions of "{os.path.basename(img_names[k])}" and "{os.path.basename(label_names[k])}" do not match'
534
560
  raise InputError()
535
561
  if len(a.shape)==3:
536
- z_shape, y_shape, x_shape = a.shape
537
- a = a.reshape(z_shape, y_shape, x_shape, 1)
562
+ a = a.reshape(a.shape[0], a.shape[1], a.shape[2], 1)
538
563
  if a.shape[3] != channels:
539
564
  InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
540
565
  raise InputError()
@@ -545,15 +570,17 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
545
570
  a = img_resize(a, bm.z_scale, bm.y_scale, bm.x_scale)
546
571
  if not bm.patch_normalization:
547
572
  a = a.astype(np.float32)
548
- for c in range(channels):
549
- a[:,:,:,c] -= np.amin(a[:,:,:,c])
550
- a[:,:,:,c] /= np.amax(a[:,:,:,c])
573
+ for ch in range(channels):
574
+ a[...,ch] -= np.amin(a[...,ch])
575
+ a[...,ch] /= np.amax(a[...,ch])
551
576
  if bm.normalize:
552
577
  a = a.astype(np.float32)
553
- for c in range(channels):
554
- mean, std = np.mean(a[:,:,:,c]), np.std(a[:,:,:,c])
555
- a[:,:,:,c] = (a[:,:,:,c] - mean) / std
556
- a[:,:,:,c] = a[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
578
+ for ch in range(channels):
579
+ mean, std = welford_mean_std(a[...,ch])
580
+ a[...,ch] = (a[...,ch] - mean) / std
581
+ a[...,ch] = a[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
582
+
583
+ # append image data
557
584
  if bm.scaling:
558
585
  img = np.append(img, a, axis=0)
559
586
  else:
@@ -565,15 +592,14 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
565
592
  for img in img_data_list:
566
593
  target_y = max(target_y, img.shape[1])
567
594
  target_x = max(target_x, img.shape[2])
568
- img = np.empty((0, target_y, target_x, channels), dtype=np.float32)
569
- label = np.empty((0, target_y, target_x), dtype=label_dtype)
595
+ img = np.empty((0, target_y, target_x, channels), dtype=img_dtype)
596
+ label = np.empty((0, target_y, target_x, 2 if bm.ignore_mask else 1), dtype=label_dtype)
570
597
  for k in range(len(img_data_list)):
571
598
  pad_y = target_y - img_data_list[k].shape[1]
572
599
  pad_x = target_x - img_data_list[k].shape[2]
573
600
  pad_width = [(0, 0), (0, pad_y), (0, pad_x), (0, 0)]
574
601
  tmp = np.pad(img_data_list[k], pad_width, mode='constant', constant_values=0)
575
602
  img = np.append(img, tmp, axis=0)
576
- pad_width = [(0, 0), (0, pad_y), (0, pad_x)]
577
603
  tmp = np.pad(label_data_list[k].astype(label_dtype), pad_width, mode='constant', constant_values=-1)
578
604
  label = np.append(label, tmp, axis=0)
579
605
 
@@ -587,13 +613,13 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
587
613
  else:
588
614
  # get labels
589
615
  if allLabels is None:
590
- allLabels = np.unique(label)
616
+ allLabels = unique(label[...,0])
591
617
  index = np.argwhere(allLabels<0)
592
618
  allLabels = np.delete(allLabels, index)
593
619
 
594
620
  # labels must be in ascending order
595
621
  for k, l in enumerate(allLabels):
596
- label[label==l] = k
622
+ label[...,0][label[...,0]==l] = k
597
623
 
598
624
  return img, label, allLabels, normalization_parameters, header, extension, channels
599
625
 
@@ -725,10 +751,11 @@ class Metrics(Callback):
725
751
  m = rest % self.dim_img[2]
726
752
  tmp_X = self.img[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]]
727
753
  if self.patch_normalization:
728
- tmp_X = np.copy(tmp_X, order='C')
729
- for c in range(self.n_channels):
730
- tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
731
- tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
754
+ tmp_X = tmp_X.copy().astype(np.float32)
755
+ for ch in range(self.n_channels):
756
+ mean, std = welford_mean_std(tmp_X[...,ch])
757
+ tmp_X[...,ch] -= mean
758
+ tmp_X[...,ch] /= max(std, 1e-6)
732
759
  X_val[i] = tmp_X
733
760
 
734
761
  # Prediction segmentation
@@ -753,6 +780,7 @@ class Metrics(Callback):
753
780
  # get result
754
781
  result = np.argmax(result, axis=-1)
755
782
  result = result.astype(np.uint8)
783
+ result = result.reshape(*result.shape, 1)
756
784
 
757
785
  # calculate standard accuracy
758
786
  if not self.train:
@@ -772,17 +800,17 @@ class Metrics(Callback):
772
800
  logs['dice'] = dice
773
801
  else:
774
802
  # save best model only
775
- if epoch == 0 or round(dice,4) > max(self.history['val_dice']):
803
+ if epoch == 0 or dice > max(self.history['val_dice']):
776
804
  self.model.save(str(self.path_to_model))
777
805
 
778
806
  # add accuracy to history
779
- self.history['loss'].append(round(logs['loss'],4))
780
- self.history['accuracy'].append(round(logs['accuracy'],4))
807
+ self.history['loss'].append(logs['loss'])
808
+ self.history['accuracy'].append(logs['accuracy'])
781
809
  if self.train_dice:
782
- self.history['dice'].append(round(logs['dice'],4))
783
- self.history['val_accuracy'].append(round(accuracy,4))
784
- self.history['val_dice'].append(round(dice,4))
785
- self.history['val_loss'].append(round(val_loss,4))
810
+ self.history['dice'].append(logs['dice'])
811
+ self.history['val_accuracy'].append(accuracy)
812
+ self.history['val_dice'].append(dice)
813
+ self.history['val_loss'].append(val_loss)
786
814
 
787
815
  # tensorflow monitoring variables
788
816
  logs['val_loss'] = val_loss
@@ -799,11 +827,11 @@ class Metrics(Callback):
799
827
 
800
828
  # print accuracies
801
829
  print('\nValidation history:')
802
- print('train_acc:', self.history['accuracy'])
830
+ print("train_acc: [" + " ".join(f"{x:.4f}" for x in self.history['accuracy']) + "]")
803
831
  if self.train_dice:
804
- print('train_dice:', self.history['dice'])
805
- print('val_acc:', self.history['val_accuracy'])
806
- print('val_dice:', self.history['val_dice'])
832
+ print("train_dice: [" + " ".join(f"{x:.4f}" for x in self.history['dice']) + "]")
833
+ print("val_acc: [" + " ".join(f"{x:.4f}" for x in self.history['val_accuracy']) + "]")
834
+ print("val_dice: [" + " ".join(f"{x:.4f}" for x in self.history['val_dice']) + "]")
807
835
  print('')
808
836
 
809
837
  # early stopping
@@ -850,13 +878,13 @@ def categorical_crossentropy(true_labels, predicted_probs):
850
878
  # Clip predicted probabilities to avoid log(0) issues
851
879
  predicted_probs = np.clip(predicted_probs, 1e-7, 1 - 1e-7)
852
880
  predicted_probs = -np.log(predicted_probs)
853
- zsh,ysh,xsh = true_labels.shape
881
+ zsh, ysh, xsh, _ = true_labels.shape
854
882
  # Calculate categorical crossentropy
855
883
  loss = 0
856
884
  for z in range(zsh):
857
885
  for y in range(ysh):
858
886
  for x in range(xsh):
859
- l = true_labels[z,y,x]
887
+ l = true_labels[z,y,x,0]
860
888
  loss += predicted_probs[z,y,x,l]
861
889
  loss = loss / float(zsh*ysh*xsh)
862
890
  return loss
@@ -880,6 +908,42 @@ def dice_coef_loss(nb_labels):
880
908
  return loss
881
909
  return loss_fn
882
910
 
911
+ def custom_loss(y_true, y_pred):
912
+ # Extract labels and ignore mask
913
+ labels = tf.cast(y_true[..., 0], tf.int32) # First channel contains class labels
914
+ ignore_mask = tf.cast(y_true[..., 1], tf.float32) # Second channel contains mask (0 = ignore, 1 = include)
915
+
916
+ # Convert integer labels to one-hot encoding
917
+ y_true_one_hot = tf.one_hot(labels, depth=2)
918
+
919
+ # Clip y_pred to avoid log(0)
920
+ y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0)
921
+
922
+ # Compute categorical cross-entropy
923
+ loss = -tf.reduce_sum(y_true_one_hot * tf.math.log(y_pred), axis=-1)
924
+
925
+ # Apply ignore mask (ignore = 0 → loss is zero, include = 1 → loss is counted)
926
+ loss = loss * ignore_mask
927
+
928
+ # Return mean loss over valid (non-ignored) samples
929
+ return tf.reduce_sum(loss) / tf.reduce_sum(ignore_mask)
930
+
931
+ def custom_accuracy(y_true, y_pred):
932
+ labels = tf.cast(y_true[..., 0], tf.int32) # Extract actual values
933
+ ignore_mask = y_true[..., 1] # Extract mask (1 = include, 0 = ignore)
934
+
935
+ # Convert predictions to discrete values (assuming regression: round values)
936
+ y_pred_class = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
937
+
938
+ # Compute correct predictions (1 where correct, 0 where incorrect)
939
+ correct_predictions = tf.cast(tf.equal(labels, y_pred_class), tf.float32)
940
+
941
+ # Apply ignore mask
942
+ masked_correct_predictions = correct_predictions * ignore_mask
943
+
944
+ # Compute accuracy only over valid (non-ignored) pixels
945
+ return tf.reduce_sum(masked_correct_predictions) / tf.reduce_sum(ignore_mask)
946
+
883
947
  def train_segmentation(bm):
884
948
 
885
949
  # training data
@@ -987,20 +1051,21 @@ def train_segmentation(bm):
987
1051
  'dim_img': (zsh, ysh, xsh),
988
1052
  'n_classes': nb_labels,
989
1053
  'n_channels': bm.channels,
990
- 'augment': (bm.flip_x, bm.flip_y, bm.flip_z, bm.swapaxes, bm.rotate),
1054
+ 'augment': (bm.flip_x, bm.flip_y, bm.flip_z, bm.swapaxes, bm.rotate, bm.rotate3d),
991
1055
  'patch_normalization': bm.patch_normalization,
992
- 'separation': bm.separation}
1056
+ 'separation': bm.separation,
1057
+ 'ignore_mask': bm.ignore_mask}
993
1058
 
994
1059
  # data generator
995
1060
  validation_generator = None
996
- training_generator = DataGenerator(bm.img_data, bm.label_data, list_IDs_fg, list_IDs_bg, True, True, False, **params)
1061
+ training_generator = DataGenerator(bm.img_data, bm.label_data, list_IDs_fg, list_IDs_bg, True, True, **params)
997
1062
  if bm.val_img_data is not None:
998
1063
  if bm.val_dice:
999
1064
  val_metrics = Metrics(bm, bm.val_img_data, bm.val_label_data, list_IDs_val_fg, (zsh_val, ysh_val, xsh_val), nb_labels, False)
1000
1065
  else:
1001
1066
  params['dim_img'] = (zsh_val, ysh_val, xsh_val)
1002
- params['augment'] = (False, False, False, False, 0)
1003
- validation_generator = DataGenerator(bm.val_img_data, bm.val_label_data, list_IDs_val_fg, list_IDs_val_bg, True, False, False, **params)
1067
+ params['augment'] = (False, False, False, False, 0, False)
1068
+ validation_generator = DataGenerator(bm.val_img_data, bm.val_label_data, list_IDs_val_fg, list_IDs_val_bg, True, False, **params)
1004
1069
 
1005
1070
  # monitor dice score on training data
1006
1071
  if bm.train_dice:
@@ -1018,7 +1083,7 @@ def train_segmentation(bm):
1018
1083
  with strategy.scope():
1019
1084
 
1020
1085
  # build model
1021
- model = make_unet(input_shape, nb_labels, bm.network_filters, bm.resnet)
1086
+ model = make_unet(bm, input_shape, nb_labels)
1022
1087
  model.summary()
1023
1088
 
1024
1089
  # pretrained model
@@ -1037,13 +1102,28 @@ def train_segmentation(bm):
1037
1102
  layer.trainable = False
1038
1103
 
1039
1104
  # optimizer
1040
- sgd = SGD(learning_rate=bm.learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
1105
+ optimizer = SGD(learning_rate=bm.learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
1106
+ #optimizer = tf.keras.optimizers.Adam(learning_rate=bm.learning_rate, epsilon=1e-4) lr=0.0001
1107
+ if bm.mixed_precision:
1108
+ optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer, dynamic=False, initial_scale=128)
1109
+
1110
+ # Rename the function to appear as "accuracy" in logs
1111
+ if bm.ignore_mask:
1112
+ custom_accuracy.__name__ = "accuracy"
1113
+ metrics=[custom_accuracy]
1114
+ else:
1115
+ metrics=['accuracy']
1116
+
1117
+ # loss function
1118
+ if bm.ignore_mask:
1119
+ loss=custom_loss
1120
+ else:
1121
+ loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
1041
1122
 
1042
1123
  # comile model
1043
- loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
1044
1124
  model.compile(loss=loss,
1045
- optimizer=sgd,
1046
- metrics=['accuracy'])
1125
+ optimizer=optimizer,
1126
+ metrics=metrics)
1047
1127
 
1048
1128
  # save meta data
1049
1129
  meta_data = MetaData(bm.path_to_model, configuration_data, allLabels,
@@ -1082,7 +1162,7 @@ def train_segmentation(bm):
1082
1162
  callbacks=callbacks,
1083
1163
  workers=bm.workers)
1084
1164
 
1085
- def load_prediction_data(bm, channels, normalize, normalization_parameters,
1165
+ def load_prediction_data(bm, channels, normalization_parameters,
1086
1166
  region_of_interest, img, img_header, load_blockwise=False, z=None):
1087
1167
 
1088
1168
  # read image data
@@ -1110,10 +1190,9 @@ def load_prediction_data(bm, channels, normalize, normalization_parameters,
1110
1190
  if bm.acwe:
1111
1191
  img_data = img.copy()
1112
1192
 
1113
- # handle all images using number of channels >=1
1193
+ # image data must have number of channels >=1
1114
1194
  if len(img.shape)==3:
1115
- z_shape, y_shape, x_shape = img.shape
1116
- img = img.reshape(z_shape, y_shape, x_shape, 1)
1195
+ img = img.reshape(img.shape[0], img.shape[1], img.shape[2], 1)
1117
1196
  if img.shape[3] != channels:
1118
1197
  InputError.message = f'Number of channels must be {channels}.'
1119
1198
  raise InputError()
@@ -1128,22 +1207,27 @@ def load_prediction_data(bm, channels, normalize, normalization_parameters,
1128
1207
  region_of_interest = np.array([min_z,max_z,min_y,max_y,min_x,max_x,z_shape,y_shape,x_shape])
1129
1208
  z_shape, y_shape, x_shape = max_z-min_z, max_y-min_y, max_x-min_x
1130
1209
 
1131
- # scale/resize image data
1132
- img = img.astype(np.float32)
1210
+ # resize image data
1133
1211
  if bm.scaling:
1212
+ img = img.astype(np.float32)
1134
1213
  img = img_resize(img, bm.z_scale, bm.y_scale, bm.x_scale)
1135
1214
 
1215
+ # scale image data
1216
+ if not bm.patch_normalization:
1217
+ img = img.astype(np.float32)
1218
+ for ch in range(channels):
1219
+ img[...,ch] -= np.amin(img[...,ch])
1220
+ img[...,ch] /= np.amax(img[...,ch])
1221
+
1136
1222
  # normalize image data
1137
- for c in range(channels):
1138
- img[:,:,:,c] -= np.amin(img[:,:,:,c])
1139
- img[:,:,:,c] /= np.amax(img[:,:,:,c])
1140
- if normalize:
1141
- mean, std = np.mean(img[:,:,:,c]), np.std(img[:,:,:,c])
1142
- img[:,:,:,c] = (img[:,:,:,c] - mean) / std
1143
- img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
1223
+ if bm.normalize:
1224
+ img = img.astype(np.float32)
1225
+ for ch in range(channels):
1226
+ mean, std = welford_mean_std(img[...,ch])
1227
+ img[...,ch] = (img[...,ch] - mean) / std
1228
+ img[...,ch] = img[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
1144
1229
 
1145
- # limit intensity range
1146
- if normalize:
1230
+ # limit intensity range
1147
1231
  img[img<0] = 0
1148
1232
  img[img>1] = 1
1149
1233
 
@@ -1186,6 +1270,20 @@ def gradient(volData):
1186
1270
  grad[grad>0]=1
1187
1271
  return grad
1188
1272
 
1273
+ @numba.jit(nopython=True)
1274
+ def scale_probabilities(final):
1275
+ zsh, ysh, xsh, nb_labels = final.shape
1276
+ for k in range(zsh):
1277
+ for l in range(ysh):
1278
+ for m in range(xsh):
1279
+ scale_factor = 0
1280
+ for n in range(nb_labels):
1281
+ scale_factor += final[k,l,m,n]
1282
+ scale_factor = max(1, scale_factor)
1283
+ for n in range(nb_labels):
1284
+ final[k,l,m,n] /= scale_factor
1285
+ return final
1286
+
1189
1287
  def predict_segmentation(bm, region_of_interest, channels, normalization_parameters):
1190
1288
 
1191
1289
  from mpi4py import MPI
@@ -1193,13 +1291,26 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1193
1291
  rank = comm.Get_rank()
1194
1292
  ngpus = comm.Get_size()
1195
1293
 
1294
+ # optional result paths
1295
+ if bm.path_to_image:
1296
+ filename, bm.extension = os.path.splitext(bm.path_to_final)
1297
+ if bm.extension == '.gz':
1298
+ bm.extension = '.nii.gz'
1299
+ filename = filename[:-4]
1300
+ path_to_cleaned = filename + '.cleaned' + bm.extension
1301
+ path_to_filled = filename + '.filled' + bm.extension
1302
+ path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
1303
+ path_to_refined = filename + '.refined' + bm.extension
1304
+ path_to_acwe = filename + '.acwe' + bm.extension
1305
+ path_to_probs = filename + '.probs.tif'
1306
+
1196
1307
  # Set the visible GPU by ID
1197
1308
  gpus = tf.config.experimental.list_physical_devices('GPU')
1198
1309
  if gpus:
1199
1310
  try:
1200
1311
  # Restrict TensorFlow to only use the specified GPU
1201
- tf.config.experimental.set_visible_devices(gpus[rank], 'GPU')
1202
- tf.config.experimental.set_memory_growth(gpus[rank], True)
1312
+ tf.config.experimental.set_visible_devices(gpus[rank % len(gpus)], 'GPU')
1313
+ tf.config.experimental.set_memory_growth(gpus[rank % len(gpus)], True)
1203
1314
  except RuntimeError as e:
1204
1315
  print(e)
1205
1316
 
@@ -1210,7 +1321,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1210
1321
  nb_labels = len(bm.allLabels)
1211
1322
  results['allLabels'] = bm.allLabels
1212
1323
 
1213
- # load model
1324
+ # custom objects
1214
1325
  if bm.dice_loss:
1215
1326
  def loss_fn(y_true, y_pred):
1216
1327
  dice = 0
@@ -1221,25 +1332,30 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1221
1332
  loss = 1 - dice
1222
1333
  return loss
1223
1334
  custom_objects = {'dice_coef_loss': dice_coef_loss,'loss_fn': loss_fn}
1224
- model = load_model(bm.path_to_model, custom_objects=custom_objects)
1335
+ elif bm.ignore_mask:
1336
+ custom_objects={'custom_loss': custom_loss}
1225
1337
  else:
1226
- model = load_model(bm.path_to_model)
1338
+ custom_objects=None
1339
+
1340
+ # load model
1341
+ model = load_model(bm.path_to_model, custom_objects=custom_objects)
1227
1342
 
1228
1343
  # check if data can be loaded blockwise to save host memory
1229
1344
  load_blockwise = False
1230
1345
  if not bm.scaling and not bm.normalize and bm.path_to_image and not np.any(region_of_interest) and \
1231
1346
  os.path.splitext(bm.path_to_image)[1] in ['.tif', '.tiff'] and not bm.acwe:
1347
+
1232
1348
  # get image shape
1233
1349
  tif = TiffFile(bm.path_to_image)
1234
1350
  zsh = len(tif.pages)
1235
1351
  ysh, xsh = tif.pages[0].shape
1236
1352
 
1237
1353
  # load mask
1238
- if bm.separation or bm.refinement:
1354
+ '''if bm.separation or bm.refinement:
1239
1355
  mask, _ = load_data(bm.mask)
1240
1356
  mask = mask.reshape(zsh, ysh, xsh, 1)
1241
1357
  mask, _, _, _ = append_ghost_areas(bm, mask)
1242
- mask = mask.reshape(mask.shape[:-1])
1358
+ mask = mask.reshape(mask.shape[:-1])'''
1243
1359
 
1244
1360
  # determine new image size after appending ghost areas to make image dimensions divisible by patch size
1245
1361
  z_rest = bm.z_patch - (zsh % bm.z_patch)
@@ -1259,7 +1375,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1259
1375
  xsh += x_rest
1260
1376
 
1261
1377
  # get Ids of patches
1262
- list_IDs = []
1378
+ '''list_IDs = []
1263
1379
  for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1264
1380
  for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1265
1381
  for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
@@ -1269,19 +1385,18 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1269
1385
  if centerLabel>0 and np.any(patch!=centerLabel):
1270
1386
  list_IDs.append(k*ysh*xsh+l*xsh+m)
1271
1387
  elif bm.refinement:
1272
- patch = mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1273
- if np.any(patch==0) and np.any(patch!=0):
1388
+ if np.any(mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
1274
1389
  list_IDs.append(k*ysh*xsh+l*xsh+m)
1275
1390
  else:
1276
- list_IDs.append(k*ysh*xsh+l*xsh+m)
1391
+ list_IDs.append(k*ysh*xsh+l*xsh+m)'''
1277
1392
 
1278
1393
  # make length of list divisible by batch size
1279
- max_i = len(list_IDs)
1394
+ '''max_i = len(list_IDs)
1280
1395
  rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
1281
- list_IDs = list_IDs + list_IDs[:rest]
1396
+ list_IDs = list_IDs + list_IDs[:rest]'''
1282
1397
 
1283
1398
  # prediction
1284
- if len(list_IDs) > 400:
1399
+ if zsh*ysh*xsh > 256**3:
1285
1400
  load_blockwise = True
1286
1401
 
1287
1402
  # load image data and calculate patch IDs
@@ -1289,7 +1404,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1289
1404
 
1290
1405
  # load prediction data
1291
1406
  img, bm.img_header, z_shape, y_shape, x_shape, region_of_interest, bm.img_data = load_prediction_data(
1292
- bm, channels, bm.normalize, normalization_parameters, region_of_interest, bm.img_data, bm.img_header)
1407
+ bm, channels, normalization_parameters, region_of_interest, bm.img_data, bm.img_header)
1293
1408
 
1294
1409
  # append ghost areas
1295
1410
  img, z_rest, y_rest, x_rest = append_ghost_areas(bm, img)
@@ -1315,6 +1430,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1315
1430
 
1316
1431
  # load all patches on GPU memory
1317
1432
  if not load_blockwise and nb_patches < 400:
1433
+ if rank==0:
1318
1434
 
1319
1435
  # parameters
1320
1436
  params = {'dim': (bm.z_patch, bm.y_patch, bm.x_patch),
@@ -1347,7 +1463,11 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1347
1463
 
1348
1464
  # allocate final probabilities array
1349
1465
  if rank==0 and bm.return_probs:
1350
- final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1466
+ if load_blockwise:
1467
+ if not os.path.exists(path_to_probs[:-4]):
1468
+ os.mkdir(path_to_probs[:-4])
1469
+ else:
1470
+ final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1351
1471
 
1352
1472
  # allocate final result array
1353
1473
  if rank==0:
@@ -1362,27 +1482,38 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1362
1482
  else:
1363
1483
  nprocs = ngpus
1364
1484
  if j % ngpus == rank:
1485
+
1365
1486
  # load blockwise from TIFF
1366
1487
  if load_blockwise:
1367
1488
  img, _, _, _, _, _, _ = load_prediction_data(bm,
1368
- channels, bm.normalize, normalization_parameters,
1489
+ channels, normalization_parameters,
1369
1490
  region_of_interest, bm.img_data, bm.img_header, load_blockwise, z)
1370
1491
  img, _, _, _ = append_ghost_areas(bm, img)
1371
1492
 
1493
+ # load mask block
1494
+ if bm.separation or bm.refinement:
1495
+ mask = imread(bm.mask, key=range(z,min(len(tif.pages),z+bm.z_patch)))
1496
+ # pad zeros to make dimensions divisible by patch dimensions
1497
+ pad_z = bm.z_patch - mask.shape[0]
1498
+ pad_y = (bm.y_patch - (mask.shape[1] % bm.y_patch)) % bm.y_patch
1499
+ pad_x = (bm.x_patch - (mask.shape[2] % bm.x_patch)) % bm.x_patch
1500
+ pad_width = [(0, pad_z), (0, pad_y), (0, pad_x)]
1501
+ mask = np.pad(mask, pad_width, mode='constant', constant_values=0)
1502
+
1372
1503
  # list of IDs
1373
1504
  list_IDs_block = []
1374
1505
 
1375
1506
  # get Ids of patches
1507
+ k = 0 if load_blockwise else z
1376
1508
  for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1377
1509
  for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1378
1510
  if bm.separation:
1379
- centerLabel = mask[z+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
1380
- patch = mask[z:z+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1511
+ centerLabel = mask[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
1512
+ patch = mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1381
1513
  if centerLabel>0 and np.any(patch!=centerLabel):
1382
1514
  list_IDs_block.append(z*ysh*xsh+l*xsh+m)
1383
1515
  elif bm.refinement:
1384
- patch = mask[z:z+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1385
- if np.any(patch==0) and np.any(patch!=0):
1516
+ if np.any(mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
1386
1517
  list_IDs_block.append(z*ysh*xsh+l*xsh+m)
1387
1518
  else:
1388
1519
  list_IDs_block.append(z*ysh*xsh+l*xsh+m)
@@ -1414,10 +1545,11 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1414
1545
  # get patch
1415
1546
  tmp_X = img[k:k+bm.z_patch,l:l+bm.y_patch,m:m+bm.x_patch]
1416
1547
  if bm.patch_normalization:
1417
- tmp_X = np.copy(tmp_X, order='C')
1418
- for c in range(channels):
1419
- tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
1420
- tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
1548
+ tmp_X = tmp_X.copy().astype(np.float32)
1549
+ for ch in range(channels):
1550
+ mean, std = welford_mean_std(tmp_X[...,ch])
1551
+ tmp_X[...,ch] -= mean
1552
+ tmp_X[...,ch] /= max(std, 1e-6)
1421
1553
  X[i] = tmp_X
1422
1554
 
1423
1555
  # predict batch
@@ -1461,19 +1593,28 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1461
1593
  # overlap in z direction
1462
1594
  if bm.stride_size < bm.z_patch:
1463
1595
  if j+source>0:
1464
- probs[:bm.stride_size] += overlap
1596
+ probs[:-bm.stride_size] += overlap
1465
1597
  overlap = probs[bm.stride_size:].copy()
1466
1598
 
1467
- # calculate result
1599
+ # block z dimension
1468
1600
  block_z = z_indices[j+source]
1469
- if j+source==len(z_indices)-1:
1470
- label[block_z:block_z+bm.z_patch] = np.argmax(probs, axis=-1).astype(np.uint8)
1471
- if bm.return_probs:
1472
- final[block_z:block_z+bm.z_patch] = probs
1601
+ if j+source==len(z_indices)-1: # last block
1602
+ block_zsh = bm.z_patch
1603
+ block_z_rest = z_rest if z_rest>0 else -block_zsh
1473
1604
  else:
1474
1605
  block_zsh = min(bm.stride_size, bm.z_patch)
1475
- label[block_z:block_z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
1476
- if bm.return_probs:
1606
+ block_z_rest = -block_zsh
1607
+
1608
+ # calculate result
1609
+ label[block_z:block_z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
1610
+
1611
+ # return probabilities
1612
+ if bm.return_probs:
1613
+ if load_blockwise:
1614
+ block_output = scale_probabilities(probs[:block_zsh])
1615
+ block_output = block_output[:-block_z_rest,:-y_rest,:-x_rest]
1616
+ imwrite(path_to_probs[:-4] + f"/block-{j+source}.tif", block_output)
1617
+ else:
1477
1618
  final[block_z:block_z+block_zsh] = probs[:block_zsh]
1478
1619
  else:
1479
1620
  for i in range(bm.z_patch):
@@ -1481,7 +1622,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1481
1622
  if rank==0:
1482
1623
 
1483
1624
  # refine mask data with result
1484
- if bm.refinement:
1625
+ '''if bm.refinement:
1485
1626
  # loop over boundary patches
1486
1627
  for i, ID in enumerate(list_IDs):
1487
1628
  if i < max_i:
@@ -1490,25 +1631,17 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1490
1631
  l = rest // xsh
1491
1632
  m = rest % xsh
1492
1633
  mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] = label[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
1493
- label = mask
1634
+ label = mask'''
1494
1635
 
1495
- # remove appendix
1496
- if bm.return_probs:
1636
+ # remove ghost areas
1637
+ if bm.return_probs and not load_blockwise:
1497
1638
  final = final[:-z_rest,:-y_rest,:-x_rest]
1498
1639
  label = label[:-z_rest,:-y_rest,:-x_rest]
1499
1640
  zsh, ysh, xsh = label.shape
1500
1641
 
1501
1642
  # return probabilities
1502
- if bm.return_probs:
1503
- counter = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
1504
- nb = 0
1505
- for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
1506
- for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
1507
- for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
1508
- counter[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += 1
1509
- nb += 1
1510
- counter[counter==0] = 1
1511
- probabilities = final / counter
1643
+ if bm.return_probs and not load_blockwise:
1644
+ probabilities = scale_probabilities(final)
1512
1645
  if bm.scaling:
1513
1646
  probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
1514
1647
  if np.any(region_of_interest):
@@ -1571,17 +1704,8 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
1571
1704
  # save result
1572
1705
  if bm.path_to_image:
1573
1706
  save_data(bm.path_to_final, label, header=bm.header, compress=bm.compression)
1574
-
1575
- # paths to optional results
1576
- filename, bm.extension = os.path.splitext(bm.path_to_final)
1577
- if bm.extension == '.gz':
1578
- bm.extension = '.nii.gz'
1579
- filename = filename[:-4]
1580
- path_to_cleaned = filename + '.cleaned' + bm.extension
1581
- path_to_filled = filename + '.filled' + bm.extension
1582
- path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
1583
- path_to_refined = filename + '.refined' + bm.extension
1584
- path_to_acwe = filename + '.acwe' + bm.extension
1707
+ if bm.return_probs and not load_blockwise:
1708
+ imwrite(path_to_probs, probabilities)
1585
1709
 
1586
1710
  # remove outliers
1587
1711
  if bm.clean: