biomedisa 24.8.10__py3-none-any.whl → 25.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
biomedisa/deeplearning.py CHANGED
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/python3
2
2
  ##########################################################################
3
3
  ## ##
4
- ## Copyright (c) 2019-2024 Philipp Lösel. All rights reserved. ##
4
+ ## Copyright (c) 2019-2025 Philipp Lösel. All rights reserved. ##
5
5
  ## ##
6
6
  ## This file is part of the open source project biomedisa. ##
7
7
  ## ##
@@ -67,7 +67,7 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
67
67
  balance=False, crop_data=False, flip_x=False, flip_y=False, flip_z=False,
68
68
  swapaxes=False, train_dice=False, val_dice=True, compression=True, ignore='none', only='all',
69
69
  network_filters='32-64-128-256-512', resnet=False, debug_cropping=False,
70
- save_cropped=False, epochs=100, normalization=True, rotate=0.0, validation_split=0.0,
70
+ save_cropped=False, epochs=100, normalization=True, rotate=0.0, rotate3d=0.0, validation_split=0.0,
71
71
  learning_rate=0.01, stride_size=32, validation_stride_size=32, validation_freq=1,
72
72
  batch_size=None, x_scale=256, y_scale=256, z_scale=256, scaling=True, early_stopping=0,
73
73
  pretrained_model=None, fine_tune=False, workers=1, cropping_epochs=50,
@@ -77,7 +77,8 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
77
77
  z_patch=64, y_patch=64, x_patch=64, path_to_logfile=None, img_id=None, label_id=None,
78
78
  remote=False, queue=0, username=None, shortfilename=None, dice_loss=False,
79
79
  acwe=False, acwe_alpha=1.0, acwe_smooth=1, acwe_steps=3, clean=None, fill=None,
80
- separation=False, mask=None, refinement=False):
80
+ separation=False, mask=None, refinement=False, ignore_mask=False, mixed_precision=False,
81
+ slicer=False, path_to_data=None):
81
82
 
82
83
  # create biomedisa
83
84
  bm = Biomedisa()
@@ -91,6 +92,7 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
91
92
  key_copy = tuple(locals().keys())
92
93
  for arg in key_copy:
93
94
  bm.__dict__[arg] = locals()[arg]
95
+ bm.path_to_data = bm.path_to_images
94
96
 
95
97
  # normalization
96
98
  bm.normalize = 1 if bm.normalization else 0
@@ -214,10 +216,16 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
214
216
 
215
217
  if bm.predict:
216
218
 
217
- # get meta data
218
- hf = h5py.File(bm.path_to_model, 'r')
219
- meta = hf.get('meta')
220
- configuration = meta.get('configuration')
219
+ # load model
220
+ try:
221
+ hf = h5py.File(bm.path_to_model, 'r')
222
+ meta = hf.get('meta')
223
+ configuration = meta.get('configuration')
224
+ bm.allLabels = np.array(meta.get('labels'))
225
+ except:
226
+ raise RuntimeError("Invalid model.")
227
+
228
+ # get configuration
221
229
  channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, mu, sig = np.array(configuration)[:]
222
230
  channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, mu, sig = int(channels), int(bm.x_scale), \
223
231
  int(bm.y_scale), int(bm.z_scale), int(bm.normalize), float(mu), float(sig)
@@ -225,7 +233,6 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
225
233
  normalization_parameters = np.array(meta['normalization'], dtype=float)
226
234
  else:
227
235
  normalization_parameters = np.array([[mu],[sig]])
228
- bm.allLabels = np.array(meta.get('labels'))
229
236
  if 'patch_normalization' in meta:
230
237
  bm.patch_normalization = bool(meta['patch_normalization'][()])
231
238
  if 'scaling' in meta:
@@ -436,6 +443,8 @@ if __name__ == '__main__':
436
443
  help='Disable normalization of 3D image volumes')
437
444
  parser.add_argument('-r','--rotate', type=float, default=0.0,
438
445
  help='Randomly rotate during training')
446
+ parser.add_argument('--rotate3d', action='store_true', default=False,
447
+ help='Randomly rotate through three dimensions during training. Uniformly distributed over the sphere.')
439
448
  parser.add_argument('-vs','--validation_split', type=float, default=0.0,
440
449
  help='Percentage of data used for training')
