biomedisa 24.8.11__py3-none-any.whl → 25.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- biomedisa/deeplearning.py +27 -8
- biomedisa/features/DataGenerator.py +93 -144
- biomedisa/features/PredictDataGenerator.py +7 -5
- biomedisa/features/biomedisa_helper.py +67 -25
- biomedisa/features/crop_helper.py +7 -7
- biomedisa/features/keras_helper.py +276 -151
- biomedisa/features/random_walk/pyopencl_large.py +1 -1
- biomedisa/features/random_walk/pyopencl_small.py +3 -3
- biomedisa/features/random_walk/rw_large.py +6 -2
- biomedisa/features/random_walk/rw_small.py +7 -3
- biomedisa/interpolation.py +3 -1
- {biomedisa-24.8.11.dist-info → biomedisa-25.7.1.dist-info}/METADATA +6 -4
- {biomedisa-24.8.11.dist-info → biomedisa-25.7.1.dist-info}/RECORD +16 -16
- {biomedisa-24.8.11.dist-info → biomedisa-25.7.1.dist-info}/WHEEL +1 -1
- {biomedisa-24.8.11.dist-info → biomedisa-25.7.1.dist-info/licenses}/LICENSE +0 -0
- {biomedisa-24.8.11.dist-info → biomedisa-25.7.1.dist-info}/top_level.txt +0 -0
@@ -39,13 +39,13 @@ from tensorflow.keras.layers import (
|
|
39
39
|
from tensorflow.keras import backend as K
|
40
40
|
from tensorflow.keras.utils import to_categorical
|
41
41
|
from tensorflow.keras.callbacks import Callback, ModelCheckpoint, EarlyStopping
|
42
|
-
from biomedisa.features.DataGenerator import DataGenerator
|
42
|
+
from biomedisa.features.DataGenerator import DataGenerator, welford_mean_std
|
43
43
|
from biomedisa.features.PredictDataGenerator import PredictDataGenerator
|
44
|
-
from biomedisa.features.biomedisa_helper import (unique,
|
44
|
+
from biomedisa.features.biomedisa_helper import (unique, welford_mean_std,
|
45
45
|
img_resize, load_data, save_data, set_labels_to_zero, id_generator, unique_file_path)
|
46
46
|
from biomedisa.features.remove_outlier import clean, fill
|
47
47
|
from biomedisa.features.active_contour import activeContour
|
48
|
-
from tifffile import TiffFile, imread
|
48
|
+
from tifffile import TiffFile, imread, imwrite
|
49
49
|
import matplotlib.pyplot as plt
|
50
50
|
import SimpleITK as sitk
|
51
51
|
import tensorflow as tf
|
@@ -220,11 +220,11 @@ def compute_position(position, zsh, ysh, xsh):
|
|
220
220
|
position[k,l,m] = x+y+z
|
221
221
|
return position
|
222
222
|
|
223
|
-
def make_conv_block(nb_filters, input_tensor, block):
|
223
|
+
def make_conv_block(nb_filters, input_tensor, block, dtype):
|
224
224
|
def make_stage(input_tensor, stage):
|
225
225
|
name = 'conv_{}_{}'.format(block, stage)
|
226
226
|
x = Conv3D(nb_filters, (3, 3, 3), activation='relu',
|
227
|
-
padding='same', name=name, data_format="channels_last")(input_tensor)
|
227
|
+
padding='same', name=name, data_format="channels_last", dtype=dtype)(input_tensor)
|
228
228
|
name = 'batch_norm_{}_{}'.format(block, stage)
|
229
229
|
try:
|
230
230
|
x = BatchNormalization(name=name, synchronized=True)(x)
|
@@ -266,56 +266,65 @@ def make_conv_block_resnet(nb_filters, input_tensor, block):
|
|
266
266
|
|
267
267
|
return out
|
268
268
|
|
269
|
-
def make_unet(input_shape, nb_labels
|
269
|
+
def make_unet(bm, input_shape, nb_labels):
|
270
|
+
# enable mixed_precision
|
271
|
+
if bm.mixed_precision:
|
272
|
+
dtype = "float16"
|
273
|
+
else:
|
274
|
+
dtype = "float32"
|
270
275
|
|
276
|
+
# input
|
271
277
|
nb_plans, nb_rows, nb_cols, _ = input_shape
|
278
|
+
inputs = Input(input_shape, dtype=dtype)
|
272
279
|
|
273
|
-
|
274
|
-
|
275
|
-
filters = filters.split('-')
|
280
|
+
# configure number of layers and filters
|
281
|
+
filters = bm.network_filters.split('-')
|
276
282
|
filters = np.array(filters, dtype=int)
|
277
283
|
latent_space_size = filters[-1]
|
278
284
|
filters = filters[:-1]
|
285
|
+
|
286
|
+
# initialize blocks
|
279
287
|
convs = []
|
280
288
|
|
289
|
+
# encoder
|
281
290
|
i = 1
|
282
291
|
for f in filters:
|
283
292
|
if i==1:
|
284
|
-
if resnet:
|
293
|
+
if bm.resnet:
|
285
294
|
conv = make_conv_block_resnet(f, inputs, i)
|
286
295
|
else:
|
287
|
-
conv = make_conv_block(f, inputs, i)
|
296
|
+
conv = make_conv_block(f, inputs, i, dtype)
|
288
297
|
else:
|
289
|
-
if resnet:
|
298
|
+
if bm.resnet:
|
290
299
|
conv = make_conv_block_resnet(f, pool, i)
|
291
300
|
else:
|
292
|
-
conv = make_conv_block(f, pool, i)
|
301
|
+
conv = make_conv_block(f, pool, i, dtype)
|
293
302
|
pool = MaxPooling3D(pool_size=(2, 2, 2))(conv)
|
294
303
|
convs.append(conv)
|
295
304
|
i += 1
|
296
305
|
|
297
|
-
|
306
|
+
# latent space
|
307
|
+
if bm.resnet:
|
298
308
|
conv = make_conv_block_resnet(latent_space_size, pool, i)
|
299
309
|
else:
|
300
|
-
conv = make_conv_block(latent_space_size, pool, i)
|
310
|
+
conv = make_conv_block(latent_space_size, pool, i, dtype)
|
301
311
|
i += 1
|
302
312
|
|
313
|
+
# decoder
|
303
314
|
for k, f in enumerate(filters[::-1]):
|
304
315
|
up = Concatenate()([UpSampling3D(size=(2, 2, 2))(conv), convs[-(k+1)]])
|
305
|
-
if resnet:
|
316
|
+
if bm.resnet:
|
306
317
|
conv = make_conv_block_resnet(f, up, i)
|
307
318
|
else:
|
308
|
-
conv = make_conv_block(f, up, i)
|
319
|
+
conv = make_conv_block(f, up, i, dtype)
|
309
320
|
i += 1
|
310
321
|
|
322
|
+
# final layer and output
|
311
323
|
conv = Conv3D(nb_labels, (1, 1, 1), name=f'conv_{i}_1')(conv)
|
312
|
-
|
313
324
|
x = Reshape((nb_plans * nb_rows * nb_cols, nb_labels))(conv)
|
314
325
|
x = Activation('softmax')(x)
|
315
326
|
outputs = Reshape((nb_plans, nb_rows, nb_cols, nb_labels))(x)
|
316
|
-
|
317
327
|
model = Model(inputs=inputs, outputs=outputs)
|
318
|
-
|
319
328
|
return model
|
320
329
|
|
321
330
|
def get_labels(arr, allLabels):
|
@@ -415,6 +424,13 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
415
424
|
print(f'{os.path.basename(label_names[0])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
|
416
425
|
label = img_resize(label, bm.z_scale, bm.y_scale, bm.x_scale, labels=True)
|
417
426
|
|
427
|
+
# label channel must be 1 or 2 if using ignore mask
|
428
|
+
if len(label.shape)>3 and label.shape[3]>1 and not bm.ignore_mask:
|
429
|
+
InputError.message = 'Training labels must have one channel (gray values).'
|
430
|
+
raise InputError()
|
431
|
+
if len(label.shape)==3:
|
432
|
+
label = label.reshape(label.shape[0], label.shape[1], label.shape[2], 1)
|
433
|
+
|
418
434
|
# if header is not single data stream Amira Mesh falling back to Multi-TIFF
|
419
435
|
if extension != '.am':
|
420
436
|
extension, header = '.tif', None
|
@@ -424,7 +440,7 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
424
440
|
else:
|
425
441
|
header = header[0]
|
426
442
|
|
427
|
-
# load first
|
443
|
+
# load first image
|
428
444
|
if any(img_list):
|
429
445
|
img, _ = load_data(img_names[0], 'first_queue')
|
430
446
|
if img is None:
|
@@ -436,14 +452,15 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
436
452
|
else:
|
437
453
|
img = img_in
|
438
454
|
img_names = ['img_1']
|
439
|
-
|
455
|
+
|
456
|
+
# label and image dimensions must match
|
457
|
+
if label_dim[:3] != img.shape[:3]:
|
440
458
|
InputError.message = f'Dimensions of "{os.path.basename(img_names[0])}" and "{os.path.basename(label_names[0])}" do not match'
|
441
459
|
raise InputError()
|
442
460
|
|
443
|
-
#
|
461
|
+
# image channels must be >=1
|
444
462
|
if len(img.shape)==3:
|
445
|
-
|
446
|
-
img = img.reshape(z_shape, y_shape, x_shape, 1)
|
463
|
+
img = img.reshape(img.shape[0], img.shape[1], img.shape[2], 1)
|
447
464
|
if channels is None:
|
448
465
|
channels = img.shape[3]
|
449
466
|
if channels != img.shape[3]:
|
@@ -462,30 +479,30 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
462
479
|
# scale data to the range from 0 to 1
|
463
480
|
if not bm.patch_normalization:
|
464
481
|
img = img.astype(np.float32)
|
465
|
-
for
|
466
|
-
img[
|
467
|
-
img[
|
482
|
+
for ch in range(channels):
|
483
|
+
img[...,ch] -= np.amin(img[...,ch])
|
484
|
+
img[...,ch] /= np.amax(img[...,ch])
|
468
485
|
|
469
486
|
# normalize first validation image
|
470
487
|
if bm.normalize and np.any(normalization_parameters):
|
471
488
|
img = img.astype(np.float32)
|
472
|
-
for
|
473
|
-
mean, std =
|
474
|
-
img[
|
475
|
-
img[
|
489
|
+
for ch in range(channels):
|
490
|
+
mean, std = welford_mean_std(img[...,ch])
|
491
|
+
img[...,ch] = (img[...,ch] - mean) / std
|
492
|
+
img[...,ch] = img[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
|
476
493
|
|
477
494
|
# get normalization parameters from first image
|
478
495
|
if normalization_parameters is None:
|
479
496
|
normalization_parameters = np.zeros((2,channels))
|
480
497
|
if bm.normalize:
|
481
|
-
for
|
482
|
-
normalization_parameters[
|
483
|
-
normalization_parameters[1,c] = np.std(img[:,:,:,c])
|
498
|
+
for ch in range(channels):
|
499
|
+
normalization_parameters[:,ch] = welford_mean_std(img[...,ch])
|
484
500
|
|
485
501
|
# pad data
|
486
502
|
if not bm.scaling:
|
487
503
|
img_data_list = [img]
|
488
504
|
label_data_list = [label]
|
505
|
+
img_dtype = img.dtype
|
489
506
|
# no-scaling for list of images needs negative values as it encodes padded areas as -1
|
490
507
|
label_dtype = label.dtype
|
491
508
|
if label_dtype==np.uint8:
|
@@ -499,7 +516,7 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
499
516
|
|
500
517
|
for k in range(1, number_of_images):
|
501
518
|
|
502
|
-
#
|
519
|
+
# load label data and pre-process
|
503
520
|
if any(label_list):
|
504
521
|
a, _ = load_data(label_names[k], 'first_queue')
|
505
522
|
if a is None:
|
@@ -516,11 +533,21 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
516
533
|
label_values, counts = unique(a, return_counts=True)
|
517
534
|
print(f'{os.path.basename(label_names[k])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
|
518
535
|
a = img_resize(a, bm.z_scale, bm.y_scale, bm.x_scale, labels=True)
|
536
|
+
|
537
|
+
# label channel must be 1 or 2 if using ignore mask
|
538
|
+
if len(a.shape)>3 and a.shape[3]>1 and not bm.ignore_mask:
|
539
|
+
InputError.message = 'Training labels must have one channel (gray values).'
|
540
|
+
raise InputError()
|
541
|
+
if len(a.shape)==3:
|
542
|
+
a = a.reshape(a.shape[0], a.shape[1], a.shape[2], 1)
|
543
|
+
|
544
|
+
# append label data
|
545
|
+
if bm.scaling:
|
519
546
|
label = np.append(label, a, axis=0)
|
520
547
|
else:
|
521
548
|
label_data_list.append(a)
|
522
549
|
|
523
|
-
#
|
550
|
+
# load image data and pre-process
|
524
551
|
if any(img_list):
|
525
552
|
a, _ = load_data(img_names[k], 'first_queue')
|
526
553
|
if a is None:
|
@@ -528,12 +555,11 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
528
555
|
raise InputError()
|
529
556
|
else:
|
530
557
|
a = img_in[k]
|
531
|
-
if label_dim != a.shape[:3]:
|
558
|
+
if label_dim[:3] != a.shape[:3]:
|
532
559
|
InputError.message = f'Dimensions of "{os.path.basename(img_names[k])}" and "{os.path.basename(label_names[k])}" do not match'
|
533
560
|
raise InputError()
|
534
561
|
if len(a.shape)==3:
|
535
|
-
|
536
|
-
a = a.reshape(z_shape, y_shape, x_shape, 1)
|
562
|
+
a = a.reshape(a.shape[0], a.shape[1], a.shape[2], 1)
|
537
563
|
if a.shape[3] != channels:
|
538
564
|
InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
|
539
565
|
raise InputError()
|
@@ -544,15 +570,17 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
544
570
|
a = img_resize(a, bm.z_scale, bm.y_scale, bm.x_scale)
|
545
571
|
if not bm.patch_normalization:
|
546
572
|
a = a.astype(np.float32)
|
547
|
-
for
|
548
|
-
a[
|
549
|
-
a[
|
573
|
+
for ch in range(channels):
|
574
|
+
a[...,ch] -= np.amin(a[...,ch])
|
575
|
+
a[...,ch] /= np.amax(a[...,ch])
|
550
576
|
if bm.normalize:
|
551
577
|
a = a.astype(np.float32)
|
552
|
-
for
|
553
|
-
mean, std =
|
554
|
-
a[
|
555
|
-
a[
|
578
|
+
for ch in range(channels):
|
579
|
+
mean, std = welford_mean_std(a[...,ch])
|
580
|
+
a[...,ch] = (a[...,ch] - mean) / std
|
581
|
+
a[...,ch] = a[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
|
582
|
+
|
583
|
+
# append image data
|
556
584
|
if bm.scaling:
|
557
585
|
img = np.append(img, a, axis=0)
|
558
586
|
else:
|
@@ -564,15 +592,14 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
564
592
|
for img in img_data_list:
|
565
593
|
target_y = max(target_y, img.shape[1])
|
566
594
|
target_x = max(target_x, img.shape[2])
|
567
|
-
img = np.empty((0, target_y, target_x, channels), dtype=
|
568
|
-
label = np.empty((0, target_y, target_x), dtype=label_dtype)
|
595
|
+
img = np.empty((0, target_y, target_x, channels), dtype=img_dtype)
|
596
|
+
label = np.empty((0, target_y, target_x, 2 if bm.ignore_mask else 1), dtype=label_dtype)
|
569
597
|
for k in range(len(img_data_list)):
|
570
598
|
pad_y = target_y - img_data_list[k].shape[1]
|
571
599
|
pad_x = target_x - img_data_list[k].shape[2]
|
572
600
|
pad_width = [(0, 0), (0, pad_y), (0, pad_x), (0, 0)]
|
573
601
|
tmp = np.pad(img_data_list[k], pad_width, mode='constant', constant_values=0)
|
574
602
|
img = np.append(img, tmp, axis=0)
|
575
|
-
pad_width = [(0, 0), (0, pad_y), (0, pad_x)]
|
576
603
|
tmp = np.pad(label_data_list[k].astype(label_dtype), pad_width, mode='constant', constant_values=-1)
|
577
604
|
label = np.append(label, tmp, axis=0)
|
578
605
|
|
@@ -586,13 +613,13 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
586
613
|
else:
|
587
614
|
# get labels
|
588
615
|
if allLabels is None:
|
589
|
-
allLabels = unique(label)
|
616
|
+
allLabels = unique(label[...,0])
|
590
617
|
index = np.argwhere(allLabels<0)
|
591
618
|
allLabels = np.delete(allLabels, index)
|
592
619
|
|
593
620
|
# labels must be in ascending order
|
594
621
|
for k, l in enumerate(allLabels):
|
595
|
-
label[label==l] = k
|
622
|
+
label[...,0][label[...,0]==l] = k
|
596
623
|
|
597
624
|
return img, label, allLabels, normalization_parameters, header, extension, channels
|
598
625
|
|
@@ -724,10 +751,11 @@ class Metrics(Callback):
|
|
724
751
|
m = rest % self.dim_img[2]
|
725
752
|
tmp_X = self.img[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]]
|
726
753
|
if self.patch_normalization:
|
727
|
-
tmp_X =
|
728
|
-
for
|
729
|
-
|
730
|
-
tmp_X[
|
754
|
+
tmp_X = tmp_X.copy().astype(np.float32)
|
755
|
+
for ch in range(self.n_channels):
|
756
|
+
mean, std = welford_mean_std(tmp_X[...,ch])
|
757
|
+
tmp_X[...,ch] -= mean
|
758
|
+
tmp_X[...,ch] /= max(std, 1e-6)
|
731
759
|
X_val[i] = tmp_X
|
732
760
|
|
733
761
|
# Prediction segmentation
|
@@ -752,6 +780,7 @@ class Metrics(Callback):
|
|
752
780
|
# get result
|
753
781
|
result = np.argmax(result, axis=-1)
|
754
782
|
result = result.astype(np.uint8)
|
783
|
+
result = result.reshape(*result.shape, 1)
|
755
784
|
|
756
785
|
# calculate standard accuracy
|
757
786
|
if not self.train:
|
@@ -771,17 +800,17 @@ class Metrics(Callback):
|
|
771
800
|
logs['dice'] = dice
|
772
801
|
else:
|
773
802
|
# save best model only
|
774
|
-
if epoch == 0 or
|
803
|
+
if epoch == 0 or dice > max(self.history['val_dice']):
|
775
804
|
self.model.save(str(self.path_to_model))
|
776
805
|
|
777
806
|
# add accuracy to history
|
778
|
-
self.history['loss'].append(
|
779
|
-
self.history['accuracy'].append(
|
807
|
+
self.history['loss'].append(logs['loss'])
|
808
|
+
self.history['accuracy'].append(logs['accuracy'])
|
780
809
|
if self.train_dice:
|
781
|
-
self.history['dice'].append(
|
782
|
-
self.history['val_accuracy'].append(
|
783
|
-
self.history['val_dice'].append(
|
784
|
-
self.history['val_loss'].append(
|
810
|
+
self.history['dice'].append(logs['dice'])
|
811
|
+
self.history['val_accuracy'].append(accuracy)
|
812
|
+
self.history['val_dice'].append(dice)
|
813
|
+
self.history['val_loss'].append(val_loss)
|
785
814
|
|
786
815
|
# tensorflow monitoring variables
|
787
816
|
logs['val_loss'] = val_loss
|
@@ -798,11 +827,11 @@ class Metrics(Callback):
|
|
798
827
|
|
799
828
|
# print accuracies
|
800
829
|
print('\nValidation history:')
|
801
|
-
print(
|
830
|
+
print("train_acc: [" + " ".join(f"{x:.4f}" for x in self.history['accuracy']) + "]")
|
802
831
|
if self.train_dice:
|
803
|
-
print(
|
804
|
-
print(
|
805
|
-
print(
|
832
|
+
print("train_dice: [" + " ".join(f"{x:.4f}" for x in self.history['dice']) + "]")
|
833
|
+
print("val_acc: [" + " ".join(f"{x:.4f}" for x in self.history['val_accuracy']) + "]")
|
834
|
+
print("val_dice: [" + " ".join(f"{x:.4f}" for x in self.history['val_dice']) + "]")
|
806
835
|
print('')
|
807
836
|
|
808
837
|
# early stopping
|
@@ -849,13 +878,13 @@ def categorical_crossentropy(true_labels, predicted_probs):
|
|
849
878
|
# Clip predicted probabilities to avoid log(0) issues
|
850
879
|
predicted_probs = np.clip(predicted_probs, 1e-7, 1 - 1e-7)
|
851
880
|
predicted_probs = -np.log(predicted_probs)
|
852
|
-
zsh,ysh,xsh = true_labels.shape
|
881
|
+
zsh, ysh, xsh, _ = true_labels.shape
|
853
882
|
# Calculate categorical crossentropy
|
854
883
|
loss = 0
|
855
884
|
for z in range(zsh):
|
856
885
|
for y in range(ysh):
|
857
886
|
for x in range(xsh):
|
858
|
-
l = true_labels[z,y,x]
|
887
|
+
l = true_labels[z,y,x,0]
|
859
888
|
loss += predicted_probs[z,y,x,l]
|
860
889
|
loss = loss / float(zsh*ysh*xsh)
|
861
890
|
return loss
|
@@ -879,6 +908,42 @@ def dice_coef_loss(nb_labels):
|
|
879
908
|
return loss
|
880
909
|
return loss_fn
|
881
910
|
|
911
|
+
def custom_loss(y_true, y_pred):
|
912
|
+
# Extract labels and ignore mask
|
913
|
+
labels = tf.cast(y_true[..., 0], tf.int32) # First channel contains class labels
|
914
|
+
ignore_mask = tf.cast(y_true[..., 1], tf.float32) # Second channel contains mask (0 = ignore, 1 = include)
|
915
|
+
|
916
|
+
# Convert integer labels to one-hot encoding
|
917
|
+
y_true_one_hot = tf.one_hot(labels, depth=2)
|
918
|
+
|
919
|
+
# Clip y_pred to avoid log(0)
|
920
|
+
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0)
|
921
|
+
|
922
|
+
# Compute categorical cross-entropy
|
923
|
+
loss = -tf.reduce_sum(y_true_one_hot * tf.math.log(y_pred), axis=-1)
|
924
|
+
|
925
|
+
# Apply ignore mask (ignore = 0 → loss is zero, include = 1 → loss is counted)
|
926
|
+
loss = loss * ignore_mask
|
927
|
+
|
928
|
+
# Return mean loss over valid (non-ignored) samples
|
929
|
+
return tf.reduce_sum(loss) / tf.reduce_sum(ignore_mask)
|
930
|
+
|
931
|
+
def custom_accuracy(y_true, y_pred):
|
932
|
+
labels = tf.cast(y_true[..., 0], tf.int32) # Extract actual values
|
933
|
+
ignore_mask = y_true[..., 1] # Extract mask (1 = include, 0 = ignore)
|
934
|
+
|
935
|
+
# Convert predictions to discrete values (assuming regression: round values)
|
936
|
+
y_pred_class = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
|
937
|
+
|
938
|
+
# Compute correct predictions (1 where correct, 0 where incorrect)
|
939
|
+
correct_predictions = tf.cast(tf.equal(labels, y_pred_class), tf.float32)
|
940
|
+
|
941
|
+
# Apply ignore mask
|
942
|
+
masked_correct_predictions = correct_predictions * ignore_mask
|
943
|
+
|
944
|
+
# Compute accuracy only over valid (non-ignored) pixels
|
945
|
+
return tf.reduce_sum(masked_correct_predictions) / tf.reduce_sum(ignore_mask)
|
946
|
+
|
882
947
|
def train_segmentation(bm):
|
883
948
|
|
884
949
|
# training data
|
@@ -988,18 +1053,19 @@ def train_segmentation(bm):
|
|
988
1053
|
'n_channels': bm.channels,
|
989
1054
|
'augment': (bm.flip_x, bm.flip_y, bm.flip_z, bm.swapaxes, bm.rotate, bm.rotate3d),
|
990
1055
|
'patch_normalization': bm.patch_normalization,
|
991
|
-
'separation': bm.separation
|
1056
|
+
'separation': bm.separation,
|
1057
|
+
'ignore_mask': bm.ignore_mask}
|
992
1058
|
|
993
1059
|
# data generator
|
994
1060
|
validation_generator = None
|
995
|
-
training_generator = DataGenerator(bm.img_data, bm.label_data, list_IDs_fg, list_IDs_bg, True, True,
|
1061
|
+
training_generator = DataGenerator(bm.img_data, bm.label_data, list_IDs_fg, list_IDs_bg, True, True, **params)
|
996
1062
|
if bm.val_img_data is not None:
|
997
1063
|
if bm.val_dice:
|
998
1064
|
val_metrics = Metrics(bm, bm.val_img_data, bm.val_label_data, list_IDs_val_fg, (zsh_val, ysh_val, xsh_val), nb_labels, False)
|
999
1065
|
else:
|
1000
1066
|
params['dim_img'] = (zsh_val, ysh_val, xsh_val)
|
1001
1067
|
params['augment'] = (False, False, False, False, 0, False)
|
1002
|
-
validation_generator = DataGenerator(bm.val_img_data, bm.val_label_data, list_IDs_val_fg, list_IDs_val_bg, True, False,
|
1068
|
+
validation_generator = DataGenerator(bm.val_img_data, bm.val_label_data, list_IDs_val_fg, list_IDs_val_bg, True, False, **params)
|
1003
1069
|
|
1004
1070
|
# monitor dice score on training data
|
1005
1071
|
if bm.train_dice:
|
@@ -1017,7 +1083,7 @@ def train_segmentation(bm):
|
|
1017
1083
|
with strategy.scope():
|
1018
1084
|
|
1019
1085
|
# build model
|
1020
|
-
model = make_unet(input_shape, nb_labels
|
1086
|
+
model = make_unet(bm, input_shape, nb_labels)
|
1021
1087
|
model.summary()
|
1022
1088
|
|
1023
1089
|
# pretrained model
|
@@ -1036,13 +1102,28 @@ def train_segmentation(bm):
|
|
1036
1102
|
layer.trainable = False
|
1037
1103
|
|
1038
1104
|
# optimizer
|
1039
|
-
|
1105
|
+
optimizer = SGD(learning_rate=bm.learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
|
1106
|
+
#optimizer = tf.keras.optimizers.Adam(learning_rate=bm.learning_rate, epsilon=1e-4) lr=0.0001
|
1107
|
+
if bm.mixed_precision:
|
1108
|
+
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer, dynamic=False, initial_scale=128)
|
1109
|
+
|
1110
|
+
# Rename the function to appear as "accuracy" in logs
|
1111
|
+
if bm.ignore_mask:
|
1112
|
+
custom_accuracy.__name__ = "accuracy"
|
1113
|
+
metrics=[custom_accuracy]
|
1114
|
+
else:
|
1115
|
+
metrics=['accuracy']
|
1116
|
+
|
1117
|
+
# loss function
|
1118
|
+
if bm.ignore_mask:
|
1119
|
+
loss=custom_loss
|
1120
|
+
else:
|
1121
|
+
loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
|
1040
1122
|
|
1041
1123
|
# comile model
|
1042
|
-
loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
|
1043
1124
|
model.compile(loss=loss,
|
1044
|
-
optimizer=
|
1045
|
-
metrics=
|
1125
|
+
optimizer=optimizer,
|
1126
|
+
metrics=metrics)
|
1046
1127
|
|
1047
1128
|
# save meta data
|
1048
1129
|
meta_data = MetaData(bm.path_to_model, configuration_data, allLabels,
|
@@ -1081,7 +1162,7 @@ def train_segmentation(bm):
|
|
1081
1162
|
callbacks=callbacks,
|
1082
1163
|
workers=bm.workers)
|
1083
1164
|
|
1084
|
-
def load_prediction_data(bm, channels,
|
1165
|
+
def load_prediction_data(bm, channels, normalization_parameters,
|
1085
1166
|
region_of_interest, img, img_header, load_blockwise=False, z=None):
|
1086
1167
|
|
1087
1168
|
# read image data
|
@@ -1109,10 +1190,9 @@ def load_prediction_data(bm, channels, normalize, normalization_parameters,
|
|
1109
1190
|
if bm.acwe:
|
1110
1191
|
img_data = img.copy()
|
1111
1192
|
|
1112
|
-
#
|
1193
|
+
# image data must have number of channels >=1
|
1113
1194
|
if len(img.shape)==3:
|
1114
|
-
|
1115
|
-
img = img.reshape(z_shape, y_shape, x_shape, 1)
|
1195
|
+
img = img.reshape(img.shape[0], img.shape[1], img.shape[2], 1)
|
1116
1196
|
if img.shape[3] != channels:
|
1117
1197
|
InputError.message = f'Number of channels must be {channels}.'
|
1118
1198
|
raise InputError()
|
@@ -1127,22 +1207,27 @@ def load_prediction_data(bm, channels, normalize, normalization_parameters,
|
|
1127
1207
|
region_of_interest = np.array([min_z,max_z,min_y,max_y,min_x,max_x,z_shape,y_shape,x_shape])
|
1128
1208
|
z_shape, y_shape, x_shape = max_z-min_z, max_y-min_y, max_x-min_x
|
1129
1209
|
|
1130
|
-
#
|
1131
|
-
img = img.astype(np.float32)
|
1210
|
+
# resize image data
|
1132
1211
|
if bm.scaling:
|
1212
|
+
img = img.astype(np.float32)
|
1133
1213
|
img = img_resize(img, bm.z_scale, bm.y_scale, bm.x_scale)
|
1134
1214
|
|
1215
|
+
# scale image data
|
1216
|
+
if not bm.patch_normalization:
|
1217
|
+
img = img.astype(np.float32)
|
1218
|
+
for ch in range(channels):
|
1219
|
+
img[...,ch] -= np.amin(img[...,ch])
|
1220
|
+
img[...,ch] /= np.amax(img[...,ch])
|
1221
|
+
|
1135
1222
|
# normalize image data
|
1136
|
-
|
1137
|
-
img
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
img[
|
1142
|
-
img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
|
1223
|
+
if bm.normalize:
|
1224
|
+
img = img.astype(np.float32)
|
1225
|
+
for ch in range(channels):
|
1226
|
+
mean, std = welford_mean_std(img[...,ch])
|
1227
|
+
img[...,ch] = (img[...,ch] - mean) / std
|
1228
|
+
img[...,ch] = img[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
|
1143
1229
|
|
1144
|
-
|
1145
|
-
if normalize:
|
1230
|
+
# limit intensity range
|
1146
1231
|
img[img<0] = 0
|
1147
1232
|
img[img>1] = 1
|
1148
1233
|
|
@@ -1185,6 +1270,20 @@ def gradient(volData):
|
|
1185
1270
|
grad[grad>0]=1
|
1186
1271
|
return grad
|
1187
1272
|
|
1273
|
+
@numba.jit(nopython=True)
|
1274
|
+
def scale_probabilities(final):
|
1275
|
+
zsh, ysh, xsh, nb_labels = final.shape
|
1276
|
+
for k in range(zsh):
|
1277
|
+
for l in range(ysh):
|
1278
|
+
for m in range(xsh):
|
1279
|
+
scale_factor = 0
|
1280
|
+
for n in range(nb_labels):
|
1281
|
+
scale_factor += final[k,l,m,n]
|
1282
|
+
scale_factor = max(1, scale_factor)
|
1283
|
+
for n in range(nb_labels):
|
1284
|
+
final[k,l,m,n] /= scale_factor
|
1285
|
+
return final
|
1286
|
+
|
1188
1287
|
def predict_segmentation(bm, region_of_interest, channels, normalization_parameters):
|
1189
1288
|
|
1190
1289
|
from mpi4py import MPI
|
@@ -1192,13 +1291,26 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1192
1291
|
rank = comm.Get_rank()
|
1193
1292
|
ngpus = comm.Get_size()
|
1194
1293
|
|
1294
|
+
# optional result paths
|
1295
|
+
if bm.path_to_image:
|
1296
|
+
filename, bm.extension = os.path.splitext(bm.path_to_final)
|
1297
|
+
if bm.extension == '.gz':
|
1298
|
+
bm.extension = '.nii.gz'
|
1299
|
+
filename = filename[:-4]
|
1300
|
+
path_to_cleaned = filename + '.cleaned' + bm.extension
|
1301
|
+
path_to_filled = filename + '.filled' + bm.extension
|
1302
|
+
path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
|
1303
|
+
path_to_refined = filename + '.refined' + bm.extension
|
1304
|
+
path_to_acwe = filename + '.acwe' + bm.extension
|
1305
|
+
path_to_probs = filename + '.probs.tif'
|
1306
|
+
|
1195
1307
|
# Set the visible GPU by ID
|
1196
1308
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
1197
1309
|
if gpus:
|
1198
1310
|
try:
|
1199
1311
|
# Restrict TensorFlow to only use the specified GPU
|
1200
|
-
tf.config.experimental.set_visible_devices(gpus[rank], 'GPU')
|
1201
|
-
tf.config.experimental.set_memory_growth(gpus[rank], True)
|
1312
|
+
tf.config.experimental.set_visible_devices(gpus[rank % len(gpus)], 'GPU')
|
1313
|
+
tf.config.experimental.set_memory_growth(gpus[rank % len(gpus)], True)
|
1202
1314
|
except RuntimeError as e:
|
1203
1315
|
print(e)
|
1204
1316
|
|
@@ -1209,7 +1321,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1209
1321
|
nb_labels = len(bm.allLabels)
|
1210
1322
|
results['allLabels'] = bm.allLabels
|
1211
1323
|
|
1212
|
-
#
|
1324
|
+
# custom objects
|
1213
1325
|
if bm.dice_loss:
|
1214
1326
|
def loss_fn(y_true, y_pred):
|
1215
1327
|
dice = 0
|
@@ -1220,25 +1332,30 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1220
1332
|
loss = 1 - dice
|
1221
1333
|
return loss
|
1222
1334
|
custom_objects = {'dice_coef_loss': dice_coef_loss,'loss_fn': loss_fn}
|
1223
|
-
|
1335
|
+
elif bm.ignore_mask:
|
1336
|
+
custom_objects={'custom_loss': custom_loss}
|
1224
1337
|
else:
|
1225
|
-
|
1338
|
+
custom_objects=None
|
1339
|
+
|
1340
|
+
# load model
|
1341
|
+
model = load_model(bm.path_to_model, custom_objects=custom_objects)
|
1226
1342
|
|
1227
1343
|
# check if data can be loaded blockwise to save host memory
|
1228
1344
|
load_blockwise = False
|
1229
1345
|
if not bm.scaling and not bm.normalize and bm.path_to_image and not np.any(region_of_interest) and \
|
1230
1346
|
os.path.splitext(bm.path_to_image)[1] in ['.tif', '.tiff'] and not bm.acwe:
|
1347
|
+
|
1231
1348
|
# get image shape
|
1232
1349
|
tif = TiffFile(bm.path_to_image)
|
1233
1350
|
zsh = len(tif.pages)
|
1234
1351
|
ysh, xsh = tif.pages[0].shape
|
1235
1352
|
|
1236
1353
|
# load mask
|
1237
|
-
if bm.separation or bm.refinement:
|
1354
|
+
'''if bm.separation or bm.refinement:
|
1238
1355
|
mask, _ = load_data(bm.mask)
|
1239
1356
|
mask = mask.reshape(zsh, ysh, xsh, 1)
|
1240
1357
|
mask, _, _, _ = append_ghost_areas(bm, mask)
|
1241
|
-
mask = mask.reshape(mask.shape[:-1])
|
1358
|
+
mask = mask.reshape(mask.shape[:-1])'''
|
1242
1359
|
|
1243
1360
|
# determine new image size after appending ghost areas to make image dimensions divisible by patch size
|
1244
1361
|
z_rest = bm.z_patch - (zsh % bm.z_patch)
|
@@ -1258,7 +1375,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1258
1375
|
xsh += x_rest
|
1259
1376
|
|
1260
1377
|
# get Ids of patches
|
1261
|
-
list_IDs = []
|
1378
|
+
'''list_IDs = []
|
1262
1379
|
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
1263
1380
|
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1264
1381
|
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
@@ -1268,19 +1385,18 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1268
1385
|
if centerLabel>0 and np.any(patch!=centerLabel):
|
1269
1386
|
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1270
1387
|
elif bm.refinement:
|
1271
|
-
|
1272
|
-
if np.any(patch==0) and np.any(patch!=0):
|
1388
|
+
if np.any(mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
|
1273
1389
|
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1274
1390
|
else:
|
1275
|
-
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1391
|
+
list_IDs.append(k*ysh*xsh+l*xsh+m)'''
|
1276
1392
|
|
1277
1393
|
# make length of list divisible by batch size
|
1278
|
-
max_i = len(list_IDs)
|
1394
|
+
'''max_i = len(list_IDs)
|
1279
1395
|
rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
|
1280
|
-
list_IDs = list_IDs + list_IDs[:rest]
|
1396
|
+
list_IDs = list_IDs + list_IDs[:rest]'''
|
1281
1397
|
|
1282
1398
|
# prediction
|
1283
|
-
if
|
1399
|
+
if zsh*ysh*xsh > 256**3:
|
1284
1400
|
load_blockwise = True
|
1285
1401
|
|
1286
1402
|
# load image data and calculate patch IDs
|
@@ -1288,7 +1404,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1288
1404
|
|
1289
1405
|
# load prediction data
|
1290
1406
|
img, bm.img_header, z_shape, y_shape, x_shape, region_of_interest, bm.img_data = load_prediction_data(
|
1291
|
-
bm, channels,
|
1407
|
+
bm, channels, normalization_parameters, region_of_interest, bm.img_data, bm.img_header)
|
1292
1408
|
|
1293
1409
|
# append ghost areas
|
1294
1410
|
img, z_rest, y_rest, x_rest = append_ghost_areas(bm, img)
|
@@ -1314,6 +1430,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1314
1430
|
|
1315
1431
|
# load all patches on GPU memory
|
1316
1432
|
if not load_blockwise and nb_patches < 400:
|
1433
|
+
if rank==0:
|
1317
1434
|
|
1318
1435
|
# parameters
|
1319
1436
|
params = {'dim': (bm.z_patch, bm.y_patch, bm.x_patch),
|
@@ -1346,7 +1463,11 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1346
1463
|
|
1347
1464
|
# allocate final probabilities array
|
1348
1465
|
if rank==0 and bm.return_probs:
|
1349
|
-
|
1466
|
+
if load_blockwise:
|
1467
|
+
if not os.path.exists(path_to_probs[:-4]):
|
1468
|
+
os.mkdir(path_to_probs[:-4])
|
1469
|
+
else:
|
1470
|
+
final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
|
1350
1471
|
|
1351
1472
|
# allocate final result array
|
1352
1473
|
if rank==0:
|
@@ -1361,27 +1482,38 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1361
1482
|
else:
|
1362
1483
|
nprocs = ngpus
|
1363
1484
|
if j % ngpus == rank:
|
1485
|
+
|
1364
1486
|
# load blockwise from TIFF
|
1365
1487
|
if load_blockwise:
|
1366
1488
|
img, _, _, _, _, _, _ = load_prediction_data(bm,
|
1367
|
-
channels,
|
1489
|
+
channels, normalization_parameters,
|
1368
1490
|
region_of_interest, bm.img_data, bm.img_header, load_blockwise, z)
|
1369
1491
|
img, _, _, _ = append_ghost_areas(bm, img)
|
1370
1492
|
|
1493
|
+
# load mask block
|
1494
|
+
if bm.separation or bm.refinement:
|
1495
|
+
mask = imread(bm.mask, key=range(z,min(len(tif.pages),z+bm.z_patch)))
|
1496
|
+
# pad zeros to make dimensions divisible by patch dimensions
|
1497
|
+
pad_z = bm.z_patch - mask.shape[0]
|
1498
|
+
pad_y = (bm.y_patch - (mask.shape[1] % bm.y_patch)) % bm.y_patch
|
1499
|
+
pad_x = (bm.x_patch - (mask.shape[2] % bm.x_patch)) % bm.x_patch
|
1500
|
+
pad_width = [(0, pad_z), (0, pad_y), (0, pad_x)]
|
1501
|
+
mask = np.pad(mask, pad_width, mode='constant', constant_values=0)
|
1502
|
+
|
1371
1503
|
# list of IDs
|
1372
1504
|
list_IDs_block = []
|
1373
1505
|
|
1374
1506
|
# get Ids of patches
|
1507
|
+
k = 0 if load_blockwise else z
|
1375
1508
|
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1376
1509
|
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
1377
1510
|
if bm.separation:
|
1378
|
-
centerLabel = mask[
|
1379
|
-
patch = mask[
|
1511
|
+
centerLabel = mask[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
|
1512
|
+
patch = mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
|
1380
1513
|
if centerLabel>0 and np.any(patch!=centerLabel):
|
1381
1514
|
list_IDs_block.append(z*ysh*xsh+l*xsh+m)
|
1382
1515
|
elif bm.refinement:
|
1383
|
-
|
1384
|
-
if np.any(patch==0) and np.any(patch!=0):
|
1516
|
+
if np.any(mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
|
1385
1517
|
list_IDs_block.append(z*ysh*xsh+l*xsh+m)
|
1386
1518
|
else:
|
1387
1519
|
list_IDs_block.append(z*ysh*xsh+l*xsh+m)
|
@@ -1413,10 +1545,11 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1413
1545
|
# get patch
|
1414
1546
|
tmp_X = img[k:k+bm.z_patch,l:l+bm.y_patch,m:m+bm.x_patch]
|
1415
1547
|
if bm.patch_normalization:
|
1416
|
-
tmp_X =
|
1417
|
-
for
|
1418
|
-
|
1419
|
-
tmp_X[
|
1548
|
+
tmp_X = tmp_X.copy().astype(np.float32)
|
1549
|
+
for ch in range(channels):
|
1550
|
+
mean, std = welford_mean_std(tmp_X[...,ch])
|
1551
|
+
tmp_X[...,ch] -= mean
|
1552
|
+
tmp_X[...,ch] /= max(std, 1e-6)
|
1420
1553
|
X[i] = tmp_X
|
1421
1554
|
|
1422
1555
|
# predict batch
|
@@ -1460,19 +1593,28 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1460
1593
|
# overlap in z direction
|
1461
1594
|
if bm.stride_size < bm.z_patch:
|
1462
1595
|
if j+source>0:
|
1463
|
-
probs[
|
1596
|
+
probs[:-bm.stride_size] += overlap
|
1464
1597
|
overlap = probs[bm.stride_size:].copy()
|
1465
1598
|
|
1466
|
-
#
|
1599
|
+
# block z dimension
|
1467
1600
|
block_z = z_indices[j+source]
|
1468
|
-
if j+source==len(z_indices)-1:
|
1469
|
-
|
1470
|
-
if
|
1471
|
-
final[block_z:block_z+bm.z_patch] = probs
|
1601
|
+
if j+source==len(z_indices)-1: # last block
|
1602
|
+
block_zsh = bm.z_patch
|
1603
|
+
block_z_rest = z_rest if z_rest>0 else -block_zsh
|
1472
1604
|
else:
|
1473
1605
|
block_zsh = min(bm.stride_size, bm.z_patch)
|
1474
|
-
|
1475
|
-
|
1606
|
+
block_z_rest = -block_zsh
|
1607
|
+
|
1608
|
+
# calculate result
|
1609
|
+
label[block_z:block_z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
|
1610
|
+
|
1611
|
+
# return probabilities
|
1612
|
+
if bm.return_probs:
|
1613
|
+
if load_blockwise:
|
1614
|
+
block_output = scale_probabilities(probs[:block_zsh])
|
1615
|
+
block_output = block_output[:-block_z_rest,:-y_rest,:-x_rest]
|
1616
|
+
imwrite(path_to_probs[:-4] + f"/block-{j+source}.tif", block_output)
|
1617
|
+
else:
|
1476
1618
|
final[block_z:block_z+block_zsh] = probs[:block_zsh]
|
1477
1619
|
else:
|
1478
1620
|
for i in range(bm.z_patch):
|
@@ -1480,7 +1622,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1480
1622
|
if rank==0:
|
1481
1623
|
|
1482
1624
|
# refine mask data with result
|
1483
|
-
if bm.refinement:
|
1625
|
+
'''if bm.refinement:
|
1484
1626
|
# loop over boundary patches
|
1485
1627
|
for i, ID in enumerate(list_IDs):
|
1486
1628
|
if i < max_i:
|
@@ -1489,27 +1631,19 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1489
1631
|
l = rest // xsh
|
1490
1632
|
m = rest % xsh
|
1491
1633
|
mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] = label[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
|
1492
|
-
label = mask
|
1634
|
+
label = mask'''
|
1493
1635
|
|
1494
|
-
# remove
|
1495
|
-
if bm.return_probs:
|
1636
|
+
# remove ghost areas
|
1637
|
+
if bm.return_probs and not load_blockwise:
|
1496
1638
|
final = final[:-z_rest,:-y_rest,:-x_rest]
|
1497
1639
|
label = label[:-z_rest,:-y_rest,:-x_rest]
|
1498
1640
|
zsh, ysh, xsh = label.shape
|
1499
1641
|
|
1500
1642
|
# return probabilities
|
1501
|
-
if bm.return_probs:
|
1502
|
-
|
1503
|
-
nb = 0
|
1504
|
-
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
1505
|
-
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1506
|
-
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
1507
|
-
counter[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += 1
|
1508
|
-
nb += 1
|
1509
|
-
counter[counter==0] = 1
|
1510
|
-
probabilities = final / counter
|
1643
|
+
if bm.return_probs and not load_blockwise:
|
1644
|
+
probabilities = scale_probabilities(final)
|
1511
1645
|
if bm.scaling:
|
1512
|
-
probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
|
1646
|
+
probabilities = img_resize(probabilities, z_shape, y_shape, x_shape, interpolation=cv2.INTER_LINEAR)
|
1513
1647
|
if np.any(region_of_interest):
|
1514
1648
|
min_z,max_z,min_y,max_y,min_x,max_x,original_zsh,original_ysh,original_xsh = region_of_interest[:]
|
1515
1649
|
tmp = np.zeros((original_zsh, original_ysh, original_xsh, nb_labels), dtype=np.float32)
|
@@ -1570,17 +1704,8 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1570
1704
|
# save result
|
1571
1705
|
if bm.path_to_image:
|
1572
1706
|
save_data(bm.path_to_final, label, header=bm.header, compress=bm.compression)
|
1573
|
-
|
1574
|
-
|
1575
|
-
filename, bm.extension = os.path.splitext(bm.path_to_final)
|
1576
|
-
if bm.extension == '.gz':
|
1577
|
-
bm.extension = '.nii.gz'
|
1578
|
-
filename = filename[:-4]
|
1579
|
-
path_to_cleaned = filename + '.cleaned' + bm.extension
|
1580
|
-
path_to_filled = filename + '.filled' + bm.extension
|
1581
|
-
path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
|
1582
|
-
path_to_refined = filename + '.refined' + bm.extension
|
1583
|
-
path_to_acwe = filename + '.acwe' + bm.extension
|
1707
|
+
if bm.return_probs and not load_blockwise:
|
1708
|
+
imwrite(path_to_probs, probabilities)
|
1584
1709
|
|
1585
1710
|
# remove outliers
|
1586
1711
|
if bm.clean:
|