biomedisa 24.5.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. biomedisa/__init__.py +49 -0
  2. biomedisa/__main__.py +18 -0
  3. biomedisa/deeplearning.py +529 -0
  4. biomedisa/features/DataGenerator.py +299 -0
  5. biomedisa/features/DataGeneratorCrop.py +121 -0
  6. biomedisa/features/PredictDataGenerator.py +87 -0
  7. biomedisa/features/PredictDataGeneratorCrop.py +74 -0
  8. biomedisa/features/__init__.py +0 -0
  9. biomedisa/features/active_contour.py +430 -0
  10. biomedisa/features/amira_to_np/__init__.py +0 -0
  11. biomedisa/features/amira_to_np/amira_data_stream.py +980 -0
  12. biomedisa/features/amira_to_np/amira_grammar.py +369 -0
  13. biomedisa/features/amira_to_np/amira_header.py +290 -0
  14. biomedisa/features/amira_to_np/amira_helper.py +72 -0
  15. biomedisa/features/assd.py +167 -0
  16. biomedisa/features/biomedisa_helper.py +842 -0
  17. biomedisa/features/create_slices.py +277 -0
  18. biomedisa/features/crop_helper.py +581 -0
  19. biomedisa/features/curvop_numba.py +149 -0
  20. biomedisa/features/django_env.py +171 -0
  21. biomedisa/features/keras_helper.py +1195 -0
  22. biomedisa/features/nc_reader.py +179 -0
  23. biomedisa/features/pid.py +52 -0
  24. biomedisa/features/process_image.py +251 -0
  25. biomedisa/features/pycuda_test.py +85 -0
  26. biomedisa/features/random_walk/__init__.py +0 -0
  27. biomedisa/features/random_walk/gpu_kernels.py +184 -0
  28. biomedisa/features/random_walk/pycuda_large.py +826 -0
  29. biomedisa/features/random_walk/pycuda_large_allx.py +806 -0
  30. biomedisa/features/random_walk/pycuda_small.py +414 -0
  31. biomedisa/features/random_walk/pycuda_small_allx.py +493 -0
  32. biomedisa/features/random_walk/pyopencl_large.py +760 -0
  33. biomedisa/features/random_walk/pyopencl_small.py +441 -0
  34. biomedisa/features/random_walk/rw_large.py +389 -0
  35. biomedisa/features/random_walk/rw_small.py +307 -0
  36. biomedisa/features/remove_outlier.py +396 -0
  37. biomedisa/features/split_volume.py +167 -0
  38. biomedisa/interpolation.py +369 -0
  39. biomedisa/mesh.py +403 -0
  40. biomedisa-24.5.23.dist-info/LICENSE +191 -0
  41. biomedisa-24.5.23.dist-info/METADATA +261 -0
  42. biomedisa-24.5.23.dist-info/RECORD +44 -0
  43. biomedisa-24.5.23.dist-info/WHEEL +5 -0
  44. biomedisa-24.5.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,842 @@
