biomedisa 24.5.23__py3-none-any.whl → 24.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
biomedisa/deeplearning.py CHANGED
@@ -65,11 +65,11 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
65
65
  path_to_images=None, path_to_labels=None, val_images=None, val_labels=None,
66
66
  path_to_model=None, predict=False, train=False, header_file=None,
67
67
  balance=False, crop_data=False, flip_x=False, flip_y=False, flip_z=False,
68
- swapaxes=False, train_dice=False, val_dice=True, no_compression=False, ignore='none', only='all',
68
+ swapaxes=False, train_dice=False, val_dice=True, compression=True, ignore='none', only='all',
69
69
  network_filters='32-64-128-256-512', resnet=False, debug_cropping=False,
70
- save_cropped=False, epochs=100, no_normalization=False, rotate=0.0, validation_split=0.0,
70
+ save_cropped=False, epochs=100, normalization=True, rotate=0.0, validation_split=0.0,
71
71
  learning_rate=0.01, stride_size=32, validation_stride_size=32, validation_freq=1,
72
- batch_size=None, x_scale=256, y_scale=256, z_scale=256, no_scaling=False, early_stopping=0,
72
+ batch_size=None, x_scale=256, y_scale=256, z_scale=256, scaling=True, early_stopping=0,
73
73
  pretrained_model=None, fine_tune=False, workers=1, cropping_epochs=50,
74
74
  x_range=None, y_range=None, z_range=None, header=None, extension='.tif',
75
75
  img_header=None, img_extension='.tif', average_dice=False, django_env=False,
@@ -91,17 +91,13 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
91
91
  for arg in key_copy:
92
92
  bm.__dict__[arg] = locals()[arg]
93
93
 
94
- # compression
95
- if bm.no_compression:
96
- bm.compression = False
97
- else:
98
- bm.compression = True
99
-
100
94
  # normalization
101
- if bm.no_normalization:
95
+ bm.normalize = 1 if bm.normalization else 0
96
+
97
+ # use patch normalization instead of normalizing the entire volume
98
+ if not bm.scaling:
102
99
  bm.normalize = 0
103
- else:
104
- bm.normalize = 1
100
+ bm.patch_normalization = True
105
101
 
106
102
  # django environment
107
103
  if bm.django_env:
@@ -217,14 +213,19 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
217
213
  hf = h5py.File(bm.path_to_model, 'r')
218
214
  meta = hf.get('meta')
219
215
  configuration = meta.get('configuration')
220
- channels, bm.x_scale, bm.y_scale, bm.z_scale, normalize, mu, sig = np.array(configuration)[:]
221
- channels, bm.x_scale, bm.y_scale, bm.z_scale, normalize, mu, sig = int(channels), int(bm.x_scale), \
222
- int(bm.y_scale), int(bm.z_scale), int(normalize), float(mu), float(sig)
223
- if '/meta/normalization' in hf:
224
- normalization_parameters = np.array(meta.get('normalization'), dtype=float)
216
+ channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, mu, sig = np.array(configuration)[:]
217
+ channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.normalize, mu, sig = int(channels), int(bm.x_scale), \
218
+ int(bm.y_scale), int(bm.z_scale), int(bm.normalize), float(mu), float(sig)
219
+ if 'normalization' in meta:
220
+ normalization_parameters = np.array(meta['normalization'], dtype=float)
225
221
  else:
226
222
  normalization_parameters = np.array([[mu],[sig]])
227
- allLabels = np.array(meta.get('labels'))
223
+ bm.allLabels = np.array(meta.get('labels'))
224
+ if 'patch_normalization' in meta:
225
+ bm.patch_normalization = bool(meta['patch_normalization'][()])
226
+ if 'scaling' in meta:
227
+ bm.scaling = bool(meta['scaling'][()])
228
+
228
229
  # check if amira header is available in the network
229
230
  if header is None and meta.get('header') is not None:
230
231
  header = [np.array(meta.get('header'))]
@@ -290,16 +291,11 @@ def deep_learning(img_data, label_data=None, val_img_data=None, val_label_data=N
290
291
  region_of_interest, cropped_volume = ch.crop_data(bm.path_to_image, bm.path_to_model, bm.path_to_cropped_image,
291
292
  bm.batch_size, bm.debug_cropping, bm.save_cropped, img_data, bm.x_range, bm.y_range, bm.z_range)
292
293
 
293
- # load prediction data
294
- img, img_header, z_shape, y_shape, x_shape, region_of_interest, img_data = load_prediction_data(bm.path_to_image,
295
- channels, bm.x_scale, bm.y_scale, bm.z_scale, bm.no_scaling, normalize, normalization_parameters,
296
- region_of_interest, img_data, img_header)
297
-
298
294
  # make prediction
299
- results, bm = predict_semantic_segmentation(bm, img, bm.path_to_model,
300
- bm.z_patch, bm.y_patch, bm.x_patch, z_shape, y_shape, x_shape, bm.compression, header,
301
- img_header, bm.stride_size, allLabels, bm.batch_size, region_of_interest,
302
- bm.no_scaling, extension, img_data)
295
+ results, bm = predict_semantic_segmentation(bm,
296
+ header, img_header,
297
+ region_of_interest, extension, img_data,
298
+ channels, normalization_parameters)
303
299
 
304
300
  # results
305
301
  if cropped_volume is not None:
@@ -403,7 +399,7 @@ if __name__ == '__main__':
403
399
  help='Dice loss function')