441
450
  parser.add_argument('-lr','--learning_rate', type=float, default=0.01,
@@ -504,9 +513,20 @@ if __name__ == '__main__':
504
513
  help='Save data in formats like NRRD or TIFF using --extension=".nrrd"')
505
514
  parser.add_argument('-ptm','--path_to_model', type=str, metavar='PATH', default=None,
506
515
  help='Specify the model location for training')
516
+ parser.add_argument('-im','--ignore_mask', action='store_true', default=False,
517
+ help='Use a binary mask in the second channel of the label file to define ignored (0) and considered (1) areas during training')
518
+ parser.add_argument('-mp','--mixed_precision', action='store_true', default=False,
519
+ help='Use mixed precision in model')
520
+ parser.add_argument('--slicer', action='store_true', default=False,
521
+ help='Required for starting Biomedisa from 3D Slicer')
507
522
  bm = parser.parse_args()
508
523
  bm.success = True
509
524
 
525
+ if bm.rotate3d and not bm.scaling:
526
+ raise RuntimeError("You cannot do true 3d rotation without rescaling the data yet.")
527
+ # To fix this, have the loading function pass in a list of where the images end,
528
+ # and use that to figure out the z-centres in biomedisa.features.DataGenerator.rotate_*_3d
529
+
510
530
  # prediction or training
511
531
  if not any([bm.train, bm.predict]):
512
532
  bm.predict = False
@@ -536,6 +556,12 @@ if __name__ == '__main__':
536
556
  bm.django_env = False
537
557
 
538
558
  kwargs = vars(bm)
559
+ bm.path_to_data = bm.path_to_images
560
+
561
+ # verify model
562
+ if bm.predict and os.path.splitext(bm.path)[1] != '.h5':
563
+ bm = _error_(bm, "Invalid model.")
564
+ raise RuntimeError("Invalid model.")
539
565
 
540
566
  # train or predict segmentation
541
567
  try:
@@ -554,5 +580,5 @@ if __name__ == '__main__':
554
580
  bm = _error_(bm, 'GPU out of memory. Reduce your batch size')
555
581
  except Exception as e:
556
582
  print(traceback.format_exc())
557
- bm = _error_(bm, e)
583
+ bm = _error_(bm, str(e))
558
584
 
@@ -1,6 +1,6 @@
1
1
  ##########################################################################
2
2
  ## ##
3
- ## Copyright (c) 2019-2024 Philipp Lösel. All rights reserved. ##
3
+ ## Copyright (c) 2019-2025 Philipp Lösel. All rights reserved. ##
4
4
  ## ##
5
5
  ## This file is part of the open source project biomedisa. ##
6
6
  ## ##
@@ -26,11 +26,30 @@
26
26
  ## ##
27
27
  ##########################################################################
28
28
 
29
+ from biomedisa.features.biomedisa_helper import welford_mean_std
29
30
  import numpy as np
30
31
  import tensorflow as tf
31
32
  import numba
32
33
  import random
33
- import scipy
34
+
35
+ def get_random_rotation_matrix(batch_size):
36
+ angle_xy = np.random.uniform(0, 2. * np.pi, batch_size)
37
+ Householder_1 = np.random.uniform(0, 2. * np.pi, batch_size)
38
+ Householder_2 = np.random.uniform(0, 1., batch_size)
39
+ # Matrix for xy rotation
40
+ RR = np.zeros([batch_size, 3, 3])
41
+ RR[:, 0, 0] = np.cos(angle_xy[:])
42
+ RR[:, 0, 1] = np.sin(angle_xy[:])
43
+ RR[:, 1, 0] = np.cos(angle_xy[:])
44
+ RR[:, 1, 1] = -np.sin(angle_xy[:])
45
+ RR[:, 2, 2] = 1.
46
+ # Householder matrix
47
+ vv = np.zeros([batch_size, 3, 1])
48
+ vv[:, 0, 0] = np.cos(Householder_1[:]) * np.sqrt(Householder_2[:])
49
+ vv[:, 1, 0] = np.sin(Householder_1[:]) * np.sqrt(Householder_2[:])
50
+ vv[:, 2, 0] = np.sqrt(1. - Householder_2[:])
51
+ HH = np.eye(3)[np.newaxis, :, :] - 2. * np.matmul(vv, vv.transpose(0, 2, 1))
52
+ return -np.matmul(HH, RR)
34
53
 
