biomedisa 2024.5.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. biomedisa/__init__.py +53 -0
  2. biomedisa/__main__.py +18 -0
  3. biomedisa/biomedisa_features/DataGenerator.py +299 -0
  4. biomedisa/biomedisa_features/DataGeneratorCrop.py +121 -0
  5. biomedisa/biomedisa_features/PredictDataGenerator.py +87 -0
  6. biomedisa/biomedisa_features/PredictDataGeneratorCrop.py +74 -0
  7. biomedisa/biomedisa_features/__init__.py +0 -0
  8. biomedisa/biomedisa_features/active_contour.py +434 -0
  9. biomedisa/biomedisa_features/amira_to_np/__init__.py +0 -0
  10. biomedisa/biomedisa_features/amira_to_np/amira_data_stream.py +980 -0
  11. biomedisa/biomedisa_features/amira_to_np/amira_grammar.py +369 -0
  12. biomedisa/biomedisa_features/amira_to_np/amira_header.py +290 -0
  13. biomedisa/biomedisa_features/amira_to_np/amira_helper.py +72 -0
  14. biomedisa/biomedisa_features/assd.py +167 -0
  15. biomedisa/biomedisa_features/biomedisa_helper.py +801 -0
  16. biomedisa/biomedisa_features/create_slices.py +286 -0
  17. biomedisa/biomedisa_features/crop_helper.py +586 -0
  18. biomedisa/biomedisa_features/curvop_numba.py +149 -0
  19. biomedisa/biomedisa_features/django_env.py +172 -0
  20. biomedisa/biomedisa_features/keras_helper.py +1219 -0
  21. biomedisa/biomedisa_features/nc_reader.py +179 -0
  22. biomedisa/biomedisa_features/pid.py +52 -0
  23. biomedisa/biomedisa_features/process_image.py +253 -0
  24. biomedisa/biomedisa_features/pycuda_test.py +84 -0
  25. biomedisa/biomedisa_features/random_walk/__init__.py +0 -0
  26. biomedisa/biomedisa_features/random_walk/gpu_kernels.py +183 -0
  27. biomedisa/biomedisa_features/random_walk/pycuda_large.py +826 -0
  28. biomedisa/biomedisa_features/random_walk/pycuda_large_allx.py +806 -0
  29. biomedisa/biomedisa_features/random_walk/pycuda_small.py +414 -0
  30. biomedisa/biomedisa_features/random_walk/pycuda_small_allx.py +493 -0
  31. biomedisa/biomedisa_features/random_walk/pyopencl_large.py +760 -0
  32. biomedisa/biomedisa_features/random_walk/pyopencl_small.py +441 -0
  33. biomedisa/biomedisa_features/random_walk/rw_large.py +390 -0
  34. biomedisa/biomedisa_features/random_walk/rw_small.py +310 -0
  35. biomedisa/biomedisa_features/remove_outlier.py +399 -0
  36. biomedisa/biomedisa_features/split_volume.py +274 -0
  37. biomedisa/deeplearning.py +519 -0
  38. biomedisa/interpolation.py +371 -0
  39. biomedisa/mesh.py +406 -0
  40. biomedisa-2024.5.14.dist-info/LICENSE +191 -0
  41. biomedisa-2024.5.14.dist-info/METADATA +306 -0
  42. biomedisa-2024.5.14.dist-info/RECORD +44 -0
  43. biomedisa-2024.5.14.dist-info/WHEEL +5 -0
  44. biomedisa-2024.5.14.dist-info/top_level.txt +1 -0