404
400
  parser.add_argument('-ad','--average_dice', action='store_true', default=False,
405
401
  help='Use averaged dice score of each label')
406
- parser.add_argument('-nc', '--no_compression', action='store_true', default=False,
402
+ parser.add_argument('-nc', '--no-compression', dest='compression', action='store_false',
407
403
  help='Disable compression of segmentation results')
408
404
  parser.add_argument('-i', '--ignore', type=str, default='none',
409
405
  help='Ignore specific label(s), e.g. 2,5,6')
@@ -421,12 +417,12 @@ if __name__ == '__main__':
421
417
  help='Epochs the network is trained')
422
418
  parser.add_argument('-ce','--cropping_epochs', type=int, default=50,
423
419
  help='Epochs the network for auto-cropping is trained')
424
- parser.add_argument('-nn','--no_normalization', action='store_true', default=False,
425
- help='Disable image normalization')
420
+ parser.add_argument('-nn','--no-normalization', dest='normalization', action='store_false',
421
+ help='Disable normalization of 3D image volumes')
426
422
  parser.add_argument('-r','--rotate', type=float, default=0.0,
427
423
  help='Randomly rotate during training')
428
424
  parser.add_argument('-vs','--validation_split', type=float, default=0.0,
429
- help='Percentage of data used for validation')
425
+ help='Percentage of data used for training')
430
426
  parser.add_argument('-lr','--learning_rate', type=float, default=0.01,
431
427
  help='Learning rate')
432
428
  parser.add_argument('-ss','--stride_size', metavar="[1-64]", type=int, choices=range(1,65), default=32,
@@ -447,7 +443,7 @@ if __name__ == '__main__':
447
443
  help='Images and labels are scaled at y-axis to this size before training')
448
444
  parser.add_argument('-zs','--z_scale', type=int, default=256,
449
445
  help='Images and labels are scaled at z-axis to this size before training')
450
- parser.add_argument('-ns','--no_scaling', action='store_true', default=False,
446
+ parser.add_argument('-ns','--no-scaling', dest='scaling', action='store_false',
451
447
  help='Do not resize image and label data')
452
448
  parser.add_argument('-es','--early_stopping', type=int, default=0,
453
449
  help='Training is terminated when the accuracy has not increased in the epochs defined by this')
@@ -483,9 +479,18 @@ if __name__ == '__main__':
483
479
  help='Processing queue when using a remote server')
484
480
  parser.add_argument('-hf','--header_file', type=str, metavar='PATH', default=None,
485
481
  help='Location of header file')
482
+ parser.add_argument('-ext','--extension', type=str, default='.tif',
483
+ help='Save data for example as NRRD file using --extension=".nrrd"')
486
484
  bm = parser.parse_args()
487
-
488
485
  bm.success = True
486
+
487
+ # prediction or training
488
+ if not any([bm.train, bm.predict]):
489
+ bm.predict = False
490
+ bm.train = True
491
+ if os.path.splitext(bm.path)[1] == '.h5':
492
+ bm.predict = True
493
+ bm.train = False
489
494
  if bm.predict:
490
495
  bm.path_to_labels = None
491
496
  bm.path_to_model = bm.path
@@ -286,7 +286,7 @@ class DataGenerator(tf.keras.utils.Sequence):
286
286
 
287
287
  # patch normalization
288
288
  if self.patch_normalization:
289
- tmp_X = np.copy(tmp_X, order='C')
289
+ tmp_X = tmp_X.copy()
290
290
  for c in range(self.n_channels):
291
291
  tmp_X[:,:,:,c] -= np.mean(tmp_X[:,:,:,c])
292
292
  tmp_X[:,:,:,c] /= max(np.std(tmp_X[:,:,:,c]), 1e-6)
@@ -106,7 +106,7 @@ def reduce_blocksize(raw, slices):
106
106
  return raw, slices, argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x
107
107
 