35
54
  @numba.jit(nopython=True)#parallel=True
36
55
  def rotate_img_patch(src,trg,k,l,m,cos_a,sin_a,z_patch,y_patch,x_patch,imageHeight,imageWidth):
@@ -61,6 +80,56 @@ def rotate_img_patch(src,trg,k,l,m,cos_a,sin_a,z_patch,y_patch,x_patch,imageHeig
61
80
  trg[z-k,y-l,x-m] = val
62
81
  return trg
63
82
 
83
+ @numba.jit(nopython=True)#parallel=True
84
+ def rotate_img_patch_3d(src,trg,k,l,m,rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,z_patch,y_patch,x_patch,imageVertStride,imageDepth,imageHeight,imageWidth):
85
+ #return rotate_label_patch_3d(src,trg,k,l,m,rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,z_patch,y_patch,x_patch,imageVertStride,imageDepth,imageHeight,imageWidth)
86
+ for z in range(k,k+z_patch):
87
+ zCentreRotation = imageVertStride * (z // imageVertStride) + imageVertStride/2
88
+ zA = z - zCentreRotation
89
+ for y in range(l,l+y_patch):
90
+ yA = y - imageHeight/2
91
+ for x in range(m,m+x_patch):
92
+ xA = x - imageWidth/2
93
+ xR = xA * rm_xx + yA * rm_xy + zA * rm_xz
94
+ yR = xA * rm_yx + yA * rm_yy + zA * rm_yz
95
+ zR = xA * rm_zx + yA * rm_zy + zA * rm_zz
96
+ src_x = xR + imageWidth/2
97
+ src_y = yR + imageHeight/2
98
+ src_z = zR + zCentreRotation
99
+ # bilinear interpolation
100
+ src_x0 = float(int(src_x))
101
+ src_x1 = src_x0 + 1
102
+ src_y0 = float(int(src_y))
103
+ src_y1 = src_y0 + 1
104
+ src_z0 = float(int(src_z))
105
+ src_z1 = src_z0 + 1
106
+ sx = src_x - src_x0
107
+ sy = src_y - src_y0
108
+ sz = src_z - src_z0
109
+ idx_src_x0 = int(min(max(0,src_x0),imageWidth-1))
110
+ idx_src_x1 = int(min(max(0,src_x1),imageWidth-1))
111
+ idx_src_y0 = int(min(max(0,src_y0),imageHeight-1))
112
+ idx_src_y1 = int(min(max(0,src_y1),imageHeight-1))
113
+ idx_src_z0 = int(min(max(0,src_z0),imageDepth-1))
114
+ idx_src_z1 = int(min(max(0,src_z1),imageDepth-1))
115
+
116
+ val = (1-sy) * (1-sx) * (1-sz) * float(src[idx_src_z0,idx_src_y0,idx_src_x0])
117
+ val += (sy) * (1-sx) * (1-sz) * float(src[idx_src_z0,idx_src_y1,idx_src_x0])
118
+ val += (1-sy) * (sx) * (1-sz) * float(src[idx_src_z0,idx_src_y0,idx_src_x1])
119
+ val += (sy) * (sx) * (1-sz) * float(src[idx_src_z0,idx_src_y1,idx_src_x1])
120
+ val += (1-sy) * (1-sx) * (sz) * float(src[idx_src_z1,idx_src_y0,idx_src_x0])
121
+ val += (sy) * (1-sx) * (sz) * float(src[idx_src_z1,idx_src_y1,idx_src_x0])
122
+ val += (1-sy) * (sx) * (sz) * float(src[idx_src_z1,idx_src_y0,idx_src_x1])
123
+ val += (sy) * (sx) * (sz) * float(src[idx_src_z1,idx_src_y1,idx_src_x1])
124
+ trg[z-k,y-l,x-m] = val
125
+ return trg
126
+
127
+ # This exists so I could test it. It's not called because sometimes numba is funny about nested
128
+ # nopython functions.
129
+ @numba.jit(nopython=True)#parallel=True
130
+ def centre_of_z_rotation(z, imageVertStride):
131
+ return imageVertStride * (z // imageVertStride) + imageVertStride/2
132
+
64
133
  @numba.jit(nopython=True)#parallel=True
65
134
  def rotate_label_patch(src,trg,k,l,m,cos_a,sin_a,z_patch,y_patch,x_patch,imageHeight,imageWidth):
66
135
  for y in range(l,l+y_patch):
@@ -80,34 +149,36 @@ def rotate_label_patch(src,trg,k,l,m,cos_a,sin_a,z_patch,y_patch,x_patch,imageHe
80
149
  trg[z-k,y-l,x-m] = src[z,idx_src_y,idx_src_x]
81
150
  return trg
82
151
 
83
- def random_rotation_3d(image, max_angle=180):
84
- """ Randomly rotate an image by a random angle (-max_angle, max_angle).
85
-
86
- Arguments:
87
- max_angle: `float`. The maximum rotation angle.
88
-
89
- Returns:
90
- batch of rotated 3D images
91
- """
92
-
93
- # rotate along x-axis
94
- angle = random.uniform(-max_angle, max_angle)
95
- image2 = scipy.ndimage.rotate(image, angle, mode='nearest', axes=(0, 1), reshape=False)
96
-
97
- # rotate along y-axis
98
- angle = random.uniform(-max_angle, max_angle)
99
- image3 = scipy.ndimage.rotate(image2, angle, mode='nearest', axes=(0, 2), reshape=False)
100
-
101
- # rotate along z-axis
102
- angle = random.uniform(-max_angle, max_angle)
103
- image_rot = scipy.ndimage.rotate(image3, angle, mode='nearest', axes=(1, 2), reshape=False)
104
-
105
- return image_rot
152
+ @numba.jit(nopython=True)#parallel=True
153
+ def rotate_label_patch_3d(src,trg,k,l,m,rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,z_patch,y_patch,x_patch,imageVertStride,imageDepth,imageHeight,imageWidth):
154
+ for z in range(k,k+z_patch):
155
+ zCentreRotation = imageVertStride * (z // imageVertStride) + imageVertStride/2
156
+ zA = z - zCentreRotation
157
+ for y in range(l,l+y_patch):
158
+ yA = y - imageHeight/2
159
+ for x in range(m,m+x_patch):
160
+ xA = x - imageWidth/2
161
+ xR = xA * rm_xx + yA * rm_xy + zA * rm_xz
162
+ yR = xA * rm_yx + yA * rm_yy + zA * rm_yz
163
+ zR = xA * rm_zx + yA * rm_zy + zA * rm_zz
164
+ src_x = xR + imageWidth/2
165
+ src_y = yR + imageHeight/2
166
+ src_z = zR + zCentreRotation
167
+ # nearest neighbour
168
+ src_x = round(src_x)
169
+ src_y = round(src_y)
170
+ src_z = round(src_z)
171
+ idx_src_x = int(min(max(0,src_x),imageWidth-1))
172
+ idx_src_y = int(min(max(0,src_y),imageHeight-1))
173
+ idx_src_z = int(min(max(0,src_z),imageDepth-1))
174
+ trg[z-k,y-l,x-m] = src[idx_src_z,idx_src_y,idx_src_x]
175
+ return trg
106
176
 
107
177
  class DataGenerator(tf.keras.utils.Sequence):
108
178
  'Generates data for Keras'
109
- def __init__(self, img, label, list_IDs_fg, list_IDs_bg, shuffle, train, classification, batch_size=32, dim=(32,32,32),
110
- dim_img=(32,32,32), n_classes=10, n_channels=1, augment=(False,False,False,False,0), patch_normalization=False, separation=False):
179
+ def __init__(self, img, label, list_IDs_fg, list_IDs_bg, shuffle, train, batch_size=32, dim=(32,32,32),
180
+ dim_img=(32,32,32), n_classes=10, n_channels=1, augment=(False,False,False,False,0,False),
181
+ patch_normalization=False, separation=False, ignore_mask=False):
111
182
  'Initialization'
112
183
  self.dim = dim
113
184
  self.dim_img = dim_img
@@ -121,10 +192,10 @@ class DataGenerator(tf.keras.utils.Sequence):
121
192
  self.shuffle = shuffle
122
193
  self.augment = augment
123
194
  self.train = train
124
- self.classification = classification
125
195
  self.on_epoch_end()
126
196
  self.patch_normalization = patch_normalization
127
197
  self.separation = separation
198
+ self.ignore_mask = ignore_mask
128
199
 
129
200
  def __len__(self):
130
201
  'Denotes the number of batches per epoch'
@@ -181,15 +252,15 @@ class DataGenerator(tf.keras.utils.Sequence):
181
252
  def __data_generation(self, list_IDs_temp):
182
253
  'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
183
254
 
184
- # Initialization
255
+ # number of label channels
256
+ label_channels = 2 if self.ignore_mask else 1
257
+
258
+ # allocate memory
185
259
  X = np.empty((self.batch_size, *self.dim, self.n_channels), dtype=np.float32)
186
- if self.classification:
187
- y = np.empty((self.batch_size, 1), dtype=np.int32)
188
- else:
189
- y = np.empty((self.batch_size, *self.dim, 1), dtype=np.int32)
260
+ y = np.empty((self.batch_size, *self.dim, label_channels), dtype=np.int32)
190
261
 
191
262
  # get augmentation parameter
192
- flip_x, flip_y, flip_z, swapaxes, rotate = self.augment
263
+ flip_x, flip_y, flip_z, swapaxes, rotate, rotate3d = self.augment
193
264
  n_aug = np.sum([flip_z, flip_y, flip_x])
194
265
  flips = np.where([flip_z, flip_y, flip_x])[0]
195
266
 
@@ -198,6 +269,8 @@ class DataGenerator(tf.keras.utils.Sequence):
198
269
  angle = np.random.uniform(-1,1,self.batch_size) * 3.1416/180*rotate
199
270
  cos_angle = np.cos(angle)
200
271
  sin_angle = np.sin(angle)
272
+ if rotate3d:
273
+ rot_mtx = get_random_rotation_matrix(self.batch_size)
201
274
 
202
275
  # Generate data
203
276
  for i, ID in enumerate(list_IDs_temp):
@@ -209,96 +282,92 @@ class DataGenerator(tf.keras.utils.Sequence):
209
282
  m = rest % self.dim_img[2]
210
283
 
211
284
  # get patch
212
- if self.classification:
213
- tmp_X = self.img[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
214
- tmp_y = self.label[k,l,m]
215
-
216
- # augmentation
217
- if self.train:
218
-
219
- # rotate in 3D
220
- if rotate:
221
- tmp_X = random_rotation_3d(tmp_X, max_angle=rotate)
222
-
223
- # flip patch along axes
224
- v = np.random.randint(n_aug+1)
225
- if np.any([flip_x, flip_y, flip_z]) and v>0:
226
- flip = flips[v-1]
227
- tmp_X = np.flip(tmp_X, flip)
228
-
229
- # swap axes
230
- if swapaxes:
231
- v = np.random.randint(4)
232
- if v==1:
233
- tmp_X = np.swapaxes(tmp_X,0,1)
234
- elif v==2:
235
- tmp_X = np.swapaxes(tmp_X,0,2)
236
- elif v==3:
237
- tmp_X = np.swapaxes(tmp_X,1,2)
238
-
239
- # assign to batch
240
- X[i,:,:,:,0] = tmp_X
241
- y[i,0] = tmp_y
242
-
243
- else:
244
- # get patch
245
- tmp_X = self.img[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
246
- tmp_y = self.label[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
247
-
248
- # center label gets value 1
249
- if self.separation:
250
- centerLabel = tmp_y[self.dim[0]//2,self.dim[1]//2,self.dim[2]//2]
251
- tmp_y = tmp_y.copy()
252
- tmp_y[tmp_y!=centerLabel]=0
253
- tmp_y[tmp_y==centerLabel]=1
254
-
255
- # augmentation
256
- if self.train:
257
-
258
- # rotate in xy plane
259
- if rotate:
260
- tmp_X = np.empty((*self.dim, self.n_channels), dtype=np.float32)
261
- tmp_y = np.empty(self.dim, dtype=np.int32)
262
- cos_a = cos_angle[i]
263
- sin_a = sin_angle[i]
264
- for c in range(self.n_channels):
265
- tmp_X[:,:,:,c] = rotate_img_patch(self.img[:,:,:,c],tmp_X[:,:,:,c],k,l,m,cos_a,sin_a,
266
- self.dim[0],self.dim[1],self.dim[2],
267
- self.dim_img[1],self.dim_img[2])
268
- tmp_y = rotate_label_patch(self.label,tmp_y,k,l,m,cos_a,sin_a,
285
+ tmp_X = self.img[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
286
+ tmp_y = self.label[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
287
+
288
+ # center label gets value 1
289
+ if self.separation:
290
+ centerLabel = tmp_y[self.dim[0]//2,self.dim[1]//2,self.dim[2]//2]
291
+ tmp_y = tmp_y.copy()
292
+ tmp_y[tmp_y!=centerLabel]=0
293
+ tmp_y[tmp_y==centerLabel]=1
294
+
295
+ # augmentation
296
+ if self.train:
297
+
298
+ # rotate in xy plane
299
+ if rotate:
300
+ tmp_X = np.empty((*self.dim, self.n_channels), dtype=np.float32)
301
+ tmp_y = np.empty((*self.dim, label_channels), dtype=np.int32)
302
+ cos_a = cos_angle[i]
303
+ sin_a = sin_angle[i]
304
+ for ch in range(self.n_channels):
305
+ tmp_X[...,ch] = rotate_img_patch(self.img[...,ch],tmp_X[...,ch],k,l,m,cos_a,sin_a,
306
+ self.dim[0],self.dim[1],self.dim[2],
307
+ self.dim_img[1],self.dim_img[2])
308
+ for ch in range(label_channels):
309
+ tmp_y[...,ch] = rotate_label_patch(self.label[...,ch],tmp_y[...,ch],k,l,m,cos_a,sin_a,
269
310
  self.dim[0],self.dim[1],self.dim[2],
270
311
  self.dim_img[1],self.dim_img[2])
271
312
 
272
- # flip patch along axes
273
- v = np.random.randint(n_aug+1)
274
- if np.any([flip_x, flip_y, flip_z]) and v>0:
275
- flip = flips[v-1]
276
- tmp_X = np.flip(tmp_X, flip)
277
- tmp_y = np.flip(tmp_y, flip)
278
-
279
- # swap axes
280
- if swapaxes:
281
- v = np.random.randint(4)
282
- if v==1:
283
- tmp_X = np.swapaxes(tmp_X,0,1)
284
- tmp_y = np.swapaxes(tmp_y,0,1)
285
- elif v==2:
286
- tmp_X = np.swapaxes(tmp_X,0,2)
287
- tmp_y = np.swapaxes(tmp_y,0,2)
288
- elif v==3:
289
- tmp_X = np.swapaxes(tmp_X,1,2)
290
- tmp_y = np.swapaxes(tmp_y,1,2)
291
-
292
- # patch normalization
293
- if self.patch_normalization:
294
- tmp_X = tmp_X.copy().astype(np.float32)
295
- for c in range(self.n_channels):
296
- tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
297
- tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
298
-
299
- # assign to batch
300
- X[i] = tmp_X
301
- y[i,:,:,:,0] = tmp_y
302
-
303
- return X, tf.keras.utils.to_categorical(y, num_classes=self.n_classes)
313
+ # rotate through a random 3d angle, uniformly distributed on a sphere.
314
+ if rotate3d:
315
+ tmp_X = np.empty((*self.dim, self.n_channels), dtype=np.float32)
316
+ tmp_y = np.empty((*self.dim, label_channels), dtype=np.int32)
317
+ rm_xx = rot_mtx[i, 0, 0]
318
+ rm_xy = rot_mtx[i, 0, 1]
319
+ rm_xz = rot_mtx[i, 0, 2]
320
+ rm_yx = rot_mtx[i, 1, 0]
321
+ rm_yy = rot_mtx[i, 1, 1]
322
+ rm_yz = rot_mtx[i, 1, 2]
323
+ rm_zx = rot_mtx[i, 2, 0]
324
+ rm_zy = rot_mtx[i, 2, 1]
325
+ rm_zz = rot_mtx[i, 2, 2]
326
+ for ch in range(self.n_channels):
327
+ tmp_X[...,ch] = rotate_img_patch_3d(self.img[...,ch],tmp_X[...,ch],k,l,m,
328
+ rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,
329
+ self.dim[0],self.dim[1],self.dim[2],
330
+ 256, self.dim_img[0],self.dim_img[1],self.dim_img[2])
331
+ for ch in range(label_channels):
332
+ tmp_y[...,ch] = rotate_label_patch_3d(self.label[...,ch],tmp_y[...,ch],k,l,m,
333
+ rm_xx,rm_xy,rm_xz,rm_yx,rm_yy,rm_yz,rm_zx,rm_zy,rm_zz,
334
+ self.dim[0],self.dim[1],self.dim[2],
335
+ 256, self.dim_img[0],self.dim_img[1],self.dim_img[2])
336
+
337
+ # flip patch along axes
338
+ v = np.random.randint(n_aug+1)
339
+ if np.any([flip_x, flip_y, flip_z]) and v>0:
340
+ flip = flips[v-1]
341
+ tmp_X = np.flip(tmp_X, flip)
342
+ tmp_y = np.flip(tmp_y, flip)
343
+
344
+ # swap axes
345
+ if swapaxes:
346
+ v = np.random.randint(4)
347
+ if v==1:
348
+ tmp_X = np.swapaxes(tmp_X,0,1)
349
+ tmp_y = np.swapaxes(tmp_y,0,1)
350
+ elif v==2:
351
+ tmp_X = np.swapaxes(tmp_X,0,2)
352
+ tmp_y = np.swapaxes(tmp_y,0,2)
353
+ elif v==3:
354
+ tmp_X = np.swapaxes(tmp_X,1,2)
355
+ tmp_y = np.swapaxes(tmp_y,1,2)
356
+
357
+ # patch normalization
358
+ if self.patch_normalization:
359
+ tmp_X = tmp_X.copy().astype(np.float32)
360
+ for ch in range(self.n_channels):
361
+ mean, std = welford_mean_std(tmp_X[...,ch])
362
+ tmp_X[...,ch] -= mean
363
+ tmp_X[...,ch] /= max(std, 1e-6)
364
+
365
+ # assign to batch
366
+ X[i] = tmp_X
367
+ y[i] = tmp_y
368
+
369
+ if self.ignore_mask:
370
+ return X, y
371
+ else:
372
+ return X, tf.keras.utils.to_categorical(y, num_classes=self.n_classes)
304
373
 
@@ -1,6 +1,6 @@
1
1
  ##########################################################################
2
2
  ## ##
3
- ## Copyright (c) 2019-2024 Philipp Lösel. All rights reserved. ##
3
+ ## Copyright (c) 2019-2025 Philipp Lösel. All rights reserved. ##
4
4
  ## ##
5
5
  ## This file is part of the open source project biomedisa. ##
6
6
  ## ##
@@ -26,6 +26,7 @@
26
26
  ## ##
27
27
  ##########################################################################
28
28
 
29
+ from biomedisa.features.biomedisa_helper import welford_mean_std
29
30
  import numpy as np
30
31
  import tensorflow as tf
31
32
 
@@ -77,10 +78,11 @@ class PredictDataGenerator(tf.keras.utils.Sequence):
77
78
  # get patch
78
79
  tmp_X = self.img[k:k+self.dim[0],l:l+self.dim[1],m:m+self.dim[2]]
79
80
  if self.patch_normalization:
80
- tmp_X = np.copy(tmp_X, order='C')
81
- for c in range(self.n_channels):
82
- tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
83
- tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
81
+ tmp_X = tmp_X.copy().astype(np.float32)
82
+ for ch in range(self.n_channels):
83
+ mean, std = welford_mean_std(tmp_X[...,ch])
84
+ tmp_X[...,ch] -= mean
85
+ tmp_X[...,ch] /= max(std, 1e-6)
84
86
  X[i] = tmp_X
85
87
 
86
88
  return X