1
+ ##########################################################################
2
+ ## ##
3
+ ## Copyright (c) 2019-2024 Philipp Lösel. All rights reserved. ##
4
+ ## ##
5
+ ## This file is part of the open source project biomedisa. ##
6
+ ## ##
7
+ ## Licensed under the European Union Public Licence (EUPL) ##
8
+ ## v1.2, or - as soon as they will be approved by the ##
9
+ ## European Commission - subsequent versions of the EUPL; ##
10
+ ## ##
11
+ ## You may redistribute it and/or modify it under the terms ##
12
+ ## of the EUPL v1.2. You may not use this work except in ##
13
+ ## compliance with this Licence. ##
14
+ ## ##
15
+ ## You can obtain a copy of the Licence at: ##
16
+ ## ##
17
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
18
+ ## ##
19
+ ## Unless required by applicable law or agreed to in ##
20
+ ## writing, software distributed under the Licence is ##
21
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
22
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
23
+ ## ##
24
+ ## See the Licence for the specific language governing ##
25
+ ## permissions and limitations under the Licence. ##
26
+ ## ##
27
+ ##########################################################################
28
+
29
+ import os
30
+ import biomedisa
31
+ from biomedisa.features.amira_to_np.amira_helper import amira_to_np, np_to_amira
32
+ from biomedisa.features.nc_reader import nc_to_np, np_to_nc
33
+ from tifffile import imread, imwrite
34
+ from medpy.io import load, save
35
+ import SimpleITK as sitk
36
+ from PIL import Image
37
+ import numpy as np
38
+ import glob
39
+ import random
40
+ import cv2
41
+ import time
42
+ import zipfile
43
+ import numba
44
+ import subprocess
45
+ import re
46
+ import math
47
+ import tempfile
48
+
49
+ def silent_remove(filename):
50
+ try:
51
+ os.remove(filename)
52
+ except OSError:
53
+ pass
54
+
55
+ # create a unique filename
56
+ def unique_file_path(path, dir_path=biomedisa.BASE_DIR+'/private_storage/'):
57
+
58
+ # get extension
59
+ username = os.path.basename(os.path.dirname(path))
60
+ filename = os.path.basename(path)
61
+ filename, extension = os.path.splitext(filename)
62
+ if extension == '.gz':
63
+ filename, extension = os.path.splitext(filename)
64
+ if extension == '.nii':
65
+ extension = '.nii.gz'
66
+ elif extension == '.tar':
67
+ extension = '.tar.gz'
68
+
69
+ # get suffix
70
+ suffix = re.search("-[0-999]"+extension, path)
71
+ if suffix:
72
+ suffix = suffix.group()
73
+ filename = os.path.basename(path)
74
+ filename = filename[:-len(suffix)]
75
+ i = int(suffix[1:-len(extension)]) + 1
76
+ else:
77
+ suffix = extension
78
+ i = 1
79
+
80
+ # get finaltype
81
+ addon = ''
82
+ for feature in ['.filled','.smooth','.acwe','.cleaned','.8bit','.refined', '.cropped',
83
+ '.uncertainty','.smooth.cleaned','.cleaned.filled','.denoised']:
84
+ if filename[-len(feature):] == feature:
85
+ addon = feature
86
+
87
+ if addon:
88
+ filename = filename[:-len(addon)]
89
+
90
+ # maximum lenght of path
91
+ pic_path = f'images/{username}/{filename}'
92
+ limit = 100 - len(addon) - len(suffix)
93
+ path = dir_path + pic_path[:limit] + addon + suffix
94
+
95
+ # check if file already exists
96
+ file_already_exists = os.path.exists(path)
97
+ while file_already_exists:
98
+ limit = 100 - len(addon) - len('-') - len(str(i)) - len(extension)
99
+ path = dir_path + pic_path[:limit] + addon + '-' + str(i) + extension
100
+ file_already_exists = os.path.exists(path)
101
+ i += 1
102
+
103
+ return path
104
+
105
+ def Dice_score(ground_truth, result, average_dice=False):
106
+ if average_dice:
107
+ dice = 0
108
+ allLabels = np.unique(ground_truth)
109
+ for l in allLabels[1:]:
110
+ dice += 2 * np.logical_and(ground_truth==l, result==l).sum() / float((ground_truth==l).sum() + (result==l).sum())
111
+ dice /= float(len(allLabels)-1)
112
+ else:
113
+ dice = 2 * np.logical_and(ground_truth==result, (ground_truth+result)>0).sum() / \
114
+ float((ground_truth>0).sum() + (result>0).sum())
115
+ return dice
116
+
117
+ def ASSD(ground_truth, result):
118
+ try:
119
+ from biomedisa.features.assd import ASSD_one_label
120
+ number_of_elements = 0
121
+ distances = 0
122
+ hausdorff = 0
123
+ for label in np.unique(ground_truth)[1:]:
124
+ d, n, h = ASSD_one_label(ground_truth, result, label)
125
+ number_of_elements += n
126
+ distances += d
127
+ hausdorff = max(h, hausdorff)
128
+ assd = distances / float(number_of_elements)
129
+ return assd, hausdorff
130
+ except:
131
+ print('Error: no CUDA device found. ASSD is not available.')
132
+ return None, None
133
+
134
+ def img_resize(a, z_shape, y_shape, x_shape, interpolation=None, labels=False):
135
+ if len(a.shape) > 3:
136
+ zsh, ysh, xsh, csh = a.shape
137
+ else:
138
+ zsh, ysh, xsh = a.shape
139
+ if interpolation == None:
140
+ if z_shape < zsh or y_shape < ysh or x_shape < xsh:
141
+ interpolation = cv2.INTER_AREA
142
+ else:
143
+ interpolation = cv2.INTER_CUBIC
144
+
145
+ def __resize__(arr):
146
+ b = np.empty((zsh, y_shape, x_shape), dtype=arr.dtype)
147
+ for k in range(zsh):
148
+ b[k] = cv2.resize(arr[k], (x_shape, y_shape), interpolation=interpolation)
149
+ c = np.empty((y_shape, z_shape, x_shape), dtype=arr.dtype)
150
+ b = np.swapaxes(b, 0, 1)
151
+ b = np.copy(b, order='C')
152
+ for k in range(y_shape):
153
+ c[k] = cv2.resize(b[k], (x_shape, z_shape), interpolation=interpolation)
154
+ c = np.swapaxes(c, 1, 0)
155
+ c = np.copy(c, order='C')
156
+ return c
157
+
158
+ if labels:
159
+ data = np.zeros((z_shape, y_shape, x_shape), dtype=a.dtype)
160
+ for k in np.unique(a):
161
+ if k!=0:
162
+ tmp = np.zeros(a.shape, dtype=np.uint8)
163
+ tmp[a==k] = 1
164
+ tmp = __resize__(tmp)
165
+ data[tmp==1] = k
166
+ elif len(a.shape) > 3:
167
+ data = np.empty((z_shape, y_shape, x_shape, csh), dtype=a.dtype)
168
+ for channel in range(csh):
169
+ data[:,:,:,channel] = __resize__(a[:,:,:,channel])
170
+ else:
171
+ data = __resize__(a)
172
+ return data
173
+
174
+ @numba.jit(nopython=True)
175
+ def smooth_img_3x3(img):
176
+ zsh, ysh, xsh = img.shape
177
+ out = np.copy(img)
178
+ for z in range(zsh):
179
+ for y in range(ysh):
180
+ for x in range(xsh):
181
+ tmp,i = 0,0
182
+ for k in range(-1,2):
183
+ for l in range(-1,2):
184
+ for m in range(-1,2):
185
+ if 0<=z+k<zsh and 0<=y+l<ysh and 0<=x+m<xsh:
186
+ tmp += img[z+k,y+l,x+m]
187
+ i += 1
188
+ out[z,y,x] = tmp / i
189
+ return out
190
+
191
+ def set_labels_to_zero(label, labels_to_compute, labels_to_remove):
192
+
193
+ # compute only specific labels (set rest to zero)
194
+ labels_to_compute = labels_to_compute.split(',')
195
+ if not any([x in ['all', 'All', 'ALL'] for x in labels_to_compute]):
196
+ allLabels = np.unique(label)
197
+ labels_to_del = [k for k in allLabels if str(k) not in labels_to_compute and k > 0]
198
+ for k in labels_to_del:
199
+ label[label == k] = 0
200
+
201
+ # ignore specific labels (set to zero)
202
+ labels_to_remove = labels_to_remove.split(',')
203
+ if not any([x in ['none', 'None', 'NONE'] for x in labels_to_remove]):
204
+ for k in labels_to_remove:
205
+ k = int(k)
206
+ if np.any(label == k):
207
+ label[label == k] = 0
208
+
209
+ return label
210
+
211
+ def img_to_uint8(img):
212
+ if img.dtype != 'uint8':
213
+ img = img.astype(np.float32)
214
+ img -= np.amin(img)
215
+ img /= np.amax(img)
216
+ img *= 255.0
217
+ img = img.astype(np.uint8)
218
+ return img
219
+
220
+ def id_generator(size, chars='abcdefghijklmnopqrstuvwxyz0123456789'):
221
+ return ''.join(random.choice(chars) for x in range(size))
222
+
223
+ def rgb2gray(img, channel='last'):
224
+ """Convert a RGB image to gray scale."""
225
+ if channel=='last':
226
+ out = 0.2989*img[:,:,0] + 0.587*img[:,:,1] + 0.114*img[:,:,2]
227
+ elif channel=='first':
228
+ out = 0.2989*img[0,:,:] + 0.587*img[1,:,:] + 0.114*img[2,:,:]
229
+ out = out.astype(img.dtype)
230
+ return out
231
+
232
+ def recursive_file_permissions(path_to_dir):
233
+ files = glob.glob(path_to_dir+'/**/*', recursive=True) + [path_to_dir]
234
+ for file in files:
235
+ try:
236
+ if os.path.isdir(file):
237
+ os.chmod(file, 0o770)
238
+ else:
239
+ os.chmod(file, 0o660)
240
+ except:
241
+ pass
242
+
243
+ def load_data(path_to_data, process='None', return_extension=False):
244
+
245
+ if not os.path.exists(path_to_data):
246
+ print(f"Error: No such file or directory '{path_to_data}'")
247
+
248
+ # get file extension
249
+ extension = os.path.splitext(path_to_data)[1]
250
+ if extension == '.gz':
251
+ extension = '.nii.gz'
252
+ elif extension == '.bz2':
253
+ extension = os.path.splitext(os.path.splitext(path_to_data)[0])[1]
254
+
255
+ if extension == '.am':
256
+ try:
257
+ data, header = amira_to_np(path_to_data)
258
+ header = [header]
259
+ if len(data) > 1:
260
+ for arr in data[1:]:
261
+ header.append(arr)
262
+ data = data[0]
263
+ except Exception as e:
264
+ print(e)
265
+ data, header = None, None
266
+
267
+ elif extension == '.nc':
268
+ try:
269
+ data, header = nc_to_np(path_to_data)
270
+ except Exception as e:
271
+ print(e)
272
+ data, header = None, None
273
+
274
+ elif extension in ['.hdr', '.mhd', '.mha', '.nrrd', '.nii', '.nii.gz']:
275
+ try:
276
+ header = sitk.ReadImage(path_to_data)
277
+ data = sitk.GetArrayViewFromImage(header).copy()
278
+ except Exception as e:
279
+ print(e)
280
+ data, header = None, None
281
+
282
+ elif extension == '.zip' or os.path.isdir(path_to_data):
283
+ with tempfile.TemporaryDirectory() as temp_dir:
284
+
285
+ # extract files
286
+ if extension=='.zip':
287
+ try:
288
+ zip_ref = zipfile.ZipFile(path_to_data, 'r')
289
+ zip_ref.extractall(path=temp_dir)
290
+ zip_ref.close()
291
+ except Exception as e:
292
+ print(e)
293
+ print('Using unzip package...')
294
+ try:
295
+ success = subprocess.Popen(['unzip',path_to_data,'-d',temp_dir]).wait()
296
+ if success != 0:
297
+ data, header = None, None
298
+ except Exception as e:
299
+ print(e)
300
+ data, header = None, None
301
+ path_to_data = temp_dir
302
+
303
+ # load files
304
+ if os.path.isdir(path_to_data):
305
+ files = []
306
+ for data_type in ['.[pP][nN][gG]','.[tT][iI][fF]','.[tT][iI][fF][fF]','.[dD][cC][mM]','.[dD][iI][cC][oO][mM]','.[bB][mM][pP]','.[jJ][pP][gG]','.[jJ][pP][eE][gG]','.nc','.nc.bz2']:
307
+ files += [file for file in glob.glob(path_to_data+'/**/*'+data_type, recursive=True) if not os.path.basename(file).startswith('.')]
308
+ nc_extension = False
309
+ for file in files:
310
+ if os.path.splitext(file)[1] == '.nc' or os.path.splitext(os.path.splitext(file)[0])[1] == '.nc':
311
+ nc_extension = True
312
+ if nc_extension:
313
+ try:
314
+ data, header = nc_to_np(path_to_data)
315
+ except Exception as e:
316
+ print(e)
317
+ data, header = None, None
318
+ else:
319
+ try:
320
+ # remove unreadable files or directories
321
+ for name in files:
322
+ if os.path.isfile(name):
323
+ try:
324
+ img, _ = load(name)
325
+ except:
326
+ files.remove(name)
327
+ else:
328
+ files.remove(name)
329
+ files.sort()
330
+
331
+ # get data size
332
+ img, _ = load(files[0])
333
+ if len(img.shape)==3:
334
+ ysh, xsh, csh = img.shape[0], img.shape[1], img.shape[2]
335
+ channel = 'last'
336
+ if ysh < csh:
337
+ csh, ysh, xsh = img.shape[0], img.shape[1], img.shape[2]
338
+ channel = 'first'
339
+ else:
340
+ ysh, xsh = img.shape[0], img.shape[1]
341
+ csh, channel = 0, None
342
+
343
+ # load data slice by slice
344
+ data = np.empty((len(files), ysh, xsh), dtype=img.dtype)
345
+ header, image_data_shape = [], []
346
+ for k, file_name in enumerate(files):
347
+ img, img_header = load(file_name)
348
+ if csh==3:
349
+ img = rgb2gray(img, channel)
350
+ elif csh==1 and channel=='last':
351
+ img = img[:,:,0]
352
+ elif csh==1 and channel=='first':
353
+ img = img[0,:,:]
354
+ data[k] = img
355
+ header.append(img_header)
356
+ header = [header, files, data.dtype]
357
+ data = np.swapaxes(data, 1, 2)
358
+ data = np.copy(data, order='C')
359
+ except Exception as e:
360
+ print(e)
361
+ data, header = None, None
362
+
363
+ elif extension == '.mrc':
364
+ try:
365
+ import mrcfile
366
+ with mrcfile.open(path_to_data, permissive=True) as mrc:
367
+ data = mrc.data
368
+ data = np.flip(data,1)
369
+ extension, header = '.tif', None
370
+ except Exception as e:
371
+ print(e)
372
+ data, header = None, None
373
+
374
+ elif extension in ['.tif', '.tiff']:
375
+ try:
376
+ data = imread(path_to_data)
377
+ header = None
378
+ except Exception as e:
379
+ print(e)
380
+ data, header = None, None
381
+
382
+ else:
383
+ data, header = None, None
384
+
385
+ if return_extension:
386
+ return data, header, extension
387
+ else:
388
+ return data, header
389
+
390
+ def _error_(bm, message):
391
+ if bm.django_env:
392
+ from biomedisa.features.django_env import create_error_object
393
+ create_error_object(message, bm.remote, bm.queue, bm.img_id)
394
+ with open(bm.path_to_logfile, 'a') as logfile:
395
+ print('%s %s %s %s' %(time.ctime(), bm.username, bm.shortfilename, message), file=logfile)
396
+ print('Error:', message)
397
+ bm.success = False
398
+ return bm
399
+
400
+ def pre_processing(bm):
401
+
402
+ # load data
403
+ if bm.data is None:
404
+ bm.data, _ = load_data(bm.path_to_data, bm.process)
405
+
406
+ # error handling
407
+ if bm.data is None:
408
+ return _error_(bm, 'Invalid image data.')
409
+
410
+ # load label data
411
+ if bm.labelData is None:
412
+ bm.labelData, bm.header, bm.final_image_type = load_data(bm.path_to_labels, bm.process, True)
413
+
414
+ # error handling
415
+ if bm.labelData is None:
416
+ return _error_(bm, 'Invalid label data.')
417
+
418
+ if len(bm.labelData.shape) != 3:
419
+ return _error_(bm, 'Label must be three-dimensional.')
420
+
421
+ if bm.data.shape != bm.labelData.shape:
422
+ return _error_(bm, 'Image and label must have the same x,y,z-dimensions.')
423
+
424
+ # get labels
425
+ bm.allLabels = np.unique(bm.labelData)
426
+ index = np.argwhere(bm.allLabels<0)
427
+ bm.allLabels = np.delete(bm.allLabels, index)
428
+
429
+ if bm.django_env and np.any(bm.allLabels > 255):
430
+ return _error_(bm, 'No labels higher than 255 allowed.')
431
+
432
+ if np.any(bm.allLabels > 255):
433
+ bm.labelData[bm.labelData > 255] = 0
434
+ index = np.argwhere(bm.allLabels > 255)
435
+ bm.allLabels = np.delete(bm.allLabels, index)
436
+ print('Warning: Only labels <=255 are allowed. Labels higher than 255 will be removed.')
437
+
438
+ # add background label if not existing
439
+ if not np.any(bm.allLabels==0):
440
+ bm.allLabels = np.append(0, bm.allLabels)
441
+
442
+ # compute only specific labels
443
+ labels_to_compute = (bm.only).split(',')
444
+ if not any([x in ['all', 'All', 'ALL'] for x in labels_to_compute]):
445
+ labels_to_remove = [k for k in bm.allLabels if str(k) not in labels_to_compute and k > 0]
446
+ for k in labels_to_remove:
447
+ bm.labelData[bm.labelData == k] = 0
448
+ index = np.argwhere(bm.allLabels==k)
449
+ bm.allLabels = np.delete(bm.allLabels, index)
450
+
451
+ # ignore specific labels
452
+ labels_to_remove = (bm.ignore).split(',')
453
+ if not any([x in ['none', 'None', 'NONE'] for x in labels_to_remove]):
454
+ for k in labels_to_remove:
455
+ try:
456
+ k = int(k)
457
+ bm.labelData[bm.labelData == k] = 0
458
+ index = np.argwhere(bm.allLabels==k)
459
+ bm.allLabels = np.delete(bm.allLabels, index)
460
+ except:
461
+ pass
462
+
463
+ # number of labels
464
+ bm.nol = len(bm.allLabels)
465
+
466
+ if bm.nol < 2:
467
+ return _error_(bm, 'No labeled slices found.')
468
+
469
+ bm.success = True
470
+ return bm
471
+
472
+ def save_data(path_to_final, final, header=None, final_image_type=None, compress=True):
473
+ if final_image_type == None:
474
+ final_image_type = os.path.splitext(path_to_final)[1]
475
+ if final_image_type == '.gz':
476
+ final_image_type = '.nii.gz'
477
+ if final_image_type == '.am':
478
+ final = [final]
479
+ if len(header) > 1:
480
+ for arr in header[1:]:
481
+ final.append(arr)
482
+ header = header[0]
483
+ np_to_amira(path_to_final, final, header)
484
+ elif final_image_type == '.nc':
485
+ np_to_nc(path_to_final, final, header)
486
+ elif final_image_type in ['.hdr', '.mhd', '.mha', '.nrrd', '.nii', '.nii.gz']:
487
+ simg = sitk.GetImageFromArray(final)
488
+ simg.CopyInformation(header)
489
+ sitk.WriteImage(simg, path_to_final, useCompression=compress)
490
+ elif final_image_type in ['.zip', 'directory', '']:
491
+ with tempfile.TemporaryDirectory() as temp_dir:
492
+ # make results directory
493
+ if final_image_type == '.zip':
494
+ results_dir = temp_dir
495
+ else:
496
+ results_dir = path_to_final
497
+ if not os.path.isdir(results_dir):
498
+ os.makedirs(results_dir)
499
+ os.chmod(results_dir, 0o770)
500
+ # save data as NC blocks
501
+ if os.path.splitext(header[1][0])[1] == '.nc':
502
+ np_to_nc(results_dir, final, header)
503
+ file_names = header[1]
504
+ # save data as PNG, TIF, DICOM slices
505
+ else:
506
+ header, file_names, final_dtype = header[0], header[1], header[2]
507
+ final = final.astype(final_dtype)
508
+ final = np.swapaxes(final, 2, 1)
509
+ for k, file in enumerate(file_names):
510
+ save(final[k], results_dir + '/' + os.path.basename(file), header[k])
511
+ # zip data
512
+ if final_image_type == '.zip':
513
+ with zipfile.ZipFile(path_to_final, 'w') as zip:
514
+ for file in file_names:
515
+ zip.write(results_dir + '/' + os.path.basename(file), os.path.basename(file))
516
+ else:
517
+ imageSize = int(final.nbytes * 10e-7)
518
+ bigtiff = True if imageSize > 2000 else False
519
+ try:
520
+ compress = 'zlib' if compress else None
521
+ imwrite(path_to_final, final, bigtiff=bigtiff, compression=compress)
522
+ except:
523
+ compress = 6 if compress else 0
524
+ imwrite(path_to_final, final, bigtiff=bigtiff, compress=compress)
525
+
526
+ def color_to_gray(labelData):
527
+ if len(labelData.shape) == 4 and labelData.shape[1] == 3:
528
+ labelData = labelData.astype(np.float16)
529
+ labelData -= np.amin(labelData)
530
+ labelData /= np.amax(labelData)
531
+ labelData = 0.299 * labelData[:,0] + 0.587 * labelData[:,1] + 0.114 * labelData[:,2]
532
+ labelData *= 255.0
533
+ labelData = labelData.astype(np.uint8)
534
+ labelData = delbackground(labelData)
535
+ elif len(labelData.shape) == 4 and labelData.shape[3] == 3:
536
+ labelData = labelData.astype(np.float16)
537
+ labelData -= np.amin(labelData)
538
+ labelData /= np.amax(labelData)
539
+ labelData = 0.299 * labelData[:,:,:,0] + 0.587 * labelData[:,:,:,1] + 0.114 * labelData[:,:,:,2]
540
+ labelData *= 255.0
541
+ labelData = labelData.astype(np.uint8)
542
+ labelData = delbackground(labelData)
543
+ return labelData
544
+
545
+ def delbackground(labels):
546
+ allLabels, labelcounts = np.unique(labels, return_counts=True)
547
+ index = np.argmax(labelcounts)
548
+ labels[labels==allLabels[index]] = 0
549
+ return labels
550
+
551
+ def _get_platform(bm):
552
+
553
+ # import PyCUDA
554
+ if bm.platform in ['cuda', None]:
555
+ try:
556
+ import pycuda.driver as cuda
557
+ import pycuda.gpuarray as gpuarray
558
+ cuda.init()
559
+ bm.available_devices = cuda.Device.count()
560
+ if bm.available_devices > 0:
561
+ dev = cuda.Device(0)
562
+ ctx = dev.make_context()
563
+ a_gpu = gpuarray.to_gpu(np.random.randn(4,4).astype(np.float32))
564
+ a_doubled = (2*a_gpu).get()
565
+ ctx.pop()
566
+ del ctx
567
+ bm.platform = 'cuda'
568
+ return bm
569
+ elif bm.platform == 'cuda':
570
+ print('Error: No CUDA device found.')
571
+ bm.success = False
572
+ return bm
573
+ except:
574
+ pass
575
+
576
+ # import PyOpenCL
577
+ try:
578
+ import pyopencl as cl
579
+ except ImportError:
580
+ cl = None
581
+
582
+ # select the first detected device
583
+ if cl and bm.platform is None:
584
+ for vendor in ['NVIDIA', 'Intel', 'AMD', 'Apple']:
585
+ for dev, device_type in [('GPU',cl.device_type.GPU),('CPU',cl.device_type.CPU)]:
586
+ all_platforms = cl.get_platforms()
587
+ my_devices = []
588
+ for p in all_platforms:
589
+ if p.get_devices(device_type=device_type) and vendor in p.name:
590
+ my_devices = p.get_devices(device_type=device_type)
591
+ if my_devices:
592
+ bm.platform = 'opencl_'+vendor+'_'+dev
593
+ if 'OMPI_COMMAND' in os.environ and dev == 'CPU':
594
+ print("Error: OpenCL CPU does not support MPI. Start Biomedisa without 'mpirun' or 'mpiexec'.")
595
+ bm.success = False
596
+ return bm
597
+ else:
598
+ bm.available_devices = len(my_devices)
599
+ print('Detected platform:', bm.platform)
600
+ print('Detected devices:', my_devices)
601
+ return bm
602
+
603
+ # explicitly select the OpenCL device
604
+ elif cl and len(bm.platform.split('_')) == 3:
605
+ plat, vendor, dev = bm.platform.split('_')
606
+ device_type=cl.device_type.GPU if dev=='GPU' else cl.device_type.CPU
607
+ all_platforms = cl.get_platforms()
608
+ my_devices = []
609
+ for p in all_platforms:
610
+ if p.get_devices(device_type=device_type) and vendor in p.name:
611
+ my_devices = p.get_devices(device_type=device_type)
612
+ if my_devices:
613
+ if 'OMPI_COMMAND' in os.environ and dev == 'CPU':
614
+ print("Error: OpenCL CPU does not support MPI. Start Biomedisa without 'mpirun' or 'mpiexec'.")
615
+ bm.success = False
616
+ return bm
617
+ else:
618
+ bm.available_devices = len(my_devices)
619
+ print('Detected platform:', bm.platform)
620
+ print('Detected devices:', my_devices)
621
+ return bm
622
+
623
+ # stop the process if no device is detected
624
+ if bm.platform is None:
625
+ bm.platform = 'OpenCL or CUDA'
626
+ print(f'Error: No {bm.platform} device found.')
627
+ bm.success = False
628
+ return bm
629
+
630
+ def _get_device(platform, dev_id):
631
+ import pyopencl as cl
632
+ plat, vendor, dev = platform.split('_')
633
+ device_type=cl.device_type.GPU if dev=='GPU' else cl.device_type.CPU
634
+ all_platforms = cl.get_platforms()
635
+ for p in all_platforms:
636
+ if p.get_devices(device_type=device_type) and vendor in p.name:
637
+ my_devices = p.get_devices(device_type=device_type)
638
+ context = cl.Context(devices=my_devices)
639
+ queue = cl.CommandQueue(context, my_devices[dev_id % len(my_devices)])
640
+ return context, queue
641
+
642
+ def read_labeled_slices(arr):
643
+ data = np.zeros((0, arr.shape[1], arr.shape[2]), dtype=np.int32)
644
+ indices = []
645
+ for k, slc in enumerate(arr[:]):
646
+ if np.any(slc):
647
+ data = np.append(data, [arr[k]], axis=0)
648
+ indices.append(k)
649
+ return indices, data
650
+
651
+ def read_labeled_slices_allx(arr, ax):
652
+ gradient = np.zeros(arr.shape, dtype=np.int8)
653
+ ones = np.zeros_like(gradient)
654
+ ones[arr != 0] = 1
655
+ tmp = ones[:,:-1] - ones[:,1:]
656
+ tmp = np.abs(tmp)
657
+ gradient[:,:-1] += tmp
658
+ gradient[:,1:] += tmp
659
+ ones[gradient == 2] = 0
660
+ gradient.fill(0)
661
+ tmp = ones[:,:,:-1] - ones[:,:,1:]
662
+ tmp = np.abs(tmp)
663
+ gradient[:,:,:-1] += tmp
664
+ gradient[:,:,1:] += tmp
665
+ ones[gradient == 2] = 0
666
+ indices = []
667
+ data = np.zeros((0, arr.shape[1], arr.shape[2]), dtype=np.int32)
668
+ for k, slc in enumerate(ones[:]):
669
+ if np.any(slc):
670
+ data = np.append(data, [arr[k]], axis=0)
671
+ indices.append((k, ax))
672
+ return indices, data
673
+
674
+ def read_indices_allx(arr, ax):
675
+ gradient = np.zeros(arr.shape, dtype=np.int8)
676
+ ones = np.zeros_like(gradient)
677
+ ones[arr != 0] = 1
678
+ tmp = ones[:,:-1] - ones[:,1:]
679
+ tmp = np.abs(tmp)
680
+ gradient[:,:-1] += tmp
681
+ gradient[:,1:] += tmp
682
+ ones[gradient == 2] = 0
683
+ gradient.fill(0)
684
+ tmp = ones[:,:,:-1] - ones[:,:,1:]
685
+ tmp = np.abs(tmp)
686
+ gradient[:,:,:-1] += tmp
687
+ gradient[:,:,1:] += tmp
688
+ ones[gradient == 2] = 0
689
+ indices = []
690
+ for k, slc in enumerate(ones[:]):
691
+ if np.any(slc):
692
+ indices.append((k, ax))
693
+ return indices
694
+
695
+ def read_labeled_slices_large(arr):
696
+ data = np.zeros((0, arr.shape[1], arr.shape[2]), dtype=np.int32)
697
+ indices = []
698
+ i = 0
699
+ while i < arr.shape[0]:
700
+ slc = arr[i]
701
+ if np.any(slc):
702
+ data = np.append(data, [arr[i]], axis=0)
703
+ indices.append(i)
704
+ i += 5
705
+ else:
706
+ i += 1
707
+ return indices, data
708
+
709
+ def read_labeled_slices_allx_large(arr):
710
+ gradient = np.zeros(arr.shape, dtype=np.int8)
711
+ ones = np.zeros_like(gradient)
712
+ ones[arr > 0] = 1
713
+ tmp = ones[:,:-1] - ones[:,1:]
714
+ tmp = np.abs(tmp)
715
+ gradient[:,:-1] += tmp
716
+ gradient[:,1:] += tmp
717
+ ones[gradient == 2] = 0
718
+ gradient.fill(0)
719
+ tmp = ones[:,:,:-1] - ones[:,:,1:]
720
+ tmp = np.abs(tmp)
721
+ gradient[:,:,:-1] += tmp
722
+ gradient[:,:,1:] += tmp
723
+ ones[gradient == 2] = 0
724
+ indices = []
725
+ data = np.zeros((0, arr.shape[1], arr.shape[2]), dtype=np.int32)
726
+ for k, slc in enumerate(ones[:]):
727
+ if np.any(slc):
728
+ data = np.append(data, [arr[k]], axis=0)
729
+ indices.append(k)
730
+ return indices, data
731
+
732
+ def predict_blocksize(bm):
733
+ zsh, ysh, xsh = bm.labelData.shape
734
+ argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x = zsh, 0, ysh, 0, xsh, 0
735
+ for k in range(zsh):
736
+ y, x = np.nonzero(bm.labelData[k])
737
+ if x.any():
738
+ argmin_x = min(argmin_x, np.amin(x))
739
+ argmax_x = max(argmax_x, np.amax(x))
740
+ argmin_y = min(argmin_y, np.amin(y))
741
+ argmax_y = max(argmax_y, np.amax(y))
742
+ argmin_z = min(argmin_z, k)
743
+ argmax_z = max(argmax_z, k)
744
+ zmin, zmax = argmin_z, argmax_z
745
+ bm.argmin_x = argmin_x - 100 if argmin_x - 100 > 0 else 0
746
+ bm.argmax_x = argmax_x + 100 if argmax_x + 100 < xsh else xsh
747
+ bm.argmin_y = argmin_y - 100 if argmin_y - 100 > 0 else 0
748
+ bm.argmax_y = argmax_y + 100 if argmax_y + 100 < ysh else ysh
749
+ bm.argmin_z = argmin_z - 100 if argmin_z - 100 > 0 else 0
750
+ bm.argmax_z = argmax_z + 100 if argmax_z + 100 < zsh else zsh
751
+ return bm
752
+
753
+ def splitlargedata(data):
754
+ dataMemory = data.nbytes
755
+ dataListe = []
756
+ if dataMemory > 1500000000:
757
+ mod = dataMemory / float(1500000000)
758
+ mod2 = int(math.ceil(mod))
759
+ mod3 = divmod(data.shape[0], mod2)[0]
760
+ for k in range(mod2):
761
+ dataListe.append(data[mod3*k:mod3*(k+1)])
762
+ dataListe.append(data[mod3*mod2:])
763
+ else:
764
+ dataListe.append(data)
765
+ return dataListe
766
+
767
+ def sendToChildLarge(comm, indices, dest, dataListe, Labels, nbrw, sorw, blocks,
768
+ allx, allLabels, smooth, uncertainty, platform):
769
+ from mpi4py import MPI
770
+ comm.send(len(dataListe), dest=dest, tag=0)
771
+ for k, tmp in enumerate(dataListe):
772
+ tmp = tmp.copy(order='C')
773
+ comm.send([tmp.shape[0], tmp.shape[1], tmp.shape[2], tmp.dtype], dest=dest, tag=10+(2*k))
774
+ if tmp.dtype == 'uint8':
775
+ comm.Send([tmp, MPI.BYTE], dest=dest, tag=10+(2*k+1))
776
+ else:
777
+ comm.Send([tmp, MPI.FLOAT], dest=dest, tag=10+(2*k+1))
778
+
779
+ comm.send([nbrw, sorw, allx, smooth, uncertainty, platform], dest=dest, tag=1)
780
+
781
+ if allx:
782
+ for k in range(3):
783
+ labelsListe = splitlargedata(Labels[k])
784
+ comm.send(len(labelsListe), dest=dest, tag=2+k)
785
+ for l, tmp in enumerate(labelsListe):
786
+ tmp = tmp.copy(order='C')
787
+ comm.send([tmp.shape[0], tmp.shape[1], tmp.shape[2]], dest=dest, tag=100+(2*k))
788
+ comm.Send([tmp, MPI.INT], dest=dest, tag=100+(2*k+1))
789
+ else:
790
+ labelsListe = splitlargedata(Labels)
791
+ comm.send(len(labelsListe), dest=dest, tag=2)
792
+ for k, tmp in enumerate(labelsListe):
793
+ tmp = tmp.copy(order='C')
794
+ comm.send([tmp.shape[0], tmp.shape[1], tmp.shape[2]], dest=dest, tag=100+(2*k))
795
+ comm.Send([tmp, MPI.INT], dest=dest, tag=100+(2*k+1))
796
+
797
+ comm.send(allLabels, dest=dest, tag=99)
798
+ comm.send(indices, dest=dest, tag=8)
799
+ comm.send(blocks, dest=dest, tag=9)
800
+
801
+ def sendToChild(comm, indices, indices_child, dest, data, Labels, nbrw, sorw, allx, platform):
802
+ from mpi4py import MPI
803
+ data = data.copy(order='C')
804
+ comm.send([data.shape[0], data.shape[1], data.shape[2], data.dtype], dest=dest, tag=0)
805
+ if data.dtype == 'uint8':
806
+ comm.Send([data, MPI.BYTE], dest=dest, tag=1)
807
+ else:
808
+ comm.Send([data, MPI.FLOAT], dest=dest, tag=1)
809
+ comm.send([allx, nbrw, sorw, platform], dest=dest, tag=2)
810
+ if allx:
811
+ for k in range(3):
812
+ labels = Labels[k].copy(order='C')
813
+ comm.send([labels.shape[0], labels.shape[1], labels.shape[2]], dest=dest, tag=k+3)
814
+ comm.Send([labels, MPI.INT], dest=dest, tag=k+6)
815
+ else:
816
+ labels = Labels.copy(order='C')
817
+ comm.send([labels.shape[0], labels.shape[1], labels.shape[2]], dest=dest, tag=3)
818
+ comm.Send([labels, MPI.INT], dest=dest, tag=6)
819
+ comm.send(indices, dest=dest, tag=9)
820
+ comm.send(indices_child, dest=dest, tag=10)
821
+
822
+ def _split_indices(indices, ngpus):
823
+ ngpus = ngpus if ngpus < len(indices) else len(indices)
824
+ nindices = len(indices)
825
+ parts = []
826
+ for i in range(0, ngpus):
827
+ slice_idx = indices[i]
828
+ parts.append([slice_idx])
829
+ if ngpus < nindices:
830
+ for i in range(ngpus, nindices):
831
+ gid = i % ngpus
832
+ slice_idx = indices[i]
833
+ parts[gid].append(slice_idx)
834
+ return parts
835
+
836
+ def get_labels(pre_final, labels):
837
+ numos = np.unique(pre_final)
838
+ final = np.zeros_like(pre_final)
839
+ for k in numos[1:]:
840
+ final[pre_final == k] = labels[k]
841
+ return final
842
+