108
108
  def activeContour(data, labelData, alpha=1.0, smooth=1, steps=3,
109
- path_to_data=None, path_to_labels=None, no_compression=False,
109
+ path_to_data=None, path_to_labels=None, compression=True,
110
110
  ignore='none', only='all', simple=False,
111
111
  img_id=None, friend_id=None, remote=False):
112
112
 
@@ -126,12 +126,6 @@ def activeContour(data, labelData, alpha=1.0, smooth=1, steps=3,
126
126
  else:
127
127
  bm.django_env = False
128
128
 
129
- # compression
130
- if bm.no_compression:
131
- bm.compression = False
132
- else:
133
- bm.compression = True
134
-
135
129
  # disable file saving when called as a function
136
130
  if bm.data is not None:
137
131
  bm.path_to_data = None
@@ -374,8 +368,7 @@ def init_active_contour(image_id, friend_id, label_id, simple=False):
374
368
  else:
375
369
  try:
376
370
  activeContour(None, None, path_to_data=image.pic.path, path_to_labels=friend.pic.path,
377
- alpha=label.ac_alpha, smooth=label.ac_smooth, steps=label.ac_steps,
378
- no_compression=(False if label.compression else True),
371
+ alpha=label.ac_alpha, smooth=label.ac_smooth, steps=label.ac_steps, compression=label.compression,
379
372
  simple=simple, img_id=image_id, friend_id=friend_id, remote=False)
380
373
  except Exception as e:
381
374
  print(traceback.format_exc())
@@ -407,7 +400,7 @@ if __name__ == '__main__':
407
400
  help='Number of smoothing steps')
408
401
  parser.add_argument('-st', '--steps', type=int, default=3,
409
402
  help='Number of iterations')
