biomedisa 24.8.10__py3-none-any.whl → 25.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- biomedisa/deeplearning.py +35 -9
- biomedisa/features/DataGenerator.py +192 -123
- biomedisa/features/PredictDataGenerator.py +7 -5
- biomedisa/features/biomedisa_helper.py +59 -14
- biomedisa/features/crop_helper.py +7 -7
- biomedisa/features/keras_helper.py +281 -157
- biomedisa/features/random_walk/rw_large.py +6 -2
- biomedisa/features/random_walk/rw_small.py +7 -3
- biomedisa/features/remove_outlier.py +3 -3
- biomedisa/features/split_volume.py +12 -11
- biomedisa/interpolation.py +6 -9
- biomedisa/mesh.py +2 -2
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info}/METADATA +3 -2
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info}/RECORD +17 -17
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info}/WHEEL +1 -1
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info/licenses}/LICENSE +0 -0
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
##########################################################################
|
2
2
|
## ##
|
3
|
-
## Copyright (c) 2019-
|
3
|
+
## Copyright (c) 2019-2025 Philipp Lösel. All rights reserved. ##
|
4
4
|
## ##
|
5
5
|
## This file is part of the open source project biomedisa. ##
|
6
6
|
## ##
|
@@ -39,13 +39,13 @@ from tensorflow.keras.layers import (
|
|
39
39
|
from tensorflow.keras import backend as K
|
40
40
|
from tensorflow.keras.utils import to_categorical
|
41
41
|
from tensorflow.keras.callbacks import Callback, ModelCheckpoint, EarlyStopping
|
42
|
-
from biomedisa.features.DataGenerator import DataGenerator
|
42
|
+
from biomedisa.features.DataGenerator import DataGenerator, welford_mean_std
|
43
43
|
from biomedisa.features.PredictDataGenerator import PredictDataGenerator
|
44
|
-
from biomedisa.features.biomedisa_helper import (
|
44
|
+
from biomedisa.features.biomedisa_helper import (unique, welford_mean_std,
|
45
45
|
img_resize, load_data, save_data, set_labels_to_zero, id_generator, unique_file_path)
|
46
46
|
from biomedisa.features.remove_outlier import clean, fill
|
47
47
|
from biomedisa.features.active_contour import activeContour
|
48
|
-
from tifffile import TiffFile, imread
|
48
|
+
from tifffile import TiffFile, imread, imwrite
|
49
49
|
import matplotlib.pyplot as plt
|
50
50
|
import SimpleITK as sitk
|
51
51
|
import tensorflow as tf
|
@@ -220,11 +220,11 @@ def compute_position(position, zsh, ysh, xsh):
|
|
220
220
|
position[k,l,m] = x+y+z
|
221
221
|
return position
|
222
222
|
|
223
|
-
def make_conv_block(nb_filters, input_tensor, block):
|
223
|
+
def make_conv_block(nb_filters, input_tensor, block, dtype):
|
224
224
|
def make_stage(input_tensor, stage):
|
225
225
|
name = 'conv_{}_{}'.format(block, stage)
|
226
226
|
x = Conv3D(nb_filters, (3, 3, 3), activation='relu',
|
227
|
-
padding='same', name=name, data_format="channels_last")(input_tensor)
|
227
|
+
padding='same', name=name, data_format="channels_last", dtype=dtype)(input_tensor)
|
228
228
|
name = 'batch_norm_{}_{}'.format(block, stage)
|
229
229
|
try:
|
230
230
|
x = BatchNormalization(name=name, synchronized=True)(x)
|
@@ -266,62 +266,70 @@ def make_conv_block_resnet(nb_filters, input_tensor, block):
|
|
266
266
|
|
267
267
|
return out
|
268
268
|
|
269
|
-
def make_unet(input_shape, nb_labels
|
269
|
+
def make_unet(bm, input_shape, nb_labels):
|
270
|
+
# enable mixed_precision
|
271
|
+
if bm.mixed_precision:
|
272
|
+
dtype = "float16"
|
273
|
+
else:
|
274
|
+
dtype = "float32"
|
270
275
|
|
276
|
+
# input
|
271
277
|
nb_plans, nb_rows, nb_cols, _ = input_shape
|
278
|
+
inputs = Input(input_shape, dtype=dtype)
|
272
279
|
|
273
|
-
|
274
|
-
|
275
|
-
filters = filters.split('-')
|
280
|
+
# configure number of layers and filters
|
281
|
+
filters = bm.network_filters.split('-')
|
276
282
|
filters = np.array(filters, dtype=int)
|
277
283
|
latent_space_size = filters[-1]
|
278
284
|
filters = filters[:-1]
|
285
|
+
|
286
|
+
# initialize blocks
|
279
287
|
convs = []
|
280
288
|
|
289
|
+
# encoder
|
281
290
|
i = 1
|
282
291
|
for f in filters:
|
283
292
|
if i==1:
|
284
|
-
if resnet:
|
293
|
+
if bm.resnet:
|
285
294
|
conv = make_conv_block_resnet(f, inputs, i)
|
286
295
|
else:
|
287
|
-
conv = make_conv_block(f, inputs, i)
|
296
|
+
conv = make_conv_block(f, inputs, i, dtype)
|
288
297
|
else:
|
289
|
-
if resnet:
|
298
|
+
if bm.resnet:
|
290
299
|
conv = make_conv_block_resnet(f, pool, i)
|
291
300
|
else:
|
292
|
-
conv = make_conv_block(f, pool, i)
|
301
|
+
conv = make_conv_block(f, pool, i, dtype)
|
293
302
|
pool = MaxPooling3D(pool_size=(2, 2, 2))(conv)
|
294
303
|
convs.append(conv)
|
295
304
|
i += 1
|
296
305
|
|
297
|
-
|
306
|
+
# latent space
|
307
|
+
if bm.resnet:
|
298
308
|
conv = make_conv_block_resnet(latent_space_size, pool, i)
|
299
309
|
else:
|
300
|
-
conv = make_conv_block(latent_space_size, pool, i)
|
310
|
+
conv = make_conv_block(latent_space_size, pool, i, dtype)
|
301
311
|
i += 1
|
302
312
|
|
313
|
+
# decoder
|
303
314
|
for k, f in enumerate(filters[::-1]):
|
304
315
|
up = Concatenate()([UpSampling3D(size=(2, 2, 2))(conv), convs[-(k+1)]])
|
305
|
-
if resnet:
|
316
|
+
if bm.resnet:
|
306
317
|
conv = make_conv_block_resnet(f, up, i)
|
307
318
|
else:
|
308
|
-
conv = make_conv_block(f, up, i)
|
319
|
+
conv = make_conv_block(f, up, i, dtype)
|
309
320
|
i += 1
|
310
321
|
|
322
|
+
# final layer and output
|
311
323
|
conv = Conv3D(nb_labels, (1, 1, 1), name=f'conv_{i}_1')(conv)
|
312
|
-
|
313
324
|
x = Reshape((nb_plans * nb_rows * nb_cols, nb_labels))(conv)
|
314
325
|
x = Activation('softmax')(x)
|
315
326
|
outputs = Reshape((nb_plans, nb_rows, nb_cols, nb_labels))(x)
|
316
|
-
|
317
327
|
model = Model(inputs=inputs, outputs=outputs)
|
318
|
-
|
319
328
|
return model
|
320
329
|
|
321
330
|
def get_labels(arr, allLabels):
|
322
|
-
np_unique = np.unique(arr)
|
323
331
|
final = np.zeros_like(arr)
|
324
|
-
for k in
|
332
|
+
for k in unique(arr):
|
325
333
|
final[arr == k] = allLabels[k]
|
326
334
|
return final
|
327
335
|
|
@@ -412,10 +420,17 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
412
420
|
argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(label, x_puffer, y_puffer, z_puffer)
|
413
421
|
label = label[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x].copy()
|
414
422
|
if bm.scaling:
|
415
|
-
label_values, counts =
|
423
|
+
label_values, counts = unique(label, return_counts=True)
|
416
424
|
print(f'{os.path.basename(label_names[0])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
|
417
425
|
label = img_resize(label, bm.z_scale, bm.y_scale, bm.x_scale, labels=True)
|
418
426
|
|
427
|
+
# label channel must be 1 or 2 if using ignore mask
|
428
|
+
if len(label.shape)>3 and label.shape[3]>1 and not bm.ignore_mask:
|
429
|
+
InputError.message = 'Training labels must have one channel (gray values).'
|
430
|
+
raise InputError()
|
431
|
+
if len(label.shape)==3:
|
432
|
+
label = label.reshape(label.shape[0], label.shape[1], label.shape[2], 1)
|
433
|
+
|
419
434
|
# if header is not single data stream Amira Mesh falling back to Multi-TIFF
|
420
435
|
if extension != '.am':
|
421
436
|
extension, header = '.tif', None
|
@@ -425,7 +440,7 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
425
440
|
else:
|
426
441
|
header = header[0]
|
427
442
|
|
428
|
-
# load first
|
443
|
+
# load first image
|
429
444
|
if any(img_list):
|
430
445
|
img, _ = load_data(img_names[0], 'first_queue')
|
431
446
|
if img is None:
|
@@ -437,14 +452,15 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
437
452
|
else:
|
438
453
|
img = img_in
|
439
454
|
img_names = ['img_1']
|
440
|
-
|
455
|
+
|
456
|
+
# label and image dimensions must match
|
457
|
+
if label_dim[:3] != img.shape[:3]:
|
441
458
|
InputError.message = f'Dimensions of "{os.path.basename(img_names[0])}" and "{os.path.basename(label_names[0])}" do not match'
|
442
459
|
raise InputError()
|
443
460
|
|
444
|
-
#
|
461
|
+
# image channels must be >=1
|
445
462
|
if len(img.shape)==3:
|
446
|
-
|
447
|
-
img = img.reshape(z_shape, y_shape, x_shape, 1)
|
463
|
+
img = img.reshape(img.shape[0], img.shape[1], img.shape[2], 1)
|
448
464
|
if channels is None:
|
449
465
|
channels = img.shape[3]
|
450
466
|
if channels != img.shape[3]:
|
@@ -463,30 +479,30 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
463
479
|
# scale data to the range from 0 to 1
|
464
480
|
if not bm.patch_normalization:
|
465
481
|
img = img.astype(np.float32)
|
466
|
-
for
|
467
|
-
img[
|
468
|
-
img[
|
482
|
+
for ch in range(channels):
|
483
|
+
img[...,ch] -= np.amin(img[...,ch])
|
484
|
+
img[...,ch] /= np.amax(img[...,ch])
|
469
485
|
|
470
486
|
# normalize first validation image
|
471
487
|
if bm.normalize and np.any(normalization_parameters):
|
472
488
|
img = img.astype(np.float32)
|
473
|
-
for
|
474
|
-
mean, std =
|
475
|
-
img[
|
476
|
-
img[
|
489
|
+
for ch in range(channels):
|
490
|
+
mean, std = welford_mean_std(img[...,ch])
|
491
|
+
img[...,ch] = (img[...,ch] - mean) / std
|
492
|
+
img[...,ch] = img[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
|
477
493
|
|
478
494
|
# get normalization parameters from first image
|
479
495
|
if normalization_parameters is None:
|
480
496
|
normalization_parameters = np.zeros((2,channels))
|
481
497
|
if bm.normalize:
|
482
|
-
for
|
483
|
-
normalization_parameters[
|
484
|
-
normalization_parameters[1,c] = np.std(img[:,:,:,c])
|
498
|
+
for ch in range(channels):
|
499
|
+
normalization_parameters[:,ch] = welford_mean_std(img[...,ch])
|
485
500
|
|
486
501
|
# pad data
|
487
502
|
if not bm.scaling:
|
488
503
|
img_data_list = [img]
|
489
504
|
label_data_list = [label]
|
505
|
+
img_dtype = img.dtype
|
490
506
|
# no-scaling for list of images needs negative values as it encodes padded areas as -1
|
491
507
|
label_dtype = label.dtype
|
492
508
|
if label_dtype==np.uint8:
|
@@ -500,7 +516,7 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
500
516
|
|
501
517
|
for k in range(1, number_of_images):
|
502
518
|
|
503
|
-
#
|
519
|
+
# load label data and pre-process
|
504
520
|
if any(label_list):
|
505
521
|
a, _ = load_data(label_names[k], 'first_queue')
|
506
522
|
if a is None:
|
@@ -514,14 +530,24 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
514
530
|
argmin_z,argmax_z,argmin_y,argmax_y,argmin_x,argmax_x = predict_blocksize(a, x_puffer, y_puffer, z_puffer)
|
515
531
|
a = np.copy(a[argmin_z:argmax_z,argmin_y:argmax_y,argmin_x:argmax_x], order='C')
|
516
532
|
if bm.scaling:
|
517
|
-
label_values, counts =
|
533
|
+
label_values, counts = unique(a, return_counts=True)
|
518
534
|
print(f'{os.path.basename(label_names[k])}:', 'Labels:', label_values[1:], 'Sizes:', counts[1:])
|
519
535
|
a = img_resize(a, bm.z_scale, bm.y_scale, bm.x_scale, labels=True)
|
536
|
+
|
537
|
+
# label channel must be 1 or 2 if using ignore mask
|
538
|
+
if len(a.shape)>3 and a.shape[3]>1 and not bm.ignore_mask:
|
539
|
+
InputError.message = 'Training labels must have one channel (gray values).'
|
540
|
+
raise InputError()
|
541
|
+
if len(a.shape)==3:
|
542
|
+
a = a.reshape(a.shape[0], a.shape[1], a.shape[2], 1)
|
543
|
+
|
544
|
+
# append label data
|
545
|
+
if bm.scaling:
|
520
546
|
label = np.append(label, a, axis=0)
|
521
547
|
else:
|
522
548
|
label_data_list.append(a)
|
523
549
|
|
524
|
-
#
|
550
|
+
# load image data and pre-process
|
525
551
|
if any(img_list):
|
526
552
|
a, _ = load_data(img_names[k], 'first_queue')
|
527
553
|
if a is None:
|
@@ -529,12 +555,11 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
529
555
|
raise InputError()
|
530
556
|
else:
|
531
557
|
a = img_in[k]
|
532
|
-
if label_dim != a.shape[:3]:
|
558
|
+
if label_dim[:3] != a.shape[:3]:
|
533
559
|
InputError.message = f'Dimensions of "{os.path.basename(img_names[k])}" and "{os.path.basename(label_names[k])}" do not match'
|
534
560
|
raise InputError()
|
535
561
|
if len(a.shape)==3:
|
536
|
-
|
537
|
-
a = a.reshape(z_shape, y_shape, x_shape, 1)
|
562
|
+
a = a.reshape(a.shape[0], a.shape[1], a.shape[2], 1)
|
538
563
|
if a.shape[3] != channels:
|
539
564
|
InputError.message = f'Number of channels must be {channels} for "{os.path.basename(img_names[k])}"'
|
540
565
|
raise InputError()
|
@@ -545,15 +570,17 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
545
570
|
a = img_resize(a, bm.z_scale, bm.y_scale, bm.x_scale)
|
546
571
|
if not bm.patch_normalization:
|
547
572
|
a = a.astype(np.float32)
|
548
|
-
for
|
549
|
-
a[
|
550
|
-
a[
|
573
|
+
for ch in range(channels):
|
574
|
+
a[...,ch] -= np.amin(a[...,ch])
|
575
|
+
a[...,ch] /= np.amax(a[...,ch])
|
551
576
|
if bm.normalize:
|
552
577
|
a = a.astype(np.float32)
|
553
|
-
for
|
554
|
-
mean, std =
|
555
|
-
a[
|
556
|
-
a[
|
578
|
+
for ch in range(channels):
|
579
|
+
mean, std = welford_mean_std(a[...,ch])
|
580
|
+
a[...,ch] = (a[...,ch] - mean) / std
|
581
|
+
a[...,ch] = a[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
|
582
|
+
|
583
|
+
# append image data
|
557
584
|
if bm.scaling:
|
558
585
|
img = np.append(img, a, axis=0)
|
559
586
|
else:
|
@@ -565,15 +592,14 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
565
592
|
for img in img_data_list:
|
566
593
|
target_y = max(target_y, img.shape[1])
|
567
594
|
target_x = max(target_x, img.shape[2])
|
568
|
-
img = np.empty((0, target_y, target_x, channels), dtype=
|
569
|
-
label = np.empty((0, target_y, target_x), dtype=label_dtype)
|
595
|
+
img = np.empty((0, target_y, target_x, channels), dtype=img_dtype)
|
596
|
+
label = np.empty((0, target_y, target_x, 2 if bm.ignore_mask else 1), dtype=label_dtype)
|
570
597
|
for k in range(len(img_data_list)):
|
571
598
|
pad_y = target_y - img_data_list[k].shape[1]
|
572
599
|
pad_x = target_x - img_data_list[k].shape[2]
|
573
600
|
pad_width = [(0, 0), (0, pad_y), (0, pad_x), (0, 0)]
|
574
601
|
tmp = np.pad(img_data_list[k], pad_width, mode='constant', constant_values=0)
|
575
602
|
img = np.append(img, tmp, axis=0)
|
576
|
-
pad_width = [(0, 0), (0, pad_y), (0, pad_x)]
|
577
603
|
tmp = np.pad(label_data_list[k].astype(label_dtype), pad_width, mode='constant', constant_values=-1)
|
578
604
|
label = np.append(label, tmp, axis=0)
|
579
605
|
|
@@ -587,13 +613,13 @@ def load_training_data(bm, img_list, label_list, channels, img_in=None, label_in
|
|
587
613
|
else:
|
588
614
|
# get labels
|
589
615
|
if allLabels is None:
|
590
|
-
allLabels =
|
616
|
+
allLabels = unique(label[...,0])
|
591
617
|
index = np.argwhere(allLabels<0)
|
592
618
|
allLabels = np.delete(allLabels, index)
|
593
619
|
|
594
620
|
# labels must be in ascending order
|
595
621
|
for k, l in enumerate(allLabels):
|
596
|
-
label[label==l] = k
|
622
|
+
label[...,0][label[...,0]==l] = k
|
597
623
|
|
598
624
|
return img, label, allLabels, normalization_parameters, header, extension, channels
|
599
625
|
|
@@ -725,10 +751,11 @@ class Metrics(Callback):
|
|
725
751
|
m = rest % self.dim_img[2]
|
726
752
|
tmp_X = self.img[k:k+self.dim_patch[0],l:l+self.dim_patch[1],m:m+self.dim_patch[2]]
|
727
753
|
if self.patch_normalization:
|
728
|
-
tmp_X =
|
729
|
-
for
|
730
|
-
|
731
|
-
tmp_X[
|
754
|
+
tmp_X = tmp_X.copy().astype(np.float32)
|
755
|
+
for ch in range(self.n_channels):
|
756
|
+
mean, std = welford_mean_std(tmp_X[...,ch])
|
757
|
+
tmp_X[...,ch] -= mean
|
758
|
+
tmp_X[...,ch] /= max(std, 1e-6)
|
732
759
|
X_val[i] = tmp_X
|
733
760
|
|
734
761
|
# Prediction segmentation
|
@@ -753,6 +780,7 @@ class Metrics(Callback):
|
|
753
780
|
# get result
|
754
781
|
result = np.argmax(result, axis=-1)
|
755
782
|
result = result.astype(np.uint8)
|
783
|
+
result = result.reshape(*result.shape, 1)
|
756
784
|
|
757
785
|
# calculate standard accuracy
|
758
786
|
if not self.train:
|
@@ -772,17 +800,17 @@ class Metrics(Callback):
|
|
772
800
|
logs['dice'] = dice
|
773
801
|
else:
|
774
802
|
# save best model only
|
775
|
-
if epoch == 0 or
|
803
|
+
if epoch == 0 or dice > max(self.history['val_dice']):
|
776
804
|
self.model.save(str(self.path_to_model))
|
777
805
|
|
778
806
|
# add accuracy to history
|
779
|
-
self.history['loss'].append(
|
780
|
-
self.history['accuracy'].append(
|
807
|
+
self.history['loss'].append(logs['loss'])
|
808
|
+
self.history['accuracy'].append(logs['accuracy'])
|
781
809
|
if self.train_dice:
|
782
|
-
self.history['dice'].append(
|
783
|
-
self.history['val_accuracy'].append(
|
784
|
-
self.history['val_dice'].append(
|
785
|
-
self.history['val_loss'].append(
|
810
|
+
self.history['dice'].append(logs['dice'])
|
811
|
+
self.history['val_accuracy'].append(accuracy)
|
812
|
+
self.history['val_dice'].append(dice)
|
813
|
+
self.history['val_loss'].append(val_loss)
|
786
814
|
|
787
815
|
# tensorflow monitoring variables
|
788
816
|
logs['val_loss'] = val_loss
|
@@ -799,11 +827,11 @@ class Metrics(Callback):
|
|
799
827
|
|
800
828
|
# print accuracies
|
801
829
|
print('\nValidation history:')
|
802
|
-
print(
|
830
|
+
print("train_acc: [" + " ".join(f"{x:.4f}" for x in self.history['accuracy']) + "]")
|
803
831
|
if self.train_dice:
|
804
|
-
print(
|
805
|
-
print(
|
806
|
-
print(
|
832
|
+
print("train_dice: [" + " ".join(f"{x:.4f}" for x in self.history['dice']) + "]")
|
833
|
+
print("val_acc: [" + " ".join(f"{x:.4f}" for x in self.history['val_accuracy']) + "]")
|
834
|
+
print("val_dice: [" + " ".join(f"{x:.4f}" for x in self.history['val_dice']) + "]")
|
807
835
|
print('')
|
808
836
|
|
809
837
|
# early stopping
|
@@ -850,13 +878,13 @@ def categorical_crossentropy(true_labels, predicted_probs):
|
|
850
878
|
# Clip predicted probabilities to avoid log(0) issues
|
851
879
|
predicted_probs = np.clip(predicted_probs, 1e-7, 1 - 1e-7)
|
852
880
|
predicted_probs = -np.log(predicted_probs)
|
853
|
-
zsh,ysh,xsh = true_labels.shape
|
881
|
+
zsh, ysh, xsh, _ = true_labels.shape
|
854
882
|
# Calculate categorical crossentropy
|
855
883
|
loss = 0
|
856
884
|
for z in range(zsh):
|
857
885
|
for y in range(ysh):
|
858
886
|
for x in range(xsh):
|
859
|
-
l = true_labels[z,y,x]
|
887
|
+
l = true_labels[z,y,x,0]
|
860
888
|
loss += predicted_probs[z,y,x,l]
|
861
889
|
loss = loss / float(zsh*ysh*xsh)
|
862
890
|
return loss
|
@@ -880,6 +908,42 @@ def dice_coef_loss(nb_labels):
|
|
880
908
|
return loss
|
881
909
|
return loss_fn
|
882
910
|
|
911
|
+
def custom_loss(y_true, y_pred):
|
912
|
+
# Extract labels and ignore mask
|
913
|
+
labels = tf.cast(y_true[..., 0], tf.int32) # First channel contains class labels
|
914
|
+
ignore_mask = tf.cast(y_true[..., 1], tf.float32) # Second channel contains mask (0 = ignore, 1 = include)
|
915
|
+
|
916
|
+
# Convert integer labels to one-hot encoding
|
917
|
+
y_true_one_hot = tf.one_hot(labels, depth=2)
|
918
|
+
|
919
|
+
# Clip y_pred to avoid log(0)
|
920
|
+
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0)
|
921
|
+
|
922
|
+
# Compute categorical cross-entropy
|
923
|
+
loss = -tf.reduce_sum(y_true_one_hot * tf.math.log(y_pred), axis=-1)
|
924
|
+
|
925
|
+
# Apply ignore mask (ignore = 0 → loss is zero, include = 1 → loss is counted)
|
926
|
+
loss = loss * ignore_mask
|
927
|
+
|
928
|
+
# Return mean loss over valid (non-ignored) samples
|
929
|
+
return tf.reduce_sum(loss) / tf.reduce_sum(ignore_mask)
|
930
|
+
|
931
|
+
def custom_accuracy(y_true, y_pred):
|
932
|
+
labels = tf.cast(y_true[..., 0], tf.int32) # Extract actual values
|
933
|
+
ignore_mask = y_true[..., 1] # Extract mask (1 = include, 0 = ignore)
|
934
|
+
|
935
|
+
# Convert predictions to discrete values (assuming regression: round values)
|
936
|
+
y_pred_class = tf.argmax(y_pred, axis=-1, output_type=tf.int32)
|
937
|
+
|
938
|
+
# Compute correct predictions (1 where correct, 0 where incorrect)
|
939
|
+
correct_predictions = tf.cast(tf.equal(labels, y_pred_class), tf.float32)
|
940
|
+
|
941
|
+
# Apply ignore mask
|
942
|
+
masked_correct_predictions = correct_predictions * ignore_mask
|
943
|
+
|
944
|
+
# Compute accuracy only over valid (non-ignored) pixels
|
945
|
+
return tf.reduce_sum(masked_correct_predictions) / tf.reduce_sum(ignore_mask)
|
946
|
+
|
883
947
|
def train_segmentation(bm):
|
884
948
|
|
885
949
|
# training data
|
@@ -987,20 +1051,21 @@ def train_segmentation(bm):
|
|
987
1051
|
'dim_img': (zsh, ysh, xsh),
|
988
1052
|
'n_classes': nb_labels,
|
989
1053
|
'n_channels': bm.channels,
|
990
|
-
'augment': (bm.flip_x, bm.flip_y, bm.flip_z, bm.swapaxes, bm.rotate),
|
1054
|
+
'augment': (bm.flip_x, bm.flip_y, bm.flip_z, bm.swapaxes, bm.rotate, bm.rotate3d),
|
991
1055
|
'patch_normalization': bm.patch_normalization,
|
992
|
-
'separation': bm.separation
|
1056
|
+
'separation': bm.separation,
|
1057
|
+
'ignore_mask': bm.ignore_mask}
|
993
1058
|
|
994
1059
|
# data generator
|
995
1060
|
validation_generator = None
|
996
|
-
training_generator = DataGenerator(bm.img_data, bm.label_data, list_IDs_fg, list_IDs_bg, True, True,
|
1061
|
+
training_generator = DataGenerator(bm.img_data, bm.label_data, list_IDs_fg, list_IDs_bg, True, True, **params)
|
997
1062
|
if bm.val_img_data is not None:
|
998
1063
|
if bm.val_dice:
|
999
1064
|
val_metrics = Metrics(bm, bm.val_img_data, bm.val_label_data, list_IDs_val_fg, (zsh_val, ysh_val, xsh_val), nb_labels, False)
|
1000
1065
|
else:
|
1001
1066
|
params['dim_img'] = (zsh_val, ysh_val, xsh_val)
|
1002
|
-
params['augment'] = (False, False, False, False, 0)
|
1003
|
-
validation_generator = DataGenerator(bm.val_img_data, bm.val_label_data, list_IDs_val_fg, list_IDs_val_bg, True, False,
|
1067
|
+
params['augment'] = (False, False, False, False, 0, False)
|
1068
|
+
validation_generator = DataGenerator(bm.val_img_data, bm.val_label_data, list_IDs_val_fg, list_IDs_val_bg, True, False, **params)
|
1004
1069
|
|
1005
1070
|
# monitor dice score on training data
|
1006
1071
|
if bm.train_dice:
|
@@ -1018,7 +1083,7 @@ def train_segmentation(bm):
|
|
1018
1083
|
with strategy.scope():
|
1019
1084
|
|
1020
1085
|
# build model
|
1021
|
-
model = make_unet(input_shape, nb_labels
|
1086
|
+
model = make_unet(bm, input_shape, nb_labels)
|
1022
1087
|
model.summary()
|
1023
1088
|
|
1024
1089
|
# pretrained model
|
@@ -1037,13 +1102,28 @@ def train_segmentation(bm):
|
|
1037
1102
|
layer.trainable = False
|
1038
1103
|
|
1039
1104
|
# optimizer
|
1040
|
-
|
1105
|
+
optimizer = SGD(learning_rate=bm.learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
|
1106
|
+
#optimizer = tf.keras.optimizers.Adam(learning_rate=bm.learning_rate, epsilon=1e-4) lr=0.0001
|
1107
|
+
if bm.mixed_precision:
|
1108
|
+
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer, dynamic=False, initial_scale=128)
|
1109
|
+
|
1110
|
+
# Rename the function to appear as "accuracy" in logs
|
1111
|
+
if bm.ignore_mask:
|
1112
|
+
custom_accuracy.__name__ = "accuracy"
|
1113
|
+
metrics=[custom_accuracy]
|
1114
|
+
else:
|
1115
|
+
metrics=['accuracy']
|
1116
|
+
|
1117
|
+
# loss function
|
1118
|
+
if bm.ignore_mask:
|
1119
|
+
loss=custom_loss
|
1120
|
+
else:
|
1121
|
+
loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
|
1041
1122
|
|
1042
1123
|
# comile model
|
1043
|
-
loss=dice_coef_loss(nb_labels) if bm.dice_loss else 'categorical_crossentropy'
|
1044
1124
|
model.compile(loss=loss,
|
1045
|
-
optimizer=
|
1046
|
-
metrics=
|
1125
|
+
optimizer=optimizer,
|
1126
|
+
metrics=metrics)
|
1047
1127
|
|
1048
1128
|
# save meta data
|
1049
1129
|
meta_data = MetaData(bm.path_to_model, configuration_data, allLabels,
|
@@ -1082,7 +1162,7 @@ def train_segmentation(bm):
|
|
1082
1162
|
callbacks=callbacks,
|
1083
1163
|
workers=bm.workers)
|
1084
1164
|
|
1085
|
-
def load_prediction_data(bm, channels,
|
1165
|
+
def load_prediction_data(bm, channels, normalization_parameters,
|
1086
1166
|
region_of_interest, img, img_header, load_blockwise=False, z=None):
|
1087
1167
|
|
1088
1168
|
# read image data
|
@@ -1110,10 +1190,9 @@ def load_prediction_data(bm, channels, normalize, normalization_parameters,
|
|
1110
1190
|
if bm.acwe:
|
1111
1191
|
img_data = img.copy()
|
1112
1192
|
|
1113
|
-
#
|
1193
|
+
# image data must have number of channels >=1
|
1114
1194
|
if len(img.shape)==3:
|
1115
|
-
|
1116
|
-
img = img.reshape(z_shape, y_shape, x_shape, 1)
|
1195
|
+
img = img.reshape(img.shape[0], img.shape[1], img.shape[2], 1)
|
1117
1196
|
if img.shape[3] != channels:
|
1118
1197
|
InputError.message = f'Number of channels must be {channels}.'
|
1119
1198
|
raise InputError()
|
@@ -1128,22 +1207,27 @@ def load_prediction_data(bm, channels, normalize, normalization_parameters,
|
|
1128
1207
|
region_of_interest = np.array([min_z,max_z,min_y,max_y,min_x,max_x,z_shape,y_shape,x_shape])
|
1129
1208
|
z_shape, y_shape, x_shape = max_z-min_z, max_y-min_y, max_x-min_x
|
1130
1209
|
|
1131
|
-
#
|
1132
|
-
img = img.astype(np.float32)
|
1210
|
+
# resize image data
|
1133
1211
|
if bm.scaling:
|
1212
|
+
img = img.astype(np.float32)
|
1134
1213
|
img = img_resize(img, bm.z_scale, bm.y_scale, bm.x_scale)
|
1135
1214
|
|
1215
|
+
# scale image data
|
1216
|
+
if not bm.patch_normalization:
|
1217
|
+
img = img.astype(np.float32)
|
1218
|
+
for ch in range(channels):
|
1219
|
+
img[...,ch] -= np.amin(img[...,ch])
|
1220
|
+
img[...,ch] /= np.amax(img[...,ch])
|
1221
|
+
|
1136
1222
|
# normalize image data
|
1137
|
-
|
1138
|
-
img
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
img[
|
1143
|
-
img[:,:,:,c] = img[:,:,:,c] * normalization_parameters[1,c] + normalization_parameters[0,c]
|
1223
|
+
if bm.normalize:
|
1224
|
+
img = img.astype(np.float32)
|
1225
|
+
for ch in range(channels):
|
1226
|
+
mean, std = welford_mean_std(img[...,ch])
|
1227
|
+
img[...,ch] = (img[...,ch] - mean) / std
|
1228
|
+
img[...,ch] = img[...,ch] * normalization_parameters[1,ch] + normalization_parameters[0,ch]
|
1144
1229
|
|
1145
|
-
|
1146
|
-
if normalize:
|
1230
|
+
# limit intensity range
|
1147
1231
|
img[img<0] = 0
|
1148
1232
|
img[img>1] = 1
|
1149
1233
|
|
@@ -1186,6 +1270,20 @@ def gradient(volData):
|
|
1186
1270
|
grad[grad>0]=1
|
1187
1271
|
return grad
|
1188
1272
|
|
1273
|
+
@numba.jit(nopython=True)
|
1274
|
+
def scale_probabilities(final):
|
1275
|
+
zsh, ysh, xsh, nb_labels = final.shape
|
1276
|
+
for k in range(zsh):
|
1277
|
+
for l in range(ysh):
|
1278
|
+
for m in range(xsh):
|
1279
|
+
scale_factor = 0
|
1280
|
+
for n in range(nb_labels):
|
1281
|
+
scale_factor += final[k,l,m,n]
|
1282
|
+
scale_factor = max(1, scale_factor)
|
1283
|
+
for n in range(nb_labels):
|
1284
|
+
final[k,l,m,n] /= scale_factor
|
1285
|
+
return final
|
1286
|
+
|
1189
1287
|
def predict_segmentation(bm, region_of_interest, channels, normalization_parameters):
|
1190
1288
|
|
1191
1289
|
from mpi4py import MPI
|
@@ -1193,13 +1291,26 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1193
1291
|
rank = comm.Get_rank()
|
1194
1292
|
ngpus = comm.Get_size()
|
1195
1293
|
|
1294
|
+
# optional result paths
|
1295
|
+
if bm.path_to_image:
|
1296
|
+
filename, bm.extension = os.path.splitext(bm.path_to_final)
|
1297
|
+
if bm.extension == '.gz':
|
1298
|
+
bm.extension = '.nii.gz'
|
1299
|
+
filename = filename[:-4]
|
1300
|
+
path_to_cleaned = filename + '.cleaned' + bm.extension
|
1301
|
+
path_to_filled = filename + '.filled' + bm.extension
|
1302
|
+
path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
|
1303
|
+
path_to_refined = filename + '.refined' + bm.extension
|
1304
|
+
path_to_acwe = filename + '.acwe' + bm.extension
|
1305
|
+
path_to_probs = filename + '.probs.tif'
|
1306
|
+
|
1196
1307
|
# Set the visible GPU by ID
|
1197
1308
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
1198
1309
|
if gpus:
|
1199
1310
|
try:
|
1200
1311
|
# Restrict TensorFlow to only use the specified GPU
|
1201
|
-
tf.config.experimental.set_visible_devices(gpus[rank], 'GPU')
|
1202
|
-
tf.config.experimental.set_memory_growth(gpus[rank], True)
|
1312
|
+
tf.config.experimental.set_visible_devices(gpus[rank % len(gpus)], 'GPU')
|
1313
|
+
tf.config.experimental.set_memory_growth(gpus[rank % len(gpus)], True)
|
1203
1314
|
except RuntimeError as e:
|
1204
1315
|
print(e)
|
1205
1316
|
|
@@ -1210,7 +1321,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1210
1321
|
nb_labels = len(bm.allLabels)
|
1211
1322
|
results['allLabels'] = bm.allLabels
|
1212
1323
|
|
1213
|
-
#
|
1324
|
+
# custom objects
|
1214
1325
|
if bm.dice_loss:
|
1215
1326
|
def loss_fn(y_true, y_pred):
|
1216
1327
|
dice = 0
|
@@ -1221,25 +1332,30 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1221
1332
|
loss = 1 - dice
|
1222
1333
|
return loss
|
1223
1334
|
custom_objects = {'dice_coef_loss': dice_coef_loss,'loss_fn': loss_fn}
|
1224
|
-
|
1335
|
+
elif bm.ignore_mask:
|
1336
|
+
custom_objects={'custom_loss': custom_loss}
|
1225
1337
|
else:
|
1226
|
-
|
1338
|
+
custom_objects=None
|
1339
|
+
|
1340
|
+
# load model
|
1341
|
+
model = load_model(bm.path_to_model, custom_objects=custom_objects)
|
1227
1342
|
|
1228
1343
|
# check if data can be loaded blockwise to save host memory
|
1229
1344
|
load_blockwise = False
|
1230
1345
|
if not bm.scaling and not bm.normalize and bm.path_to_image and not np.any(region_of_interest) and \
|
1231
1346
|
os.path.splitext(bm.path_to_image)[1] in ['.tif', '.tiff'] and not bm.acwe:
|
1347
|
+
|
1232
1348
|
# get image shape
|
1233
1349
|
tif = TiffFile(bm.path_to_image)
|
1234
1350
|
zsh = len(tif.pages)
|
1235
1351
|
ysh, xsh = tif.pages[0].shape
|
1236
1352
|
|
1237
1353
|
# load mask
|
1238
|
-
if bm.separation or bm.refinement:
|
1354
|
+
'''if bm.separation or bm.refinement:
|
1239
1355
|
mask, _ = load_data(bm.mask)
|
1240
1356
|
mask = mask.reshape(zsh, ysh, xsh, 1)
|
1241
1357
|
mask, _, _, _ = append_ghost_areas(bm, mask)
|
1242
|
-
mask = mask.reshape(mask.shape[:-1])
|
1358
|
+
mask = mask.reshape(mask.shape[:-1])'''
|
1243
1359
|
|
1244
1360
|
# determine new image size after appending ghost areas to make image dimensions divisible by patch size
|
1245
1361
|
z_rest = bm.z_patch - (zsh % bm.z_patch)
|
@@ -1259,7 +1375,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1259
1375
|
xsh += x_rest
|
1260
1376
|
|
1261
1377
|
# get Ids of patches
|
1262
|
-
list_IDs = []
|
1378
|
+
'''list_IDs = []
|
1263
1379
|
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
1264
1380
|
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1265
1381
|
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
@@ -1269,19 +1385,18 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1269
1385
|
if centerLabel>0 and np.any(patch!=centerLabel):
|
1270
1386
|
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1271
1387
|
elif bm.refinement:
|
1272
|
-
|
1273
|
-
if np.any(patch==0) and np.any(patch!=0):
|
1388
|
+
if np.any(mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
|
1274
1389
|
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1275
1390
|
else:
|
1276
|
-
list_IDs.append(k*ysh*xsh+l*xsh+m)
|
1391
|
+
list_IDs.append(k*ysh*xsh+l*xsh+m)'''
|
1277
1392
|
|
1278
1393
|
# make length of list divisible by batch size
|
1279
|
-
max_i = len(list_IDs)
|
1394
|
+
'''max_i = len(list_IDs)
|
1280
1395
|
rest = bm.batch_size - (len(list_IDs) % bm.batch_size)
|
1281
|
-
list_IDs = list_IDs + list_IDs[:rest]
|
1396
|
+
list_IDs = list_IDs + list_IDs[:rest]'''
|
1282
1397
|
|
1283
1398
|
# prediction
|
1284
|
-
if
|
1399
|
+
if zsh*ysh*xsh > 256**3:
|
1285
1400
|
load_blockwise = True
|
1286
1401
|
|
1287
1402
|
# load image data and calculate patch IDs
|
@@ -1289,7 +1404,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1289
1404
|
|
1290
1405
|
# load prediction data
|
1291
1406
|
img, bm.img_header, z_shape, y_shape, x_shape, region_of_interest, bm.img_data = load_prediction_data(
|
1292
|
-
bm, channels,
|
1407
|
+
bm, channels, normalization_parameters, region_of_interest, bm.img_data, bm.img_header)
|
1293
1408
|
|
1294
1409
|
# append ghost areas
|
1295
1410
|
img, z_rest, y_rest, x_rest = append_ghost_areas(bm, img)
|
@@ -1315,6 +1430,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1315
1430
|
|
1316
1431
|
# load all patches on GPU memory
|
1317
1432
|
if not load_blockwise and nb_patches < 400:
|
1433
|
+
if rank==0:
|
1318
1434
|
|
1319
1435
|
# parameters
|
1320
1436
|
params = {'dim': (bm.z_patch, bm.y_patch, bm.x_patch),
|
@@ -1347,7 +1463,11 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1347
1463
|
|
1348
1464
|
# allocate final probabilities array
|
1349
1465
|
if rank==0 and bm.return_probs:
|
1350
|
-
|
1466
|
+
if load_blockwise:
|
1467
|
+
if not os.path.exists(path_to_probs[:-4]):
|
1468
|
+
os.mkdir(path_to_probs[:-4])
|
1469
|
+
else:
|
1470
|
+
final = np.zeros((zsh, ysh, xsh, nb_labels), dtype=np.float32)
|
1351
1471
|
|
1352
1472
|
# allocate final result array
|
1353
1473
|
if rank==0:
|
@@ -1362,27 +1482,38 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1362
1482
|
else:
|
1363
1483
|
nprocs = ngpus
|
1364
1484
|
if j % ngpus == rank:
|
1485
|
+
|
1365
1486
|
# load blockwise from TIFF
|
1366
1487
|
if load_blockwise:
|
1367
1488
|
img, _, _, _, _, _, _ = load_prediction_data(bm,
|
1368
|
-
channels,
|
1489
|
+
channels, normalization_parameters,
|
1369
1490
|
region_of_interest, bm.img_data, bm.img_header, load_blockwise, z)
|
1370
1491
|
img, _, _, _ = append_ghost_areas(bm, img)
|
1371
1492
|
|
1493
|
+
# load mask block
|
1494
|
+
if bm.separation or bm.refinement:
|
1495
|
+
mask = imread(bm.mask, key=range(z,min(len(tif.pages),z+bm.z_patch)))
|
1496
|
+
# pad zeros to make dimensions divisible by patch dimensions
|
1497
|
+
pad_z = bm.z_patch - mask.shape[0]
|
1498
|
+
pad_y = (bm.y_patch - (mask.shape[1] % bm.y_patch)) % bm.y_patch
|
1499
|
+
pad_x = (bm.x_patch - (mask.shape[2] % bm.x_patch)) % bm.x_patch
|
1500
|
+
pad_width = [(0, pad_z), (0, pad_y), (0, pad_x)]
|
1501
|
+
mask = np.pad(mask, pad_width, mode='constant', constant_values=0)
|
1502
|
+
|
1372
1503
|
# list of IDs
|
1373
1504
|
list_IDs_block = []
|
1374
1505
|
|
1375
1506
|
# get Ids of patches
|
1507
|
+
k = 0 if load_blockwise else z
|
1376
1508
|
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1377
1509
|
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
1378
1510
|
if bm.separation:
|
1379
|
-
centerLabel = mask[
|
1380
|
-
patch = mask[
|
1511
|
+
centerLabel = mask[k+bm.z_patch//2,l+bm.y_patch//2,m+bm.x_patch//2]
|
1512
|
+
patch = mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
|
1381
1513
|
if centerLabel>0 and np.any(patch!=centerLabel):
|
1382
1514
|
list_IDs_block.append(z*ysh*xsh+l*xsh+m)
|
1383
1515
|
elif bm.refinement:
|
1384
|
-
|
1385
|
-
if np.any(patch==0) and np.any(patch!=0):
|
1516
|
+
if np.any(mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]):
|
1386
1517
|
list_IDs_block.append(z*ysh*xsh+l*xsh+m)
|
1387
1518
|
else:
|
1388
1519
|
list_IDs_block.append(z*ysh*xsh+l*xsh+m)
|
@@ -1414,10 +1545,11 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1414
1545
|
# get patch
|
1415
1546
|
tmp_X = img[k:k+bm.z_patch,l:l+bm.y_patch,m:m+bm.x_patch]
|
1416
1547
|
if bm.patch_normalization:
|
1417
|
-
tmp_X =
|
1418
|
-
for
|
1419
|
-
|
1420
|
-
tmp_X[
|
1548
|
+
tmp_X = tmp_X.copy().astype(np.float32)
|
1549
|
+
for ch in range(channels):
|
1550
|
+
mean, std = welford_mean_std(tmp_X[...,ch])
|
1551
|
+
tmp_X[...,ch] -= mean
|
1552
|
+
tmp_X[...,ch] /= max(std, 1e-6)
|
1421
1553
|
X[i] = tmp_X
|
1422
1554
|
|
1423
1555
|
# predict batch
|
@@ -1461,19 +1593,28 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1461
1593
|
# overlap in z direction
|
1462
1594
|
if bm.stride_size < bm.z_patch:
|
1463
1595
|
if j+source>0:
|
1464
|
-
probs[
|
1596
|
+
probs[:-bm.stride_size] += overlap
|
1465
1597
|
overlap = probs[bm.stride_size:].copy()
|
1466
1598
|
|
1467
|
-
#
|
1599
|
+
# block z dimension
|
1468
1600
|
block_z = z_indices[j+source]
|
1469
|
-
if j+source==len(z_indices)-1:
|
1470
|
-
|
1471
|
-
if
|
1472
|
-
final[block_z:block_z+bm.z_patch] = probs
|
1601
|
+
if j+source==len(z_indices)-1: # last block
|
1602
|
+
block_zsh = bm.z_patch
|
1603
|
+
block_z_rest = z_rest if z_rest>0 else -block_zsh
|
1473
1604
|
else:
|
1474
1605
|
block_zsh = min(bm.stride_size, bm.z_patch)
|
1475
|
-
|
1476
|
-
|
1606
|
+
block_z_rest = -block_zsh
|
1607
|
+
|
1608
|
+
# calculate result
|
1609
|
+
label[block_z:block_z+block_zsh] = np.argmax(probs[:block_zsh], axis=-1).astype(np.uint8)
|
1610
|
+
|
1611
|
+
# return probabilities
|
1612
|
+
if bm.return_probs:
|
1613
|
+
if load_blockwise:
|
1614
|
+
block_output = scale_probabilities(probs[:block_zsh])
|
1615
|
+
block_output = block_output[:-block_z_rest,:-y_rest,:-x_rest]
|
1616
|
+
imwrite(path_to_probs[:-4] + f"/block-{j+source}.tif", block_output)
|
1617
|
+
else:
|
1477
1618
|
final[block_z:block_z+block_zsh] = probs[:block_zsh]
|
1478
1619
|
else:
|
1479
1620
|
for i in range(bm.z_patch):
|
@@ -1481,7 +1622,7 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1481
1622
|
if rank==0:
|
1482
1623
|
|
1483
1624
|
# refine mask data with result
|
1484
|
-
if bm.refinement:
|
1625
|
+
'''if bm.refinement:
|
1485
1626
|
# loop over boundary patches
|
1486
1627
|
for i, ID in enumerate(list_IDs):
|
1487
1628
|
if i < max_i:
|
@@ -1490,25 +1631,17 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1490
1631
|
l = rest // xsh
|
1491
1632
|
m = rest % xsh
|
1492
1633
|
mask[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] = label[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch]
|
1493
|
-
label = mask
|
1634
|
+
label = mask'''
|
1494
1635
|
|
1495
|
-
# remove
|
1496
|
-
if bm.return_probs:
|
1636
|
+
# remove ghost areas
|
1637
|
+
if bm.return_probs and not load_blockwise:
|
1497
1638
|
final = final[:-z_rest,:-y_rest,:-x_rest]
|
1498
1639
|
label = label[:-z_rest,:-y_rest,:-x_rest]
|
1499
1640
|
zsh, ysh, xsh = label.shape
|
1500
1641
|
|
1501
1642
|
# return probabilities
|
1502
|
-
if bm.return_probs:
|
1503
|
-
|
1504
|
-
nb = 0
|
1505
|
-
for k in range(0, zsh-bm.z_patch+1, bm.stride_size):
|
1506
|
-
for l in range(0, ysh-bm.y_patch+1, bm.stride_size):
|
1507
|
-
for m in range(0, xsh-bm.x_patch+1, bm.stride_size):
|
1508
|
-
counter[k:k+bm.z_patch, l:l+bm.y_patch, m:m+bm.x_patch] += 1
|
1509
|
-
nb += 1
|
1510
|
-
counter[counter==0] = 1
|
1511
|
-
probabilities = final / counter
|
1643
|
+
if bm.return_probs and not load_blockwise:
|
1644
|
+
probabilities = scale_probabilities(final)
|
1512
1645
|
if bm.scaling:
|
1513
1646
|
probabilities = img_resize(probabilities, z_shape, y_shape, x_shape)
|
1514
1647
|
if np.any(region_of_interest):
|
@@ -1571,17 +1704,8 @@ def predict_segmentation(bm, region_of_interest, channels, normalization_paramet
|
|
1571
1704
|
# save result
|
1572
1705
|
if bm.path_to_image:
|
1573
1706
|
save_data(bm.path_to_final, label, header=bm.header, compress=bm.compression)
|
1574
|
-
|
1575
|
-
|
1576
|
-
filename, bm.extension = os.path.splitext(bm.path_to_final)
|
1577
|
-
if bm.extension == '.gz':
|
1578
|
-
bm.extension = '.nii.gz'
|
1579
|
-
filename = filename[:-4]
|
1580
|
-
path_to_cleaned = filename + '.cleaned' + bm.extension
|
1581
|
-
path_to_filled = filename + '.filled' + bm.extension
|
1582
|
-
path_to_cleaned_filled = filename + '.cleaned.filled' + bm.extension
|
1583
|
-
path_to_refined = filename + '.refined' + bm.extension
|
1584
|
-
path_to_acwe = filename + '.acwe' + bm.extension
|
1707
|
+
if bm.return_probs and not load_blockwise:
|
1708
|
+
imwrite(path_to_probs, probabilities)
|
1585
1709
|
|
1586
1710
|
# remove outliers
|
1587
1711
|
if bm.clean:
|