biomedisa 24.8.10__py3-none-any.whl → 25.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- biomedisa/deeplearning.py +35 -9
- biomedisa/features/DataGenerator.py +192 -123
- biomedisa/features/PredictDataGenerator.py +7 -5
- biomedisa/features/biomedisa_helper.py +59 -14
- biomedisa/features/crop_helper.py +7 -7
- biomedisa/features/keras_helper.py +281 -157
- biomedisa/features/random_walk/rw_large.py +6 -2
- biomedisa/features/random_walk/rw_small.py +7 -3
- biomedisa/features/remove_outlier.py +3 -3
- biomedisa/features/split_volume.py +12 -11
- biomedisa/interpolation.py +6 -9
- biomedisa/mesh.py +2 -2
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info}/METADATA +3 -2
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info}/RECORD +17 -17
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info}/WHEEL +1 -1
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info/licenses}/LICENSE +0 -0
- {biomedisa-24.8.10.dist-info → biomedisa-25.6.1.dist-info}/top_level.txt +0 -0
biomedisa/deeplearning.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/python3
|
2
2
|
##########################################################################
|
3
3
|
## ##
|
4
|
-
## Copyright (c) 2019-
|
4
|
+
## Copyright (c) 2019-2025 Philipp Lösel. All rights reserved. ##
|
5
5
|
## ##
|
6
6
|
## This file is part of the open source project biomedisa. ##
|
7
7
|
## ##
|
@@ -67,7 +67,7 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
|
|
67
67
|
balance=False, crop_data=False, flip_x=False, flip_y=False, flip_z=False,
|
68
68
|
swapaxes=False, train_dice=False, val_dice=True, compression=True, ignore='none', only='all',
|
69
69
|
network_filters='32-64-128-256-512', resnet=False, debug_cropping=False,
|
70
|
-
save_cropped=False, epochs=100, normalization=True, rotate=0.0, validation_split=0.0,
|
70
|
+
save_cropped=False, epochs=100, normalization=True, rotate=0.0, rotate3d=0.0, validation_split=0.0,
|
71
71
|
learning_rate=0.01, stride_size=32, validation_stride_size=32, validation_freq=1,
|
72
72
|
batch_size=None, x_scale=256, y_scale=256, z_scale=256, scaling=True, early_stopping=0,
|
73
73
|
pretrained_model=None, fine_tune=False, workers=1, cropping_epochs=50,
|
@@ -77,7 +77,8 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
|
|
77
77
|
z_patch=64, y_patch=64, x_patch=64, path_to_logfile=None, img_id=None, label_id=None,
|
78
78
|
remote=False, queue=0, username=None, shortfilename=None, dice_loss=False,
|
79
79
|
acwe=False, acwe_alpha=1.0, acwe_smooth=1, acwe_steps=3, clean=None, fill=None,
|
80
|
-
separation=False, mask=None, refinement=False
|
80
|
+
separation=False, mask=None, refinement=False, ignore_mask=False, mixed_precision=False,
|
81
|
+
slicer=False, path_to_data=None):
|
81
82
|
|
82
83
|
# create biomedisa
|
83
84
|
bm = Biomedisa()
|
@@ -91,6 +92,7 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
|
|
91
92
|
key_copy = tuple(locals().keys())
|
92
93
|
for arg in key_copy:
|
93
94
|
bm.__dict__[arg] = locals()[arg]
|
95
|
+
bm.path_to_data = bm.path_to_images
|
94
96
|
|
95
97
|
# normalization
|
96
98
|
bm.normalize = 1 if bm.normalization else 0
|
@@ -214,10 +216,16 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
|
|
214
216
|
|
215
217
|
if bm.predict:
|
216
218
|
|
217
|
-
#
|
218
|
-
|
219
|
-
|
220
|
-
|
219
|
+
# load model
|
220
|
+
try:
|
221
|
+
hf = h5py.File(bm.path_to_model, 'r')
|
222
|
+
meta = hf.get('meta')
|
223
|
+
configuration = meta.get('configuration')
|
224
|
+
bm.allLabels = np.array(meta.get('labels'))
|
225
|
+
except:
|
226
|
+
raise RuntimeError("Invalid model.")
|
227
|
+
|
228
|
+
# get configuration
|
221
229
|
channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, mu, sig = np.array(configuration)[:]
|
222
230
|
channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, mu, sig = int(channels), int(bm.x_scale), \
|
223
231
|
int(bm.y_scale), int(bm.z_scale), int(bm.normalize), float(mu), float(sig)
|
@@ -225,7 +233,6 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
|
|
225
233
|
normalization_parameters = np.array(meta['normalization'], dtype=float)
|
226
234
|
else:
|
227
235
|
normalization_parameters = np.array([[mu],[sig]])
|
228
|
-
bm.allLabels = np.array(meta.get('labels'))
|
229
236
|
if 'patch_normalization' in meta:
|
230
237
|
bm.patch_normalization = bool(meta['patch_normalization'][()])
|
231
238
|
if 'scaling' in meta:
|
@@ -436,6 +443,8 @@ if __name__ == '__main__':
|
|
436
443
|
help='Disable normalization of 3D image volumes')
|
437
444
|
parser.add_argument('-r','--rotate', type=float, default=0.0,
|
438
445
|
help='Randomly rotate during training')
|
446
|
+
parser.add_argument('--rotate3d', action='store_true', default=False,
|
447
|
+
help='Randomly rotate through three dimensions during training. Uniformly distributed over the sphere.')
|
439
448
|
parser.add_argument('-vs','--validation_split', type=float, default=0.0,
|
440
449
|
help='Percentage of data used for training')
|
441
450
|
parser.add_argument('-lr','--learning_rate', type=float, default=0.01,
|
@@ -504,9 +513,20 @@ if __name__ == '__main__':
|
|
504
513
|
help='Save data in formats like NRRD or TIFF using --extension=".nrrd"')
|
505
514
|
parser.add_argument('-ptm','--path_to_model', type=str, metavar='PATH', default=None,
|
506
515
|
help='Specify the model location for training')
|
516
|
+
parser.add_argument('-im','--ignore_mask', action='store_true', default=False,
|
517
|
+
help='Use a binary mask in the second channel of the label file to define ignored (0) and considered (1) areas during training')
|
518
|
+
parser.add_argument('-mp','--mixed_precision', action='store_true', default=False,
|
519
|
+
help='Use mixed precision in model')
|
520
|
+
parser.add_argument('--slicer', action='store_true', default=False,
|
521
|
+
help='Required for starting Biomedisa from 3D Slicer')
|
507
522
|
bm = parser.parse_args()
|
508
523
|
bm.success = True
|
509
524
|
|
525
|
+
if bm.rotate3d and not bm.scaling:
|
526
|
+
raise RuntimeError("You cannot do true 3d rotation without rescaling the data yet.")
|
527
|
+
# To fix this, have the loading function pass in a list of where the images end,
|
528
|
+
# and use that to figure out the z-centres in biomedisa.features.DataGenerator.rotate_*_3d
|
529
|
+
|
510
530
|
# prediction or training
|
511
531
|
if not any([bm.train, bm.predict]):
|
512
532
|
bm.predict = False
|
@@ -536,6 +556,12 @@ if __name__ == '__main__':
|
|
536
556
|
bm.django_env = False
|
537
557
|
|
538
558
|
kwargs = vars(bm)
|
559
|
+
bm.path_to_data = bm.path_to_images
|
560
|
+
|
561
|
+
# verify model
|
562
|
+
if bm.predict and os.path.splitext(bm.path)[1] != '.h5':
|
563
|
+
bm = _error_(bm, "Invalid model.")
|
564
|
+
raise RuntimeError("Invalid model.")
|
539
565
|
|
540
566
|
# train or predict segmentation
|
541
567
|
try:
|
@@ -554,5 +580,5 @@ if __name__ == '__main__':
|
|
554
580
|
bm = _error_(bm, 'GPU out of memory. Reduce your batch size')
|
555
581
|
except Exception as e:
|
556
582
|
print(traceback.format_exc())
|
557
|
-
bm = _error_(bm, e)
|
583
|
+
bm = _error_(bm, str(e))
|
558
584
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
##########################################################################
|
2
2
|
## ##
|
3
|
-
## Copyright (c) 2019-
|
3
|
+
## Copyright (c) 2019-2025 Philipp Lösel. All rights reserved. ##
|
4
4
|
## ##
|
5
5
|
## This file is part of the open source project biomedisa. ##
|
6
6
|
## ##
|
@@ -26,11 +26,30 @@
|
|
26
26
|
## ##
|
27
27
|
##########################################################################
|
28
28
|
|
29
|
+
from biomedisa.features.biomedisa_helper import welford_mean_std
|
29
30
|
import numpy as np
|
30
31
|
import tensorflow as tf
|
31
32
|
import numba
|
32
33
|
import random
|
33
|
-
|
34
|
+
|
35
|
+
def get_random_rotation_matrix(batch_size):
|
36
|
+
angle_xy = np.random.uniform(0, 2. * np.pi, batch_size)
|
37
|
+
Householder_1 = np.random.uniform(0, 2. * np.pi, batch_size)
|
38
|
+
Householder_2 = np.random.uniform(0, 1., batch_size)
|
39
|
+
# Matrix for xy rotation
|
40
|
+
RR = np.zeros([batch_size, 3, 3])
|
41
|
+
RR[:, 0, 0] = np.cos(angle_xy[:])
|
42
|
+
RR[:, 0, 1] = np.sin(angle_xy[:])
|
43
|
+
RR[:, 1, 0] = np.cos(angle_xy[:])
|
44
|
+
RR[:, 1, 1] = -np.sin(angle_xy[:])
|
45
|
+
RR[:, 2, 2] = 1.
|
46
|
+
# Householder matrix
|
47
|
+
vv = np.zeros([batch_size, 3, 1])
|
48
|
+
vv[:, 0, 0] = np.cos(Householder_1[:]) * np.sqrt(Householder_2[:])
|
49
|
+
vv[:, 1, 0] = np.sin(Householder_1[:]) * np.sqrt(Householder_2[:])
|
50
|
+
vv[:, 2, 0] = np.sqrt(1. - Householder_2[:])
|
51
|
+
HH = np.eye(3)[np.newaxis, :, :] - 2. * np.matmul(vv, vv.transpose(0, 2, 1))
|
52
|
+
return -np.matmul(HH, RR)
|
34
53
|
|
35
54
|
@numba.jit(nopython=True)#parallel=True
|
36
55
|
def rotate_img_patch(src,trg,k,l,m,cos_a,sin_a,z_patch,y_patch,x_patch,imageHeight,imageWidth):
|
@@ -61,6 +80,56 @@ def rotate_img_patch(src,trg,k,l,m,cos_a,sin_a,z_patch,y_patch,x_patch,imageHeig
|
|
61
80
|
trg[z-k,y-l,x-m] = val
|
62
81
|
return trg
|
63
82
|
|
83
|
+
@numba.jit(nopython=True)#parallel=True
|
84
|
+
def rotate_img_patch_3d(src,trg,k,l,m,rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,z_patch,y_patch,x_patch,imageVertStride,imageDepth,imageHeight,imageWidth):
|
85
|
+
#return rotate_label_patch_3d(src,trg,k,l,m,rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,z_patch,y_patch,x_patch,imageVertStride,imageDepth,imageHeight,imageWidth)
|
86
|
+
for z in range(k,k+z_patch):
|
87
|
+
zCentreRotation = imageVertStride * (z // imageVertStride) + imageVertStride/2
|
88
|
+
zA = z - zCentreRotation
|
89
|
+
for y in range(l,l+y_patch):
|
90
|
+
yA = y - imageHeight/2
|
91
|
+
for x in range(m,m+x_patch):
|
92
|
+
xA = x - imageWidth/2
|
93
|
+
xR = xA * rm_xx + yA * rm_xy + zA * rm_xz
|
94
|
+
yR = xA * rm_yx + yA * rm_yy + zA * rm_yz
|
95
|
+
zR = xA * rm_zx + yA * rm_zy + zA * rm_zz
|
96
|
+
src_x = xR + imageWidth/2
|
97
|
+
src_y = yR + imageHeight/2
|
98
|
+
src_z = zR + zCentreRotation
|
99
|
+
# bilinear interpolation
|
100
|
+
src_x0 = float(int(src_x))
|
101
|
+
src_x1 = src_x0 + 1
|
102
|
+
src_y0 = float(int(src_y))
|
103
|
+
src_y1 = src_y0 + 1
|
104
|
+
src_z0 = float(int(src_z))
|
105
|
+
src_z1 = src_z0 + 1
|
106
|
+
sx = src_x - src_x0
|
107
|
+
sy = src_y - src_y0
|
108
|
+
sz = src_z - src_z0
|
109
|
+
idx_src_x0 = int(min(max(0,src_x0),imageWidth-1))
|
110
|
+
idx_src_x1 = int(min(max(0,src_x1),imageWidth-1))
|
111
|
+
idx_src_y0 = int(min(max(0,src_y0),imageHeight-1))
|
112
|
+
idx_src_y1 = int(min(max(0,src_y1),imageHeight-1))
|
113
|
+
idx_src_z0 = int(min(max(0,src_z0),imageDepth-1))
|
114
|
+
idx_src_z1 = int(min(max(0,src_z1),imageDepth-1))
|
115
|
+
|
116
|
+
val = (1-sy) * (1-sx) * (1-sz) * float(src[idx_src_z0,idx_src_y0,idx_src_x0])
|
117
|
+
val += (sy) * (1-sx) * (1-sz) * float(src[idx_src_z0,idx_src_y1,idx_src_x0])
|
118
|
+
val += (1-sy) * (sx) * (1-sz) * float(src[idx_src_z0,idx_src_y0,idx_src_x1])
|
119
|
+
val += (sy) * (sx) * (1-sz) * float(src[idx_src_z0,idx_src_y1,idx_src_x1])
|
120
|
+
val += (1-sy) * (1-sx) * (sz) * float(src[idx_src_z1,idx_src_y0,idx_src_x0])
|
121
|
+
val += (sy) * (1-sx) * (sz) * float(src[idx_src_z1,idx_src_y1,idx_src_x0])
|
122
|
+
val += (1-sy) * (sx) * (sz) * float(src[idx_src_z1,idx_src_y0,idx_src_x1])
|
123
|
+
val += (sy) * (sx) * (sz) * float(src[idx_src_z1,idx_src_y1,idx_src_x1])
|
124
|
+
trg[z-k,y-l,x-m] = val
|
125
|
+
return trg
|
126
|
+
|
127
|
+
# This exists so I could test it. It's not called because sometimes numba is funny about nested
|
128
|
+
# nopython functions.
|
129
|
+
@numba.jit(nopython=True)#parallel=True
|
130
|
+
def centre_of_z_rotation(z, imageVertStride):
|
131
|
+
return imageVertStride * (z // imageVertStride) + imageVertStride/2
|
132
|
+
|
64
133
|
@numba.jit(nopython=True)#parallel=True
|
65
134
|
def rotate_label_patch(src,trg,k,l,m,cos_a,sin_a,z_patch,y_patch,x_patch,imageHeight,imageWidth):
|
66
135
|
for y in range(l,l+y_patch):
|
@@ -80,34 +149,36 @@ def rotate_label_patch(src,trg,k,l,m,cos_a,sin_a,z_patch,y_patch,x_patch,imageHe
|
|
80
149
|
trg[z-k,y-l,x-m] = src[z,idx_src_y,idx_src_x]
|
81
150
|
return trg
|
82
151
|
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
152
|
+
@numba.jit(nopython=True)#parallel=True
|
153
|
+
def rotate_label_patch_3d(src,trg,k,l,m,rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,z_patch,y_patch,x_patch,imageVertStride,imageDepth,imageHeight,imageWidth):
|
154
|
+
for z in range(k,k+z_patch):
|
155
|
+
zCentreRotation = imageVertStride * (z // imageVertStride) + imageVertStride/2
|
156
|
+
zA = z - zCentreRotation
|
157
|
+
for y in range(l,l+y_patch):
|
158
|
+
yA = y - imageHeight/2
|
159
|
+
for x in range(m,m+x_patch):
|
160
|
+
xA = x - imageWidth/2
|
161
|
+
xR = xA * rm_xx + yA * rm_xy + zA * rm_xz
|
162
|
+
yR = xA * rm_yx + yA * rm_yy + zA * rm_yz
|
163
|
+
zR = xA * rm_zx + yA * rm_zy + zA * rm_zz
|
164
|
+
src_x = xR + imageWidth/2
|
165
|
+
src_y = yR + imageHeight/2
|
166
|
+
src_z = zR + zCentreRotation
|
167
|
+
# nearest neighbour
|
168
|
+
src_x = round(src_x)
|
169
|
+
src_y = round(src_y)
|
170
|
+
src_z = round(src_z)
|
171
|
+
idx_src_x = int(min(max(0,src_x),imageWidth-1))
|
172
|
+
idx_src_y = int(min(max(0,src_y),imageHeight-1))
|
173
|
+
idx_src_z = int(min(max(0,src_z),imageDepth-1))
|
174
|
+
trg[z-k,y-l,x-m] = src[idx_src_z,idx_src_y,idx_src_x]
|
175
|
+
return trg
|
106
176
|
|
107
177
|
class DataGenerator(tf.keras.utils.Sequence):
|
108
178
|
'Generates data for Keras'
|
109
|
-
def __init__(self, img, label, list_IDs_fg, list_IDs_bg, shuffle, train,
|
110
|
-
dim_img=(32,32,32), n_classes=10, n_channels=1, augment=(False,False,False,False,0
|
179
|
+
def __init__(self, img, label, list_IDs_fg, list_IDs_bg, shuffle, train, batch_size=32, dim=(32,32,32),
|
180
|
+
dim_img=(32,32,32), n_classes=10, n_channels=1, augment=(False,False,False,False,0,False),
|
181
|
+
patch_normalization=False, separation=False, ignore_mask=False):
|
111
182
|
'Initialization'
|
112
183
|
self.dim = dim
|
113
184
|
self.dim_img = dim_img
|
@@ -121,10 +192,10 @@ class DataGenerator(tf.keras.utils.Sequence):
|
|
121
192
|
self.shuffle = shuffle
|
122
193
|
self.augment = augment
|
123
194
|
self.train = train
|
124
|
-
self.classification = classification
|
125
195
|
self.on_epoch_end()
|
126
196
|
self.patch_normalization = patch_normalization
|
127
197
|
self.separation = separation
|
198
|
+
self.ignore_mask = ignore_mask
|
128
199
|
|
129
200
|
def __len__(self):
|
130
201
|
'Denotes the number of batches per epoch'
|
@@ -181,15 +252,15 @@ class DataGenerator(tf.keras.utils.Sequence):
|
|
181
252
|
def __data_generation(self, list_IDs_temp):
|
182
253
|
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
|
183
254
|
|
184
|
-
#
|
255
|
+
# number of label channels
|
256
|
+
label_channels = 2 if self.ignore_mask else 1
|
257
|
+
|
258
|
+
# allocate memory
|
185
259
|
X = np.empty((self.batch_size, *self.dim, self.n_channels), dtype=np.float32)
|
186
|
-
|
187
|
-
y = np.empty((self.batch_size, 1), dtype=np.int32)
|
188
|
-
else:
|
189
|
-
y = np.empty((self.batch_size, *self.dim, 1), dtype=np.int32)
|
260
|
+
y = np.empty((self.batch_size, *self.dim, label_channels), dtype=np.int32)
|
190
261
|
|
191
262
|
# get augmentation parameter
|
192
|
-
flip_x, flip_y, flip_z, swapaxes, rotate = self.augment
|
263
|
+
flip_x, flip_y, flip_z, swapaxes, rotate, rotate3d = self.augment
|
193
264
|
n_aug = np.sum([flip_z, flip_y, flip_x])
|
194
265
|
flips = np.where([flip_z, flip_y, flip_x])[0]
|
195
266
|
|
@@ -198,6 +269,8 @@ class DataGenerator(tf.keras.utils.Sequence):
|
|
198
269
|
angle = np.random.uniform(-1,1,self.batch_size) * 3.1416/180*rotate
|
199
270
|
cos_angle = np.cos(angle)
|
200
271
|
sin_angle = np.sin(angle)
|
272
|
+
if rotate3d:
|
273
|
+
rot_mtx = get_random_rotation_matrix(self.batch_size)
|
201
274
|
|
202
275
|
# Generate data
|
203
276
|
for i, ID in enumerate(list_IDs_temp):
|
@@ -209,96 +282,92 @@ class DataGenerator(tf.keras.utils.Sequence):
|
|
209
282
|
m = rest % self.dim_img[2]
|
210
283
|
|
211
284
|
# get patch
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
tmp_X = np.swapaxes(tmp_X,1,2)
|
238
|
-
|
239
|
-
# assign to batch
|
240
|
-
X[i,:,:,:,0] = tmp_X
|
241
|
-
y[i,0] = tmp_y
|
242
|
-
|
243
|
-
else:
|
244
|
-
# get patch
|
245
|
-
tmp_X = self.img[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
|
246
|
-
tmp_y = self.label[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
|
247
|
-
|
248
|
-
# center label gets value 1
|
249
|
-
if self.separation:
|
250
|
-
centerLabel = tmp_y[self.dim[0]//2,self.dim[1]//2,self.dim[2]//2]
|
251
|
-
tmp_y = tmp_y.copy()
|
252
|
-
tmp_y[tmp_y!=centerLabel]=0
|
253
|
-
tmp_y[tmp_y==centerLabel]=1
|
254
|
-
|
255
|
-
# augmentation
|
256
|
-
if self.train:
|
257
|
-
|
258
|
-
# rotate in xy plane
|
259
|
-
if rotate:
|
260
|
-
tmp_X = np.empty((*self.dim, self.n_channels), dtype=np.float32)
|
261
|
-
tmp_y = np.empty(self.dim, dtype=np.int32)
|
262
|
-
cos_a = cos_angle[i]
|
263
|
-
sin_a = sin_angle[i]
|
264
|
-
for c in range(self.n_channels):
|
265
|
-
tmp_X[:,:,:,c] = rotate_img_patch(self.img[:,:,:,c],tmp_X[:,:,:,c],k,l,m,cos_a,sin_a,
|
266
|
-
self.dim[0],self.dim[1],self.dim[2],
|
267
|
-
self.dim_img[1],self.dim_img[2])
|
268
|
-
tmp_y = rotate_label_patch(self.label,tmp_y,k,l,m,cos_a,sin_a,
|
285
|
+
tmp_X = self.img[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
|
286
|
+
tmp_y = self.label[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
|
287
|
+
|
288
|
+
# center label gets value 1
|
289
|
+
if self.separation:
|
290
|
+
centerLabel = tmp_y[self.dim[0]//2,self.dim[1]//2,self.dim[2]//2]
|
291
|
+
tmp_y = tmp_y.copy()
|
292
|
+
tmp_y[tmp_y!=centerLabel]=0
|
293
|
+
tmp_y[tmp_y==centerLabel]=1
|
294
|
+
|
295
|
+
# augmentation
|
296
|
+
if self.train:
|
297
|
+
|
298
|
+
# rotate in xy plane
|
299
|
+
if rotate:
|
300
|
+
tmp_X = np.empty((*self.dim, self.n_channels), dtype=np.float32)
|
301
|
+
tmp_y = np.empty((*self.dim, label_channels), dtype=np.int32)
|
302
|
+
cos_a = cos_angle[i]
|
303
|
+
sin_a = sin_angle[i]
|
304
|
+
for ch in range(self.n_channels):
|
305
|
+
tmp_X[...,ch] = rotate_img_patch(self.img[...,ch],tmp_X[...,ch],k,l,m,cos_a,sin_a,
|
306
|
+
self.dim[0],self.dim[1],self.dim[2],
|
307
|
+
self.dim_img[1],self.dim_img[2])
|
308
|
+
for ch in range(label_channels):
|
309
|
+
tmp_y[...,ch] = rotate_label_patch(self.label[...,ch],tmp_y[...,ch],k,l,m,cos_a,sin_a,
|
269
310
|
self.dim[0],self.dim[1],self.dim[2],
|
270
311
|
self.dim_img[1],self.dim_img[2])
|
271
312
|
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
313
|
+
# rotate through a random 3d angle, uniformly distributed on a sphere.
|
314
|
+
if rotate3d:
|
315
|
+
tmp_X = np.empty((*self.dim, self.n_channels), dtype=np.float32)
|
316
|
+
tmp_y = np.empty((*self.dim, label_channels), dtype=np.int32)
|
317
|
+
rm_xx = rot_mtx[i, 0, 0]
|
318
|
+
rm_xy = rot_mtx[i, 0, 1]
|
319
|
+
rm_xz = rot_mtx[i, 0, 2]
|
320
|
+
rm_yx = rot_mtx[i, 1, 0]
|
321
|
+
rm_yy = rot_mtx[i, 1, 1]
|
322
|
+
rm_yz = rot_mtx[i, 1, 2]
|
323
|
+
rm_zx = rot_mtx[i, 2, 0]
|
324
|
+
rm_zy = rot_mtx[i, 2, 1]
|
325
|
+
rm_zz = rot_mtx[i, 2, 2]
|
326
|
+
for ch in range(self.n_channels):
|
327
|
+
tmp_X[...,ch] = rotate_img_patch_3d(self.img[...,ch],tmp_X[...,ch],k,l,m,
|
328
|
+
rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,
|
329
|
+
self.dim[0],self.dim[1],self.dim[2],
|
330
|
+
256, self.dim_img[0],self.dim_img[1],self.dim_img[2])
|
331
|
+
for ch in range(label_channels):
|
332
|
+
tmp_y[...,ch] = rotate_label_patch_3d(self.label[...,ch],tmp_y[...,ch],k,l,m,
|
333
|
+
rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,
|
334
|
+
self.dim[0],self.dim[1],self.dim[2],
|
335
|
+
256, self.dim_img[0],self.dim_img[1],self.dim_img[2])
|
336
|
+
|
337
|
+
# flip patch along axes
|
338
|
+
v = np.random.randint(n_aug+1)
|
339
|
+
if np.any([flip_x, flip_y, flip_z]) and v>0:
|
340
|
+
flip = flips[v-1]
|
341
|
+
tmp_X = np.flip(tmp_X, flip)
|
342
|
+
tmp_y = np.flip(tmp_y, flip)
|
343
|
+
|
344
|
+
# swap axes
|
345
|
+
if swapaxes:
|
346
|
+
v = np.random.randint(4)
|
347
|
+
if v==1:
|
348
|
+
tmp_X = np.swapaxes(tmp_X,0,1)
|
349
|
+
tmp_y = np.swapaxes(tmp_y,0,1)
|
350
|
+
elif v==2:
|
351
|
+
tmp_X = np.swapaxes(tmp_X,0,2)
|
352
|
+
tmp_y = np.swapaxes(tmp_y,0,2)
|
353
|
+
elif v==3:
|
354
|
+
tmp_X = np.swapaxes(tmp_X,1,2)
|
355
|
+
tmp_y = np.swapaxes(tmp_y,1,2)
|
356
|
+
|
357
|
+
# patch normalization
|
358
|
+
if self.patch_normalization:
|
359
|
+
tmp_X = tmp_X.copy().astype(np.float32)
|
360
|
+
for ch in range(self.n_channels):
|
361
|
+
mean, std = welford_mean_std(tmp_X[...,ch])
|
362
|
+
tmp_X[...,ch] -= mean
|
363
|
+
tmp_X[...,ch] /= max(std, 1e-6)
|
364
|
+
|
365
|
+
# assign to batch
|
366
|
+
X[i] = tmp_X
|
367
|
+
y[i] = tmp_y
|
368
|
+
|
369
|
+
if self.ignore_mask:
|
370
|
+
return X, y
|
371
|
+
else:
|
372
|
+
return X, tf.keras.utils.to_categorical(y, num_classes=self.n_classes)
|
304
373
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
##########################################################################
|
2
2
|
## ##
|
3
|
-
## Copyright (c) 2019-
|
3
|
+
## Copyright (c) 2019-2025 Philipp Lösel. All rights reserved. ##
|
4
4
|
## ##
|
5
5
|
## This file is part of the open source project biomedisa. ##
|
6
6
|
## ##
|
@@ -26,6 +26,7 @@
|
|
26
26
|
## ##
|
27
27
|
##########################################################################
|
28
28
|
|
29
|
+
from biomedisa.features.biomedisa_helper import welford_mean_std
|
29
30
|
import numpy as np
|
30
31
|
import tensorflow as tf
|
31
32
|
|
@@ -77,10 +78,11 @@ class PredictDataGenerator(tf.keras.utils.Sequence):
|
|
77
78
|
# get patch
|
78
79
|
tmp_X = self.img[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
|
79
80
|
if self.patch_normalization:
|
80
|
-
tmp_X =
|
81
|
-
for
|
82
|
-
|
83
|
-
tmp_X[
|
81
|
+
tmp_X = tmp_X.copy().astype(np.float32)
|
82
|
+
for ch in range(self.n_channels):
|
83
|
+
mean, std = welford_mean_std(tmp_X[...,ch])
|
84
|
+
tmp_X[...,ch] -= mean
|
85
|
+
tmp_X[...,ch] /= max(std, 1e-6)
|
84
86
|
X[i] = tmp_X
|
85
87
|
|
86
88
|
return X
|