410
- parser.add_argument('-nc', '--no_compression', action='store_true', default=False,
403
+ parser.add_argument('-nc', '--no-compression', dest='compression', action='store_false',
411
404
  help='Disable compression of segmentation results')
412
405
  parser.add_argument('-i', '--ignore', type=str, default='none',
413
406
  help='Ignore specific label(s), e.g. 2,5,6')
@@ -317,19 +317,23 @@ def load_data(path_to_data, process='None', return_extension=False):
317
317
  data, header = None, None
318
318
  else:
319
319
  try:
320
- # remove unreadable files or directories
321
- for name in files:
322
- if os.path.isfile(name):
320
+ # load data slice by slice
321
+ file_names = []
322
+ img_slices = []
323
+ header = []
324
+ files.sort()
325
+ for file_name in files:
326
+ if os.path.isfile(file_name):
323
327
  try:
324
- img, _ = load(name)
328
+ img, img_header = load(file_name)
329
+ file_names.append(file_name)
330
+ img_slices.append(img)
331
+ header.append(img_header)
325
332
  except:
326
- files.remove(name)
327
- else:
328
- files.remove(name)
329
- files.sort()
333
+ pass
330
334
 
331
335
  # get data size
332
- img, _ = load(files[0])
336
+ img = img_slices[0]
333
337
  if len(img.shape)==3:
334
338
  ysh, xsh, csh = img.shape[0], img.shape[1], img.shape[2]
335
339
  channel = 'last'
@@ -340,11 +344,9 @@ def load_data(path_to_data, process='None', return_extension=False):
340
344
  ysh, xsh = img.shape[0], img.shape[1]
341
345
  csh, channel = 0, None
342
346
 
343
- # load data slice by slice
344
- data = np.empty((len(files), ysh, xsh), dtype=img.dtype)
345
- header, image_data_shape = [], []
346
- for k, file_name in enumerate(files):
347
- img, img_header = load(file_name)
347
+ # create 3D volume
348
+ data = np.empty((len(file_names), ysh, xsh), dtype=img.dtype)
349
+ for k, img in enumerate(img_slices):
348
350
  if csh==3:
349
351
  img = rgb2gray(img, channel)
350
352
  elif csh==1 and channel=='last':
@@ -352,8 +354,7 @@ def load_data(path_to_data, process='None', return_extension=False):
352
354
  elif csh==1 and channel=='first':
353
355
  img = img[0,:,:]
354
356
  data[k] = img
355
- header.append(img_header)
356
- header = [header, files, data.dtype]
357
+ header = [header, file_names, data.dtype]
357
358
  data = np.swapaxes(data, 1, 2)
358
359
  data = np.copy(data, order='C')
359
360
  except Exception as e:
@@ -415,25 +416,33 @@ def pre_processing(bm):
415
416
  if bm.labelData is None:
416
417
  return _error_(bm, 'Invalid label data.')
417
418
 
419
+ # dimension errors
418
420
  if len(bm.labelData.shape) != 3:
419
- return _error_(bm, 'Label must be three-dimensional.')
420
-
421
+ return _error_(bm, 'Label data must be three-dimensional.')
421
422
  if bm.data.shape != bm.labelData.shape:
422
- return _error_(bm, 'Image and label must have the same x,y,z-dimensions.')
423
+ return _error_(bm, 'Image and label data must have the same x,y,z-dimensions.')
424
+
425
+ # label data type
426
+ if bm.labelData.dtype in ['float16','float32','float64']:
427
+ if bm.django_env:
428
+ return _error_(bm, 'Label data must be of integer type.')
429
+ print(f'Warning: Potential label loss during conversion from {bm.labelData.dtype} to int32.')
430
+ bm.labelData = bm.labelData.astype(np.int32)
423
431
 
424
432
  # get labels
425
433
  bm.allLabels = np.unique(bm.labelData)
426
434
  index = np.argwhere(bm.allLabels<0)
427
435
  bm.allLabels = np.delete(bm.allLabels, index)
428
436
 
429
- if bm.django_env and np.any(bm.allLabels > 255):
430
- return _error_(bm, 'No labels higher than 255 allowed.')
431
-
437
+ # labels greater than 255
432
438
  if np.any(bm.allLabels > 255):
433
- bm.labelData[bm.labelData > 255] = 0
434
- index = np.argwhere(bm.allLabels > 255)
435
- bm.allLabels = np.delete(bm.allLabels, index)
436
- print('Warning: Only labels <=255 are allowed. Labels higher than 255 will be removed.')
439
+ if bm.django_env:
440
+ return _error_(bm, 'No labels greater than 255 allowed.')
441
+ else:
442
+ bm.labelData[bm.labelData > 255] = 0
443
+ index = np.argwhere(bm.allLabels > 255)
444
+ bm.allLabels = np.delete(bm.allLabels, index)
445
+ print('Warning: Only labels <=255 are allowed. Labels greater than 255 will be removed.')
437
446
 
438
447
  # add background label if not existing
439
448
  if not np.any(bm.allLabels==0):
@@ -485,7 +494,8 @@ def save_data(path_to_final, final, header=None, final_image_type=None, compress
485
494
  np_to_nc(path_to_final, final, header)
486
495
  elif final_image_type in ['.hdr', '.mhd', '.mha', '.nrrd', '.nii', '.nii.gz']:
487
496
  simg = sitk.GetImageFromArray(final)
488
- simg.CopyInformation(header)
497
+ if header is not None:
498
+ simg.CopyInformation(header)
489
499
  sitk.WriteImage(simg, path_to_final, useCompression=compress)
490
500
  elif final_image_type in ['.zip', 'directory', '']:
491
501
  with tempfile.TemporaryDirectory() as temp_dir:
@@ -151,6 +151,7 @@ def create_slices(path_to_data, path_to_label, on_site=False):
151
151
  # increase contrast
152
152
  raw = img_to_uint8(raw)
153
153
  raw = contrast(raw)
154
+ zsh, ysh, xsh = raw.shape
154
155
 
155
156
  # create slices for slice viewer
156
157
  if not os.path.isdir(path_to_slices):
@@ -160,9 +161,9 @@ def create_slices(path_to_data, path_to_label, on_site=False):
160
161
  os.chmod(path_to_slices, 0o770)
161
162
 
162
163
  # save slices
163
- for k in range(raw.shape[0]):
164
+ for k in range(zsh):
164
165
  im = Image.fromarray(raw[k])
165
- im.save(path_to_slices + f'/{k}.png')
166
+ im.save(path_to_slices + '/slice_' + str(k).zfill(len(str(zsh-1))) + '.png')
166
167
 
167
168
  if path_to_label and not os.path.isdir(path_to_label_slices):
168
169
 
@@ -263,7 +264,7 @@ def create_slices(path_to_data, path_to_label, on_site=False):
263
264
 
264
265
  # save slice
265
266
  im = Image.fromarray(out)
266
- im.save(path_to_label_slices + f'/{k}.png')
267
+ im.save(path_to_label_slices + '/slice_' + str(k).zfill(len(str(zsh-1))) + '.png')
267
268
 
268
269
  except Exception as e:
269
270
  print(e)
@@ -542,7 +542,8 @@ def load_and_train(normalize,path_to_img,path_to_labels,path_to_model,
542
542
  cropping_weights.append(arr)
543
543
 
544
544
  # configuration data
545
- cropping_config = np.array([channels, x_scale, y_scale, z_scale, normalize, 0, 1])
545
+ cropping_config = np.array([channels, x_scale, y_scale, z_scale, normalize,
546
+ normalization_parameters[0,0], normalization_parameters[1,0]])
546
547
 
547
548
  return cropping_weights, cropping_config, normalization_parameters
548
549