@@ -0,0 +1,801 @@
1
+ ##########################################################################
2
+ ## ##
3
+ ## Copyright (c) 2024 Philipp Lösel. All rights reserved. ##
4
+ ## ##
5
+ ## This file is part of the open source project biomedisa. ##
6
+ ## ##
7
+ ## Licensed under the European Union Public Licence (EUPL) ##
8
+ ## v1.2, or - as soon as they will be approved by the ##
9
+ ## European Commission - subsequent versions of the EUPL; ##
10
+ ## ##
11
+ ## You may redistribute it and/or modify it under the terms ##
12
+ ## of the EUPL v1.2. You may not use this work except in ##
13
+ ## compliance with this Licence. ##
14
+ ## ##
15
+ ## You can obtain a copy of the Licence at: ##
16
+ ## ##
17
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
18
+ ## ##
19
+ ## Unless required by applicable law or agreed to in ##
20
+ ## writing, software distributed under the Licence is ##
21
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
22
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
23
+ ## ##
24
+ ## See the Licence for the specific language governing ##
25
+ ## permissions and limitations under the Licence. ##
26
+ ## ##
27
+ ##########################################################################
28
+
29
+ import os
30
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
31
+ from biomedisa_features.amira_to_np.amira_helper import amira_to_np, np_to_amira
32
+ from biomedisa_features.nc_reader import nc_to_np, np_to_nc
33
+ from tifffile import imread, imwrite
34
+ from medpy.io import load, save
35
+ import SimpleITK as sitk
36
+ from PIL import Image
37
+ import numpy as np
38
+ import glob
39
+ import random
40
+ import cv2
41
+ import time
42
+ import zipfile
43
+ import numba
44
+ import shutil
45
+ import subprocess
46
+ import math
47
+
48
+ def silent_remove(filename):
49
+ try:
50
+ os.remove(filename)
51
+ except OSError:
52
+ pass
53
+
54
+ def Dice_score(ground_truth, result, average_dice=False):
55
+ if average_dice:
56
+ dice = 0
57
+ allLabels = np.unique(ground_truth)
58
+ for l in allLabels[1:]:
59
+ dice += 2 * np.logical_and(ground_truth==l, result==l).sum() / float((ground_truth==l).sum() + (result==l).sum())
60
+ dice /= float(len(allLabels)-1)
61
+ else:
62
+ dice = 2 * np.logical_and(ground_truth==result, (ground_truth+result)>0).sum() / \
63
+ float((ground_truth>0).sum() + (result>0).sum())
64
+ return dice
65
+
66
+ def ASSD(ground_truth, result):
67
+ try:
68
+ from biomedisa_features.assd import ASSD_one_label
69
+ number_of_elements = 0
70
+ distances = 0
71
+ hausdorff = 0
72
+ for label in np.unique(ground_truth)[1:]:
73
+ d, n, h = ASSD_one_label(ground_truth, result, label)
74
+ number_of_elements += n
75
+ distances += d
76
+ hausdorff = max(h, hausdorff)
77
+ assd = distances / float(number_of_elements)
78
+ return assd, hausdorff
79
+ except:
80
+ print('Error: no CUDA device found. ASSD is not available.')
81
+ return None, None
82
+
83
+ def img_resize(a, z_shape, y_shape, x_shape, interpolation=None, labels=False):
84
+ if len(a.shape) > 3:
85
+ zsh, ysh, xsh, csh = a.shape
86
+ else:
87
+ zsh, ysh, xsh = a.shape
88
+ if interpolation == None:
89
+ if z_shape < zsh or y_shape < ysh or x_shape < xsh:
90
+ interpolation = cv2.INTER_AREA
91
+ else:
92
+ interpolation = cv2.INTER_CUBIC
93
+
94
+ def __resize__(arr):
95
+ b = np.empty((zsh, y_shape, x_shape), dtype=arr.dtype)
96
+ for k in range(zsh):
97
+ b[k] = cv2.resize(arr[k], (x_shape, y_shape), interpolation=interpolation)
98
+ c = np.empty((y_shape, z_shape, x_shape), dtype=arr.dtype)
99
+ b = np.swapaxes(b, 0, 1)
100
+ b = np.copy(b, order='C')
101
+ for k in range(y_shape):
102
+ c[k] = cv2.resize(b[k], (x_shape, z_shape), interpolation=interpolation)
103
+ c = np.swapaxes(c, 1, 0)
104
+ c = np.copy(c, order='C')
105
+ return c
106
+
107
+ if labels:
108
+ data = np.zeros((z_shape, y_shape, x_shape), dtype=a.dtype)
109
+ for k in np.unique(a):
110
+ if k!=0:
111
+ tmp = np.zeros(a.shape, dtype=np.uint8)
112
+ tmp[a==k] = 1
113
+ tmp = __resize__(tmp)
114
+ data[tmp==1] = k
115
+ elif len(a.shape) > 3:
116
+ data = np.empty((z_shape, y_shape, x_shape, csh), dtype=a.dtype)
117
+ for channel in range(csh):
118
+ data[:,:,:,channel] = __resize__(a[:,:,:,channel])
119
+ else:
120
+ data = __resize__(a)
121
+ return data
122
+
123
+ @numba.jit(nopython=True)
124
+ def smooth_img_3x3(img):
125
+ zsh, ysh, xsh = img.shape
126
+ out = np.copy(img)
127
+ for z in range(zsh):
128
+ for y in range(ysh):
129
+ for x in range(xsh):
130
+ tmp,i = 0,0
131
+ for k in range(-1,2):
132
+ for l in range(-1,2):
133
+ for m in range(-1,2):
134
+ if 0<=z+k<zsh and 0<=y+l<ysh and 0<=x+m<xsh:
135
+ tmp += img[z+k,y+l,x+m]
136
+ i += 1
137
+ out[z,y,x] = tmp / i
138
+ return out
139
+
140
+ def set_labels_to_zero(label, labels_to_compute, labels_to_remove):
141
+
142
+ # compute only specific labels (set rest to zero)
143
+ labels_to_compute = labels_to_compute.split(',')
144
+ if not any([x in ['all', 'All', 'ALL'] for x in labels_to_compute]):
145
+ allLabels = np.unique(label)
146
+ labels_to_del = [k for k in allLabels if str(k) not in labels_to_compute and k > 0]
147
+ for k in labels_to_del:
148
+ label[label == k] = 0
149
+
150
+ # ignore specific labels (set to zero)
151
+ labels_to_remove = labels_to_remove.split(',')
152
+ if not any([x in ['none', 'None', 'NONE'] for x in labels_to_remove]):
153
+ for k in labels_to_remove:
154
+ k = int(k)
155
+ if np.any(label == k):
156
+ label[label == k] = 0
157
+
158
+ return label
159
+
160
+ def img_to_uint8(img):
161
+ if img.dtype != 'uint8':
162
+ img = img.astype(np.float32)
163
+ img -= np.amin(img)
164
+ img /= np.amax(img)
165
+ img *= 255.0
166
+ img = img.astype(np.uint8)
167
+ return img
168
+
169
+ def id_generator(size, chars='abcdefghijklmnopqrstuvwxyz0123456789'):
170
+ return ''.join(random.choice(chars) for x in range(size))
171
+
172
+ def rgb2gray(img, channel='last'):
173
+ """Convert a RGB image to gray scale."""
174
+ if channel=='last':
175
+ out = 0.2989*img[:,:,0] + 0.587*img[:,:,1] + 0.114*img[:,:,2]
176
+ elif channel=='first':
177
+ out = 0.2989*img[0,:,:] + 0.587*img[1,:,:] + 0.114*img[2,:,:]
178
+ out = out.astype(img.dtype)
179
+ return out
180
+
181
+ def recursive_file_permissions(path_to_dir):
182
+ files = glob.glob(path_to_dir+'/**/*', recursive=True) + [path_to_dir]
183
+ for file in files:
184
+ try:
185
+ if os.path.isdir(file):
186
+ os.chmod(file, 0o770)
187
+ else:
188
+ os.chmod(file, 0o660)
189
+ except:
190
+ pass
191
+
192
+ def load_data(path_to_data, process='None', return_extension=False):
193
+
194
+ if not os.path.exists(path_to_data):
195
+ print(f"Error: No such file or directory '{path_to_data}'")
196
+
197
+ # get file extension
198
+ extension = os.path.splitext(path_to_data)[1]
199
+ if extension == '.gz':
200
+ extension = '.nii.gz'
201
+ elif extension == '.bz2':
202
+ extension = os.path.splitext(os.path.splitext(path_to_data)[0])[1]
203
+
204
+ if extension == '.am':
205
+ try:
206
+ data, header = amira_to_np(path_to_data)
207
+ header = [header]
208
+ if len(data) > 1:
209
+ for arr in data[1:]:
210
+ header.append(arr)
211
+ data = data[0]
212
+ except Exception as e:
213
+ print(e)
214
+ data, header = None, None
215
+
216
+ elif extension == '.nc':
217
+ try:
218
+ data, header = nc_to_np(path_to_data)
219
+ except Exception as e:
220
+ print(e)
221
+ data, header = None, None
222
+
223
+ elif extension in ['.hdr', '.mhd', '.mha', '.nrrd', '.nii', '.nii.gz']:
224
+ try:
225
+ header = sitk.ReadImage(path_to_data)
226
+ data = sitk.GetArrayViewFromImage(header).copy()
227
+ except Exception as e:
228
+ print(e)
229
+ data, header = None, None
230
+
231
+ elif extension == '.zip' or os.path.isdir(path_to_data):
232
+ # extract files
233
+ if extension=='.zip':
234
+ path_to_dir = BASE_DIR + '/tmp/' + id_generator(40)
235
+ try:
236
+ zip_ref = zipfile.ZipFile(path_to_data, 'r')
237
+ zip_ref.extractall(path=path_to_dir)
238
+ zip_ref.close()
239
+ except Exception as e:
240
+ print(e)
241
+ print('Using unzip package...')
242
+ try:
243
+ success = subprocess.Popen(['unzip',path_to_data,'-d',path_to_dir]).wait()
244
+ if success != 0:
245
+ if os.path.isdir(path_to_dir):
246
+ shutil.rmtree(path_to_dir)
247
+ data, header = None, None
248
+ except Exception as e:
249
+ print(e)
250
+ if os.path.isdir(path_to_dir):
251
+ shutil.rmtree(path_to_dir)
252
+ data, header = None, None
253
+ path_to_data = path_to_dir
254
+
255
+ # load files
256
+ if os.path.isdir(path_to_data):
257
+ files = []
258
+ for data_type in ['.[pP][nN][gG]','.[tT][iI][fF]','.[tT][iI][fF][fF]','.[dD][cC][mM]','.[dD][iI][cC][oO][mM]','.[bB][mM][pP]','.[jJ][pP][gG]','.[jJ][pP][eE][gG]','.nc','.nc.bz2']:
259
+ files += [file for file in glob.glob(path_to_data+'/**/*'+data_type, recursive=True) if not os.path.basename(file).startswith('.')]
260
+ nc_extension = False
261
+ for file in files:
262
+ if os.path.splitext(file)[1] == '.nc' or os.path.splitext(os.path.splitext(file)[0])[1] == '.nc':
263
+ nc_extension = True
264
+ if nc_extension:
265
+ try:
266
+ data, header = nc_to_np(path_to_data)
267
+ except Exception as e:
268
+ print(e)
269
+ data, header = None, None
270
+ else:
271
+ try:
272
+ # remove unreadable files or directories
273
+ for name in files:
274
+ if os.path.isfile(name):
275
+ try:
276
+ img, _ = load(name)
277
+ except:
278
+ files.remove(name)
279
+ else:
280
+ files.remove(name)
281
+ files.sort()
282
+
283
+ # get data size
284
+ img, _ = load(files[0])
285
+ if len(img.shape)==3:
286
+ ysh, xsh, csh = img.shape[0], img.shape[1], img.shape[2]
287
+ channel = 'last'
288
+ if ysh < csh:
289
+ csh, ysh, xsh = img.shape[0], img.shape[1], img.shape[2]
290
+ channel = 'first'
291
+ else:
292
+ ysh, xsh = img.shape[0], img.shape[1]
293
+ csh, channel = 0, None
294
+
295
+ # load data slice by slice
296
+ data = np.empty((len(files), ysh, xsh), dtype=img.dtype)
297
+ header, image_data_shape = [], []
298
+ for k, file_name in enumerate(files):
299
+ img, img_header = load(file_name)
300
+ if csh==3:
301
+ img = rgb2gray(img, channel)
302
+ elif csh==1 and channel=='last':
303
+ img = img[:,:,0]
304
+ elif csh==1 and channel=='first':
305
+ img = img[0,:,:]
306
+ data[k] = img
307
+ header.append(img_header)
308
+ header = [header, files, data.dtype]
309
+ data = np.swapaxes(data, 1, 2)
310
+ data = np.copy(data, order='C')
311
+ except Exception as e:
312
+ print(e)
313
+ data, header = None, None
314
+
315
+ # remove extracted files
316
+ if extension=='.zip' and os.path.isdir(path_to_data):
317
+ shutil.rmtree(path_to_data)
318
+
319
+ elif extension == '.mrc':
320
+ try:
321
+ import mrcfile
322
+ with mrcfile.open(path_to_data, permissive=True) as mrc:
323
+ data = mrc.data
324
+ data = np.flip(data,1)
325
+ extension, header = '.tif', None
326
+ except Exception as e:
327
+ print(e)
328
+ data, header = None, None
329
+
330
+ elif extension in ['.tif', '.tiff']:
331
+ try:
332
+ data = imread(path_to_data)
333
+ header = None
334
+ except Exception as e:
335
+ print(e)
336
+ data, header = None, None
337
+
338
+ else:
339
+ data, header = None, None
340
+
341
+ if return_extension:
342
+ return data, header, extension
343
+ else:
344
+ return data, header
345
+
346
+ def _error_(bm, message):
347
+ if bm.django_env:
348
+ from biomedisa_features.django_env import create_error_object
349
+ create_error_object(message, bm.remote, bm.queue, bm.img_id)
350
+ with open(bm.path_to_logfile, 'a') as logfile:
351
+ print('%s %s %s %s' %(time.ctime(), bm.username, bm.shortfilename, message), file=logfile)
352
+ print('Error:', message)
353
+ bm.success = False
354
+ return bm
355
+
356
+ def pre_processing(bm):
357
+
358
+ # load data
359
+ if bm.data is None:
360
+ bm.data, _ = load_data(bm.path_to_data, bm.process)
361
+
362
+ # error handling
363
+ if bm.data is None:
364
+ return _error_(bm, 'Invalid image data.')
365
+
366
+ # load label data
367
+ if bm.labelData is None:
368
+ bm.labelData, bm.header, bm.final_image_type = load_data(bm.path_to_labels, bm.process, True)
369
+
370
+ # error handling
371
+ if bm.labelData is None:
372
+ return _error_(bm, 'Invalid label data.')
373
+
374
+ if len(bm.labelData.shape) != 3:
375
+ return _error_(bm, 'Label must be three-dimensional.')
376
+
377
+ if bm.data.shape != bm.labelData.shape:
378
+ return _error_(bm, 'Image and label must have the same x,y,z-dimensions.')
379
+
380
+ # get labels
381
+ bm.allLabels = np.unique(bm.labelData)
382
+ index = np.argwhere(bm.allLabels<0)
383
+ bm.allLabels = np.delete(bm.allLabels, index)
384
+
385
+ if bm.django_env and np.any(bm.allLabels > 255):
386
+ return _error_(bm, 'No labels higher than 255 allowed.')
387
+
388
+ if np.any(bm.allLabels > 255):
389
+ bm.labelData[bm.labelData > 255] = 0
390
+ index = np.argwhere(bm.allLabels > 255)
391
+ bm.allLabels = np.delete(bm.allLabels, index)
392
+ print('Warning: Only labels <=255 are allowed. Labels higher than 255 will be removed.')
393
+
394
+ # add background label if not existing
395
+ if not np.any(bm.allLabels==0):
396
+ bm.allLabels = np.append(0, bm.allLabels)
397
+
398
+ # compute only specific labels
399
+ labels_to_compute = (bm.only).split(',')
400
+ if not any([x in ['all', 'All', 'ALL'] for x in labels_to_compute]):
401
+ labels_to_remove = [k for k in bm.allLabels if str(k) not in labels_to_compute and k > 0]
402
+ for k in labels_to_remove:
403
+ bm.labelData[bm.labelData == k] = 0
404
+ index = np.argwhere(bm.allLabels==k)
405
+ bm.allLabels = np.delete(bm.allLabels, index)
406
+
407
+ # ignore specific labels
408
+ labels_to_remove = (bm.ignore).split(',')
409
+ if not any([x in ['none', 'None', 'NONE'] for x in labels_to_remove]):
410
+ for k in labels_to_remove:
411
+ try:
412
+ k = int(k)
413
+ bm.labelData[bm.labelData == k] = 0
414
+ index = np.argwhere(bm.allLabels==k)
415
+ bm.allLabels = np.delete(bm.allLabels, index)
416
+ except:
417
+ pass
418
+
419
+ # number of labels
420
+ bm.nol = len(bm.allLabels)
421
+
422
+ if bm.nol < 2:
423
+ return _error_(bm, 'No labeled slices found.')
424
+
425
+ bm.success = True
426
+ return bm
427
+
428
+ def save_data(path_to_final, final, header=None, final_image_type=None, compress=True):
429
+ if final_image_type == None:
430
+ final_image_type = os.path.splitext(path_to_final)[1]
431
+ if final_image_type == '.gz':
432
+ final_image_type = '.nii.gz'
433
+ if final_image_type == '.am':
434
+ final = [final]
435
+ if len(header) > 1:
436
+ for arr in header[1:]:
437
+ final.append(arr)
438
+ header = header[0]
439
+ np_to_amira(path_to_final, final, header)
440
+ elif final_image_type == '.nc':
441
+ np_to_nc(path_to_final, final, header)
442
+ elif final_image_type in ['.hdr', '.mhd', '.mha', '.nrrd', '.nii', '.nii.gz']:
443
+ simg = sitk.GetImageFromArray(final)
444
+ simg.CopyInformation(header)
445
+ sitk.WriteImage(simg, path_to_final, useCompression=compress)
446
+ elif final_image_type in ['.zip', 'directory', '']:
447
+ # make results directory
448
+ if final_image_type == '.zip':
449
+ results_dir = BASE_DIR + '/tmp/' + id_generator(40)
450
+ os.makedirs(results_dir)
451
+ os.chmod(results_dir, 0o770)
452
+ else:
453
+ results_dir = path_to_final
454
+ if not os.path.isdir(results_dir):
455
+ os.makedirs(results_dir)
456
+ os.chmod(results_dir, 0o770)
457
+ # save data as NC blocks
458
+ if os.path.splitext(header[1][0])[1] == '.nc':
459
+ np_to_nc(results_dir, final, header)
460
+ file_names = header[1]
461
+ # save data as PNG, TIF, DICOM slices
462
+ else:
463
+ header, file_names, final_dtype = header[0], header[1], header[2]
464
+ final = final.astype(final_dtype)
465
+ final = np.swapaxes(final, 2, 1)
466
+ for k, file in enumerate(file_names):
467
+ save(final[k], results_dir + '/' + os.path.basename(file), header[k])
468
+ # zip data
469
+ if final_image_type == '.zip':
470
+ with zipfile.ZipFile(path_to_final, 'w') as zip:
471
+ for file in file_names:
472
+ zip.write(results_dir + '/' + os.path.basename(file), os.path.basename(file))
473
+ if os.path.isdir(results_dir):
474
+ shutil.rmtree(results_dir)
475
+ else:
476
+ imageSize = int(final.nbytes * 10e-7)
477
+ bigtiff = True if imageSize > 2000 else False
478
+ try:
479
+ compress = 'zlib' if compress else None
480
+ imwrite(path_to_final, final, bigtiff=bigtiff, compression=compress)
481
+ except:
482
+ compress = 6 if compress else 0
483
+ imwrite(path_to_final, final, bigtiff=bigtiff, compress=compress)
484
+
485
+ def color_to_gray(labelData):
486
+ if len(labelData.shape) == 4 and labelData.shape[1] == 3:
487
+ labelData = labelData.astype(np.float16)
488
+ labelData -= np.amin(labelData)
489
+ labelData /= np.amax(labelData)
490
+ labelData = 0.299 * labelData[:,0] + 0.587 * labelData[:,1] + 0.114 * labelData[:,2]
491
+ labelData *= 255.0
492
+ labelData = labelData.astype(np.uint8)
493
+ labelData = delbackground(labelData)
494
+ elif len(labelData.shape) == 4 and labelData.shape[3] == 3:
495
+ labelData = labelData.astype(np.float16)
496
+ labelData -= np.amin(labelData)
497
+ labelData /= np.amax(labelData)
498
+ labelData = 0.299 * labelData[:,:,:,0] + 0.587 * labelData[:,:,:,1] + 0.114 * labelData[:,:,:,2]
499
+ labelData *= 255.0
500
+ labelData = labelData.astype(np.uint8)
501
+ labelData = delbackground(labelData)
502
+ return labelData
503
+
504
+ def delbackground(labels):
505
+ allLabels, labelcounts = np.unique(labels, return_counts=True)
506
+ index = np.argmax(labelcounts)
507
+ labels[labels==allLabels[index]] = 0
508
+ return labels
509
+
510
+ def _get_platform(bm):
511
+
512
+ # import PyCUDA
513
+ if bm.platform in ['cuda', None]:
514
+ try:
515
+ import pycuda.driver as cuda
516
+ import pycuda.gpuarray as gpuarray
517
+ cuda.init()
518
+ bm.available_devices = cuda.Device.count()
519
+ if bm.available_devices > 0:
520
+ dev = cuda.Device(0)
521
+ ctx = dev.make_context()
522
+ a_gpu = gpuarray.to_gpu(np.random.randn(4,4).astype(np.float32))
523
+ a_doubled = (2*a_gpu).get()
524
+ ctx.pop()
525
+ del ctx
526
+ bm.platform = 'cuda'
527
+ return bm
528
+ elif bm.platform == 'cuda':
529
+ print('Error: No CUDA device found.')
530
+ bm.success = False
531
+ return bm
532
+ except:
533
+ pass
534
+
535
+ # import PyOpenCL
536
+ try:
537
+ import pyopencl as cl
538
+ except ImportError:
539
+ cl = None
540
+
541
+ # select the first detected device
542
+ if cl and bm.platform is None:
543
+ for vendor in ['NVIDIA', 'Intel', 'AMD', 'Apple']:
544
+ for dev, device_type in [('GPU',cl.device_type.GPU),('CPU',cl.device_type.CPU)]:
545
+ all_platforms = cl.get_platforms()
546
+ my_devices = []
547
+ for p in all_platforms:
548
+ if p.get_devices(device_type=device_type) and vendor in p.name:
549
+ my_devices = p.get_devices(device_type=device_type)
550
+ if my_devices:
551
+ bm.platform = 'opencl_'+vendor+'_'+dev
552
+ if 'OMPI_COMMAND' in os.environ and dev == 'CPU':
553
+ print("Error: OpenCL CPU does not support MPI. Start Biomedisa without 'mpirun' or 'mpiexec'.")
554
+ bm.success = False
555
+ return bm
556
+ else:
557
+ bm.available_devices = len(my_devices)
558
+ print('Detected platform:', bm.platform)
559
+ print('Detected devices:', my_devices)
560
+ return bm
561
+
562
+ # explicitly select the OpenCL device
563
+ elif cl and len(bm.platform.split('_')) == 3:
564
+ plat, vendor, dev = bm.platform.split('_')
565
+ device_type=cl.device_type.GPU if dev=='GPU' else cl.device_type.CPU
566
+ all_platforms = cl.get_platforms()
567
+ my_devices = []
568
+ for p in all_platforms:
569
+ if p.get_devices(device_type=device_type) and vendor in p.name:
570
+ my_devices = p.get_devices(device_type=device_type)
571
+ if my_devices:
572
+ if 'OMPI_COMMAND' in os.environ and dev == 'CPU':
573
+ print("Error: OpenCL CPU does not support MPI. Start Biomedisa without 'mpirun' or 'mpiexec'.")
574
+ bm.success = False
575
+ return bm
576
+ else:
577
+ bm.available_devices = len(my_devices)
578
+ print('Detected platform:', bm.platform)
579
+ print('Detected devices:', my_devices)
580
+ return bm
581
+
582
+ # stop the process if no device is detected
583
+ if bm.platform is None:
584
+ bm.platform = 'OpenCL or CUDA'
585
+ print(f'Error: No {bm.platform} device found.')
586
+ bm.success = False
587
+ return bm
588
+
589
+ def _get_device(platform, dev_id):
590
+ import pyopencl as cl
591
+ plat, vendor, dev = platform.split('_')
592
+ device_type=cl.device_type.GPU if dev=='GPU' else cl.device_type.CPU
593
+ all_platforms = cl.get_platforms()
594
+ for p in all_platforms:
595
+ if p.get_devices(device_type=device_type) and vendor in p.name:
596
+ my_devices = p.get_devices(device_type=device_type)
597
+ context = cl.Context(devices=my_devices)
598
+ queue = cl.CommandQueue(context, my_devices[dev_id % len(my_devices)])
599
+ return context, queue
600
+
601
+ def read_labeled_slices(arr):
602
+ data = np.zeros((0, arr.shape[1], arr.shape[2]), dtype=np.int32)
603
+ indices = []
604
+ for k, slc in enumerate(arr[:]):
605
+ if np.any(slc):
606
+ data = np.append(data, [arr[k]], axis=0)
607
+ indices.append(k)
608
+ return indices, data
609
+
610
+ def read_labeled_slices_allx(arr, ax):
611
+ gradient = np.zeros(arr.shape, dtype=np.int8)
612
+ ones = np.zeros_like(gradient)
613
+ ones[arr != 0] = 1
614
+ tmp = ones[:,:-1] - ones[:,1:]
615
+ tmp = np.abs(tmp)
616
+ gradient[:,:-1] += tmp
617
+ gradient[:,1:] += tmp
618
+ ones[gradient == 2] = 0
619
+ gradient.fill(0)
620
+ tmp = ones[:,:,:-1] - ones[:,:,1:]
621
+ tmp = np.abs(tmp)
622
+ gradient[:,:,:-1] += tmp
623
+ gradient[:,:,1:] += tmp
624
+ ones[gradient == 2] = 0
625
+ indices = []
626
+ data = np.zeros((0, arr.shape[1], arr.shape[2]), dtype=np.int32)
627
+ for k, slc in enumerate(ones[:]):
628
+ if np.any(slc):
629
+ data = np.append(data, [arr[k]], axis=0)
630
+ indices.append((k, ax))
631
+ return indices, data
632
+
633
+ def read_indices_allx(arr, ax):
634
+ gradient = np.zeros(arr.shape, dtype=np.int8)
635
+ ones = np.zeros_like(gradient)
636
+ ones[arr != 0] = 1
637
+ tmp = ones[:,:-1] - ones[:,1:]
638
+ tmp = np.abs(tmp)
639
+ gradient[:,:-1] += tmp
640
+ gradient[:,1:] += tmp
641
+ ones[gradient == 2] = 0
642
+ gradient.fill(0)
643
+ tmp = ones[:,:,:-1] - ones[:,:,1:]
644
+ tmp = np.abs(tmp)
645
+ gradient[:,:,:-1] += tmp
646
+ gradient[:,:,1:] += tmp
647
+ ones[gradient == 2] = 0
648
+ indices = []
649
+ for k, slc in enumerate(ones[:]):
650
+ if np.any(slc):
651
+ indices.append((k, ax))
652
+ return indices
653
+
654
+ def read_labeled_slices_large(arr):
655
+ data = np.zeros((0, arr.shape[1], arr.shape[2]), dtype=np.int32)
656
+ indices = []
657
+ i = 0
658
+ while i < arr.shape[0]:
659
+ slc = arr[i]
660
+ if np.any(slc):
661
+ data = np.append(data, [arr[i]], axis=0)
662
+ indices.append(i)
663
+ i += 5
664
+ else:
665
+ i += 1
666
+ return indices, data
667
+
668
+ def read_labeled_slices_allx_large(arr):
669
+ gradient = np.zeros(arr.shape, dtype=np.int8)
670
+ ones = np.zeros_like(gradient)
671
+ ones[arr > 0] = 1
672
+ tmp = ones[:,:-1] - ones[:,1:]
673
+ tmp = np.abs(tmp)
674
+ gradient[:,:-1] += tmp
675
+ gradient[:,1:] += tmp
676
+ ones[gradient == 2] = 0
677
+ gradient.fill(0)
678
+ tmp = ones[:,:,:-1] - ones[:,:,1:]
679
+ tmp = np.abs(tmp)
680
+ gradient[:,:,:-1] += tmp
681
+ gradient[:,:,1:] += tmp
682
+ ones[gradient == 2] = 0
683
+ indices = []
684
+ data = np.zeros((0, arr.shape[1], arr.shape[2]), dtype=np.int32)
685
+ for k, slc in enumerate(ones[:]):
686
+ if np.any(slc):
687
+ data = np.append(data, [arr[k]], axis=0)
688
+ indices.append(k)
689
+ return indices, data
690
+
691
+ def predict_blocksize(bm):
692
+ zsh, ysh, xsh = bm.labelData.shape
693
+ argmin_z, argmax_z, argmin_y, argmax_y, argmin_x, argmax_x = zsh, 0, ysh, 0, xsh, 0
694
+ for k in range(zsh):
695
+ y, x = np.nonzero(bm.labelData[k])
696
+ if x.any():
697
+ argmin_x = min(argmin_x, np.amin(x))
698
+ argmax_x = max(argmax_x, np.amax(x))
699
+ argmin_y = min(argmin_y, np.amin(y))
700
+ argmax_y = max(argmax_y, np.amax(y))
701
+ argmin_z = min(argmin_z, k)
702
+ argmax_z = max(argmax_z, k)
703
+ zmin, zmax = argmin_z, argmax_z
704
+ bm.argmin_x = argmin_x - 100 if argmin_x - 100 > 0 else 0
705
+ bm.argmax_x = argmax_x + 100 if argmax_x + 100 < xsh else xsh
706
+ bm.argmin_y = argmin_y - 100 if argmin_y - 100 > 0 else 0
707
+ bm.argmax_y = argmax_y + 100 if argmax_y + 100 < ysh else ysh
708
+ bm.argmin_z = argmin_z - 100 if argmin_z - 100 > 0 else 0
709
+ bm.argmax_z = argmax_z + 100 if argmax_z + 100 < zsh else zsh
710
+ return bm
711
+
712
+ def splitlargedata(data):
713
+ dataMemory = data.nbytes
714
+ dataListe = []
715
+ if dataMemory > 1500000000:
716
+ mod = dataMemory / float(1500000000)
717
+ mod2 = int(math.ceil(mod))
718
+ mod3 = divmod(data.shape[0], mod2)[0]
719
+ for k in range(mod2):
720
+ dataListe.append(data[mod3*k:mod3*(k+1)])
721
+ dataListe.append(data[mod3*mod2:])
722
+ else:
723
+ dataListe.append(data)
724
+ return dataListe
725
+
726
+ def sendToChildLarge(comm, indices, dest, dataListe, Labels, nbrw, sorw, blocks,
727
+ allx, allLabels, smooth, uncertainty, platform):
728
+ from mpi4py import MPI
729
+ comm.send(len(dataListe), dest=dest, tag=0)
730
+ for k, tmp in enumerate(dataListe):
731
+ tmp = tmp.copy(order='C')
732
+ comm.send([tmp.shape[0], tmp.shape[1], tmp.shape[2], tmp.dtype], dest=dest, tag=10+(2*k))
733
+ if tmp.dtype == 'uint8':
734
+ comm.Send([tmp, MPI.BYTE], dest=dest, tag=10+(2*k+1))
735
+ else:
736
+ comm.Send([tmp, MPI.FLOAT], dest=dest, tag=10+(2*k+1))
737
+
738
+ comm.send([nbrw, sorw, allx, smooth, uncertainty, platform], dest=dest, tag=1)
739
+
740
+ if allx:
741
+ for k in range(3):
742
+ labelsListe = splitlargedata(Labels[k])
743
+ comm.send(len(labelsListe), dest=dest, tag=2+k)
744
+ for l, tmp in enumerate(labelsListe):
745
+ tmp = tmp.copy(order='C')
746
+ comm.send([tmp.shape[0], tmp.shape[1], tmp.shape[2]], dest=dest, tag=100+(2*k))
747
+ comm.Send([tmp, MPI.INT], dest=dest, tag=100+(2*k+1))
748
+ else:
749
+ labelsListe = splitlargedata(Labels)
750
+ comm.send(len(labelsListe), dest=dest, tag=2)
751
+ for k, tmp in enumerate(labelsListe):
752
+ tmp = tmp.copy(order='C')
753
+ comm.send([tmp.shape[0], tmp.shape[1], tmp.shape[2]], dest=dest, tag=100+(2*k))
754
+ comm.Send([tmp, MPI.INT], dest=dest, tag=100+(2*k+1))
755
+
756
+ comm.send(allLabels, dest=dest, tag=99)
757
+ comm.send(indices, dest=dest, tag=8)
758
+ comm.send(blocks, dest=dest, tag=9)
759
+
760
+ def sendToChild(comm, indices, indices_child, dest, data, Labels, nbrw, sorw, allx, platform):
761
+ from mpi4py import MPI
762
+ data = data.copy(order='C')
763
+ comm.send([data.shape[0], data.shape[1], data.shape[2], data.dtype], dest=dest, tag=0)
764
+ if data.dtype == 'uint8':
765
+ comm.Send([data, MPI.BYTE], dest=dest, tag=1)
766
+ else:
767
+ comm.Send([data, MPI.FLOAT], dest=dest, tag=1)
768
+ comm.send([allx, nbrw, sorw, platform], dest=dest, tag=2)
769
+ if allx:
770
+ for k in range(3):
771
+ labels = Labels[k].copy(order='C')
772
+ comm.send([labels.shape[0], labels.shape[1], labels.shape[2]], dest=dest, tag=k+3)
773
+ comm.Send([labels, MPI.INT], dest=dest, tag=k+6)
774
+ else:
775
+ labels = Labels.copy(order='C')
776
+ comm.send([labels.shape[0], labels.shape[1], labels.shape[2]], dest=dest, tag=3)
777
+ comm.Send([labels, MPI.INT], dest=dest, tag=6)
778
+ comm.send(indices, dest=dest, tag=9)
779
+ comm.send(indices_child, dest=dest, tag=10)
780
+
781
+ def _split_indices(indices, ngpus):
782
+ ngpus = ngpus if ngpus < len(indices) else len(indices)
783
+ nindices = len(indices)
784
+ parts = []
785
+ for i in range(0, ngpus):
786
+ slice_idx = indices[i]
787
+ parts.append([slice_idx])
788
+ if ngpus < nindices:
789
+ for i in range(ngpus, nindices):
790
+ gid = i % ngpus
791
+ slice_idx = indices[i]
792
+ parts[gid].append(slice_idx)
793
+ return parts
794
+
795
+ def get_labels(pre_final, labels):
796
+ numos = np.unique(pre_final)
797
+ final = np.zeros_like(pre_final)
798
+ for k in numos[1:]:
799
+ final[pre_final == k] = labels[k]
800
+ return final
801
+