biomedisa 2024.5.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. biomedisa/__init__.py +53 -0
  2. biomedisa/__main__.py +18 -0
  3. biomedisa/biomedisa_features/DataGenerator.py +299 -0
  4. biomedisa/biomedisa_features/DataGeneratorCrop.py +121 -0
  5. biomedisa/biomedisa_features/PredictDataGenerator.py +87 -0
  6. biomedisa/biomedisa_features/PredictDataGeneratorCrop.py +74 -0
  7. biomedisa/biomedisa_features/__init__.py +0 -0
  8. biomedisa/biomedisa_features/active_contour.py +434 -0
  9. biomedisa/biomedisa_features/amira_to_np/__init__.py +0 -0
  10. biomedisa/biomedisa_features/amira_to_np/amira_data_stream.py +980 -0
  11. biomedisa/biomedisa_features/amira_to_np/amira_grammar.py +369 -0
  12. biomedisa/biomedisa_features/amira_to_np/amira_header.py +290 -0
  13. biomedisa/biomedisa_features/amira_to_np/amira_helper.py +72 -0
  14. biomedisa/biomedisa_features/assd.py +167 -0
  15. biomedisa/biomedisa_features/biomedisa_helper.py +801 -0
  16. biomedisa/biomedisa_features/create_slices.py +286 -0
  17. biomedisa/biomedisa_features/crop_helper.py +586 -0
  18. biomedisa/biomedisa_features/curvop_numba.py +149 -0
  19. biomedisa/biomedisa_features/django_env.py +172 -0
  20. biomedisa/biomedisa_features/keras_helper.py +1219 -0
  21. biomedisa/biomedisa_features/nc_reader.py +179 -0
  22. biomedisa/biomedisa_features/pid.py +52 -0
  23. biomedisa/biomedisa_features/process_image.py +253 -0
  24. biomedisa/biomedisa_features/pycuda_test.py +84 -0
  25. biomedisa/biomedisa_features/random_walk/__init__.py +0 -0
  26. biomedisa/biomedisa_features/random_walk/gpu_kernels.py +183 -0
  27. biomedisa/biomedisa_features/random_walk/pycuda_large.py +826 -0
  28. biomedisa/biomedisa_features/random_walk/pycuda_large_allx.py +806 -0
  29. biomedisa/biomedisa_features/random_walk/pycuda_small.py +414 -0
  30. biomedisa/biomedisa_features/random_walk/pycuda_small_allx.py +493 -0
  31. biomedisa/biomedisa_features/random_walk/pyopencl_large.py +760 -0
  32. biomedisa/biomedisa_features/random_walk/pyopencl_small.py +441 -0
  33. biomedisa/biomedisa_features/random_walk/rw_large.py +390 -0
  34. biomedisa/biomedisa_features/random_walk/rw_small.py +310 -0
  35. biomedisa/biomedisa_features/remove_outlier.py +399 -0
  36. biomedisa/biomedisa_features/split_volume.py +274 -0
  37. biomedisa/deeplearning.py +519 -0
  38. biomedisa/interpolation.py +371 -0
  39. biomedisa/mesh.py +406 -0
  40. biomedisa-2024.5.14.dist-info/LICENSE +191 -0
  41. biomedisa-2024.5.14.dist-info/METADATA +306 -0
  42. biomedisa-2024.5.14.dist-info/RECORD +44 -0
  43. biomedisa-2024.5.14.dist-info/WHEEL +5 -0
  44. biomedisa-2024.5.14.dist-info/top_level.txt +1 -0
@@ -0,0 +1,179 @@
1
+ ##########################################################################
2
+ ## ##
3
+ ## Copyright (c) 2024 Philipp Lösel. All rights reserved. ##
4
+ ## ##
5
+ ## This file is part of the open source project biomedisa. ##
6
+ ## ##
7
+ ## Licensed under the European Union Public Licence (EUPL) ##
8
+ ## v1.2, or - as soon as they will be approved by the ##
9
+ ## European Commission - subsequent versions of the EUPL; ##
10
+ ## ##
11
+ ## You may redistribute it and/or modify it under the terms ##
12
+ ## of the EUPL v1.2. You may not use this work except in ##
13
+ ## compliance with this Licence. ##
14
+ ## ##
15
+ ## You can obtain a copy of the Licence at: ##
16
+ ## ##
17
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
18
+ ## ##
19
+ ## Unless required by applicable law or agreed to in ##
20
+ ## writing, software distributed under the Licence is ##
21
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
22
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
23
+ ## ##
24
+ ## See the Licence for the specific language governing ##
25
+ ## permissions and limitations under the Licence. ##
26
+ ## ##
27
+ ##########################################################################
28
+
29
+ import os
30
+ import glob
31
+ import numpy as np
32
+
33
+ def save_nc_block(path_to_dst, arr, path_to_src, offset):
34
+ try:
35
+ import netCDF4
36
+ except:
37
+ raise Exception("netCDF4 not found. please use `pip install netCDF4`")
38
+ with netCDF4.Dataset(path_to_src, 'r') as src:
39
+ with netCDF4.Dataset(path_to_dst, 'w') as dst:
40
+ # copy global attributes all at once via dictionary
41
+ dst.setncatts(src.__dict__)
42
+ # copy dimensions
43
+ for name, dimension in src.dimensions.items():
44
+ dst.createDimension(
45
+ name, (len(dimension) if not dimension.isunlimited() else None))
46
+ # copy all file data
47
+ for name, variable in src.variables.items():
48
+ if name in ['labels','segmented']:
49
+ srcarr = src[name][:]
50
+ zsh, ysh, xsh = srcarr.shape
51
+ x = dst.createVariable(name, variable.datatype, variable.dimensions, compression='zlib')
52
+ dst[name][:] = arr[offset:offset+zsh]
53
+ elif name == 'tomo':
54
+ srcarr = src[name][:]
55
+ zsh, ysh, xsh = srcarr.shape
56
+ x = dst.createVariable(name, variable.datatype, variable.dimensions)
57
+ dst[name][:] = arr[offset:offset+zsh]
58
+ else:
59
+ x = dst.createVariable(name, variable.datatype, variable.dimensions)
60
+ dst[name][:] = src[name][:]
61
+ # copy variable attributes all at once via dictionary
62
+ dst[name].setncatts(src[name].__dict__)
63
+ return offset+zsh
64
+
65
+ def np_to_nc(results_dir, labeled_array, header=None, reference_dir=None, reference_file=None, start=0, stop=None):
66
+ try:
67
+ import netCDF4
68
+ except:
69
+ raise Exception("netCDF4 not found. please use `pip install netCDF4`")
70
+
71
+ # save as file or directory
72
+ is_file = False
73
+ if os.path.splitext(results_dir)[1] == '.nc':
74
+ is_file = True
75
+
76
+ # get reference information
77
+ if reference_dir:
78
+ ref_files = glob.glob(reference_dir+'/*.nc')
79
+ ref_files.sort()
80
+ elif header:
81
+ ref_files = header[1]
82
+ reference_dir = os.path.dirname(ref_files[0])
83
+ elif reference_file:
84
+ ref_files = [reference_file]
85
+ else:
86
+ raise Exception("reference file(s) required")
87
+
88
+ if is_file and len(ref_files) > 1:
89
+ raise Exception("reference needs to be a file")
90
+
91
+ if not stop:
92
+ stop = len(ref_files)-1
93
+
94
+ # save volume by volume
95
+ offset = 0
96
+ for path_to_src in ref_files[start:stop+1]:
97
+ if is_file:
98
+ path_to_dst = results_dir
99
+ else:
100
+ path_to_dst = results_dir + '/' + os.path.basename(path_to_src)
101
+ offset = save_nc_block(path_to_dst, labeled_array, path_to_src, offset)
102
+
103
+ def nc_to_np(base_dir, start=0, stop=None, show_keys=False):
104
+ try:
105
+ import netCDF4
106
+ except:
107
+ raise Exception("netCDF4 not found. please use `pip install netCDF4`")
108
+
109
+ if os.path.isfile(base_dir):
110
+ # decompress bz2 files
111
+ if '.bz2' in base_dir:
112
+ import bz2
113
+ zipfile = bz2.BZ2File(base_dir) # open the file
114
+ data = zipfile.read() # get the decompressed data
115
+ newfilepath = base_dir[:-4] # assuming the filepath ends with .bz2
116
+ open(newfilepath, 'wb').write(data)
117
+ f = netCDF4.Dataset(newfilepath,'r')
118
+ else:
119
+ f = netCDF4.Dataset(base_dir,'r')
120
+ if show_keys:
121
+ print(f.variables.keys())
122
+ for n in ['labels', 'segmented', 'tomo']:
123
+ if n in f.variables.keys():
124
+ name = n
125
+ output = f.variables[name]
126
+ output = np.copy(output, order='C')
127
+ # remove tmp file
128
+ if '.bz2' in base_dir:
129
+ os.remove(newfilepath)
130
+ header = [name, [base_dir], output.dtype]
131
+
132
+ elif os.path.isdir(base_dir):
133
+ # read volume by volume
134
+ files = glob.glob(base_dir+'/**/*.nc', recursive=True)
135
+ files += glob.glob(base_dir+'/**/*.bz2', recursive=True)
136
+ files.sort()
137
+
138
+ # check for compression
139
+ if os.path.splitext(files[0])[1]=='.bz2':
140
+ import bz2
141
+
142
+ if not stop:
143
+ stop = len(files)-1
144
+
145
+ for i,filepath in enumerate(files[start:stop+1]):
146
+
147
+ # decompress bz2 files
148
+ if '.bz2' in filepath:
149
+ zipfile = bz2.BZ2File(filepath) # open the file
150
+ data = zipfile.read() # get the decompressed data
151
+ newfilepath = filepath[:-4] # assuming the filepath ends with .bz2
152
+ open(newfilepath, 'wb').write(data)
153
+ f = netCDF4.Dataset(newfilepath,'r')
154
+ else:
155
+ f = netCDF4.Dataset(filepath,'r')
156
+
157
+ if show_keys:
158
+ print(f.variables.keys())
159
+ for n in ['labels', 'segmented', 'tomo']:
160
+ if n in f.variables.keys():
161
+ name = n
162
+
163
+ a = f.variables[name]
164
+ a = np.copy(a, order='C')
165
+
166
+ # remove tmp file
167
+ if '.bz2' in filepath:
168
+ os.remove(newfilepath)
169
+
170
+ # append output array
171
+ if i==0:
172
+ output = a
173
+ else:
174
+ output = np.append(output, a, axis=0)
175
+
176
+ header = [name, files[start:stop+1], a.dtype]
177
+
178
+ return output, header
179
+
@@ -0,0 +1,52 @@
1
+ ##########################################################################
2
+ ## ##
3
+ ## Copyright (c) 2023 Philipp Lösel. All rights reserved. ##
4
+ ## ##
5
+ ## This file is part of the open source project biomedisa. ##
6
+ ## ##
7
+ ## Licensed under the European Union Public Licence (EUPL) ##
8
+ ## v1.2, or - as soon as they will be approved by the ##
9
+ ## European Commission - subsequent versions of the EUPL; ##
10
+ ## ##
11
+ ## You may redistribute it and/or modify it under the terms ##
12
+ ## of the EUPL v1.2. You may not use this work except in ##
13
+ ## compliance with this Licence. ##
14
+ ## ##
15
+ ## You can obtain a copy of the Licence at: ##
16
+ ## ##
17
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
18
+ ## ##
19
+ ## Unless required by applicable law or agreed to in ##
20
+ ## writing, software distributed under the Licence is ##
21
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
22
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
23
+ ## ##
24
+ ## See the Licence for the specific language governing ##
25
+ ## permissions and limitations under the Licence. ##
26
+ ## ##
27
+ ##########################################################################
28
+
29
+ import sys, os
30
+
31
+ if __name__ == "__main__":
32
+
33
+ # queue
34
+ queue = sys.argv[1]
35
+
36
+ # path to biomedisa
37
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
38
+
39
+ # wait for pid
40
+ while not os.path.exists(BASE_DIR + f'/log/pid_{queue}'):
41
+ pass
42
+
43
+ # get pid
44
+ with open(BASE_DIR + f'/log/pid_{queue}', 'r') as pidfile:
45
+ pid = pidfile.read()
46
+
47
+ # remove pid file
48
+ os.remove(BASE_DIR + f'/log/pid_{queue}')
49
+
50
+ # stop process
51
+ os.system(f'kill {pid}')
52
+
@@ -0,0 +1,253 @@
1
+ #!/usr/bin/python3
2
+ ##########################################################################
3
+ ## ##
4
+ ## Copyright (c) 2024 Philipp Lösel. All rights reserved. ##
5
+ ## ##
6
+ ## This file is part of the open source project biomedisa. ##
7
+ ## ##
8
+ ## Licensed under the European Union Public Licence (EUPL) ##
9
+ ## v1.2, or - as soon as they will be approved by the ##
10
+ ## European Commission - subsequent versions of the EUPL; ##
11
+ ## ##
12
+ ## You may redistribute it and/or modify it under the terms ##
13
+ ## of the EUPL v1.2. You may not use this work except in ##
14
+ ## compliance with this Licence. ##
15
+ ## ##
16
+ ## You can obtain a copy of the Licence at: ##
17
+ ## ##
18
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
19
+ ## ##
20
+ ## Unless required by applicable law or agreed to in ##
21
+ ## writing, software distributed under the Licence is ##
22
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
23
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
24
+ ## ##
25
+ ## See the Licence for the specific language governing ##
26
+ ## permissions and limitations under the Licence. ##
27
+ ## ##
28
+ ##########################################################################
29
+
30
+ import os, sys
31
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
32
+ if not BASE_DIR in sys.path:
33
+ sys.path.append(BASE_DIR)
34
+ import numpy as np
35
+ from biomedisa_features.biomedisa_helper import (load_data, save_data,
36
+ img_to_uint8, smooth_img_3x3)
37
+ from biomedisa_features.create_slices import create_slices
38
+ from shutil import copytree
39
+ import argparse
40
+ import traceback
41
+ import subprocess
42
+
43
+ def init_process_image(id, process=None):
44
+
45
+ import django
46
+ django.setup()
47
+ from biomedisa_app.models import Upload
48
+ from biomedisa_app.config import config
49
+ from biomedisa_app.views import send_data_to_host, qsub_start, qsub_stop, unique_file_path
50
+ from redis import Redis
51
+ from rq import Queue
52
+
53
+ # get object
54
+ try:
55
+ img = Upload.objects.get(pk=id)
56
+ except Upload.DoesNotExist:
57
+ img.status = 0
58
+ img.save()
59
+ Upload.objects.create(user=img.user, project=img.project,
60
+ log=1, imageType=None, shortfilename='File has been removed.')
61
+
62
+ # get host information
63
+ host = ''
64
+ host_base = BASE_DIR
65
+ subhost, qsub_pid = None, None
66
+ if 'REMOTE_QUEUE_HOST' in config:
67
+ host = config['REMOTE_QUEUE_HOST']
68
+ if host and 'REMOTE_QUEUE_BASE_DIR' in config:
69
+ host_base = config['REMOTE_QUEUE_BASE_DIR']
70
+
71
+ # check if aborted
72
+ if img.status > 0:
73
+
74
+ if process=='smooth' and img.imageType!=1:
75
+ Upload.objects.create(user=img.user, project=img.project, log=1,
76
+ imageType=None, shortfilename='No valid image data.')
77
+
78
+ else:
79
+ # set status to processing
80
+ img.status = 2
81
+ img.save()
82
+
83
+ # suffix
84
+ if process == 'convert':
85
+ suffix = '.8bit.tif'
86
+ elif process == 'smooth':
87
+ suffix = '.denoised.tif'
88
+
89
+ # create path to result
90
+ filename, extension = os.path.splitext(img.pic.path)
91
+ if extension == '.gz':
92
+ extension = '.nii.gz'
93
+ filename = filename[:-4]
94
+ path_to_result = unique_file_path(filename + suffix)
95
+ new_short_name = os.path.basename(path_to_result)
96
+ pic_path = 'images/%s/%s' %(img.user.username, new_short_name)
97
+
98
+ # remote server
99
+ if host:
100
+
101
+ # command
102
+ cmd = ['python3', host_base+'/biomedisa_features/process_image.py', img.pic.path.replace(BASE_DIR,host_base)]
103
+ cmd += [f'-iid={img.id}', '-r']
104
+ if process == 'convert':
105
+ cmd += ['-c']
106
+ elif process == 'smooth':
107
+ cmd += ['-s']
108
+
109
+ # create user directory
110
+ subprocess.Popen(['ssh', host, 'mkdir', '-p', host_base+'/private_storage/images/'+img.user.username]).wait()
111
+
112
+ # send data to host
113
+ success = send_data_to_host(img.pic.path, host+':'+img.pic.path.replace(BASE_DIR,host_base))
114
+
115
+ # qsub start
116
+ if 'REMOTE_QUEUE_QSUB' in config and config['REMOTE_QUEUE_QSUB']:
117
+ subhost, qsub_pid = qsub_start(host, host_base, 5)
118
+
119
+ # check if aborted
120
+ img = Upload.objects.get(pk=img.id)
121
+ if img.status==2 and img.queue==5 and success==0:
122
+
123
+ # set pid and processing status
124
+ img.message = 'Processing'
125
+ img.pid = -1
126
+ img.save()
127
+
128
+ # process image
129
+ if subhost:
130
+ cmd = ['ssh', '-t', host, 'ssh', subhost] + cmd
131
+ else:
132
+ cmd = ['ssh', host] + cmd
133
+ subprocess.Popen(cmd).wait()
134
+
135
+ # check if aborted
136
+ success = subprocess.Popen(['scp', host+':'+host_base+f'/log/pid_5', BASE_DIR+f'/log/pid_5']).wait()
137
+
138
+ # get result
139
+ if success==0:
140
+ # remove pid file
141
+ subprocess.Popen(['ssh', host, 'rm', host_base+f'/log/pid_5']).wait()
142
+
143
+ result_on_host = img.pic.path.replace(BASE_DIR,host_base)
144
+ result_on_host = result_on_host.replace(os.path.splitext(result_on_host)[1], suffix)
145
+ success = subprocess.Popen(['scp', host+':'+result_on_host, path_to_result]).wait()
146
+
147
+ if success==0:
148
+ # create object
149
+ active = 1 if img.imageType == 3 else 0
150
+ Upload.objects.create(pic=pic_path, user=img.user, project=img.project,
151
+ imageType=img.imageType, shortfilename=new_short_name, active=active)
152
+ else:
153
+ # return error
154
+ Upload.objects.create(user=img.user, project=img.project,
155
+ log=1, imageType=None, shortfilename='Invalid data.')
156
+
157
+ # local server
158
+ else:
159
+
160
+ # set pid and processing status
161
+ img.pid = int(os.getpid())
162
+ img.message = 'Processing'
163
+ img.save()
164
+
165
+ # load data
166
+ data, header = load_data(img.pic.path, process='converter')
167
+ if data is None:
168
+ # return error
169
+ success = 1
170
+ Upload.objects.create(user=img.user, project=img.project,
171
+ log=1, imageType=None, shortfilename='Invalid data.')
172
+ else:
173
+ # process data
174
+ success = 0
175
+ if process == 'convert':
176
+ data = img_to_uint8(data)
177
+ save_data(path_to_result, data, final_image_type='.tif')
178
+ elif process == 'smooth':
179
+ data = smooth_img_3x3(data)
180
+ save_data(path_to_result, data, final_image_type='.tif')
181
+
182
+ # create object
183
+ active = 1 if img.imageType == 3 else 0
184
+ Upload.objects.create(pic=pic_path, user=img.user, project=img.project,
185
+ imageType=img.imageType, shortfilename=new_short_name, active=active)
186
+
187
+ # copy or create slices for preview
188
+ if success==0 and process == 'convert':
189
+ path_to_source = img.pic.path.replace('images', 'sliceviewer', 1)
190
+ path_to_dest = path_to_result.replace('images', 'sliceviewer', 1)
191
+ if os.path.exists(path_to_source) and not os.path.exists(path_to_dest):
192
+ copytree(path_to_source, path_to_dest, copy_function=os.link)
193
+ elif success==0 and process == 'smooth':
194
+ q = Queue('slices', connection=Redis())
195
+ job = q.enqueue_call(create_slices, args=(path_to_result, None,), timeout=-1)
196
+
197
+ # close process
198
+ img.status = 0
199
+ img.pid = 0
200
+ img.save()
201
+
202
+ # qsub stop
203
+ if 'REMOTE_QUEUE_QSUB' in config and config['REMOTE_QUEUE_QSUB']:
204
+ qsub_stop(host, host_base, 5, 'process_image', subhost, qsub_pid)
205
+
206
+ if __name__ == "__main__":
207
+
208
+ # initialize arguments
209
+ parser = argparse.ArgumentParser(description='Biomedisa process image.',
210
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
211
+
212
+ # required arguments
213
+ parser.add_argument('path_to_data', type=str, metavar='PATH_TO_DATA',
214
+ help='Location of image data')
215
+
216
+ # optional arguments
217
+ parser.add_argument('-c', '--convert', action='store_true', default=False,
218
+ help='Convert to uint8 TIFF')
219
+ parser.add_argument('-s', '--smooth', action='store_true', default=False,
220
+ help='Denoise/smooth image data')
221
+ parser.add_argument('-iid','--img_id', type=str, default=None,
222
+ help='Image ID within django environment/browser version')
223
+ parser.add_argument('-r','--remote', action='store_true', default=False,
224
+ help='Process is carried out on a remote server. Must be set up in config.py')
225
+ bm = parser.parse_args()
226
+
227
+ # set pid
228
+ if bm.remote:
229
+ from biomedisa_features.django_env import create_pid_object
230
+ create_pid_object(os.getpid(), True, 5, bm.img_id)
231
+
232
+ # load data
233
+ if bm.convert or bm.smooth:
234
+ bm.image, _ = load_data(bm.path_to_data)
235
+ if bm.image is None:
236
+ print('Error: Invalid data.')
237
+ else:
238
+ try:
239
+ # suffix
240
+ if bm.convert:
241
+ bm.image = img_to_uint8(bm.image)
242
+ suffix = '.8bit.tif'
243
+ elif bm.smooth:
244
+ bm.image = smooth_img_3x3(bm.image)
245
+ suffix = '.denoised.tif'
246
+
247
+ # save result
248
+ path_to_result = bm.path_to_data.replace(os.path.splitext(bm.path_to_data)[1], suffix)
249
+ save_data(path_to_result, bm.image, final_image_type='.tif')
250
+
251
+ except Exception as e:
252
+ print(traceback.format_exc())
253
+
@@ -0,0 +1,84 @@
1
+ ##########################################################################
2
+ ## ##
3
+ ## Copyright (c) 2022 Philipp Lösel. All rights reserved. ##
4
+ ## ##
5
+ ## This file is part of the open source project biomedisa. ##
6
+ ## ##
7
+ ## Licensed under the European Union Public Licence (EUPL) ##
8
+ ## v1.2, or - as soon as they will be approved by the ##
9
+ ## European Commission - subsequent versions of the EUPL; ##
10
+ ## ##
11
+ ## You may redistribute it and/or modify it under the terms ##
12
+ ## of the EUPL v1.2. You may not use this work except in ##
13
+ ## compliance with this Licence. ##
14
+ ## ##
15
+ ## You can obtain a copy of the Licence at: ##
16
+ ## ##
17
+ ## https://joinup.ec.europa.eu/page/eupl-text-11-12 ##
18
+ ## ##
19
+ ## Unless required by applicable law or agreed to in ##
20
+ ## writing, software distributed under the Licence is ##
21
+ ## distributed on an "AS IS" basis, WITHOUT WARRANTIES ##
22
+ ## OR CONDITIONS OF ANY KIND, either express or implied. ##
23
+ ## ##
24
+ ## See the Licence for the specific language governing ##
25
+ ## permissions and limitations under the Licence. ##
26
+ ## ##
27
+ ##########################################################################
28
+
29
+ import numpy as np
30
+ import pycuda.driver as cuda
31
+ import pycuda.gpuarray as gpuarray
32
+ from pycuda.compiler import SourceModule
33
+
34
+ if __name__ == "__main__":
35
+
36
+ cuda.init()
37
+ dev = cuda.Device(0)
38
+ ctx = dev.make_context()
39
+
40
+ code = """
41
+ __global__ void Funktion(int *a) {
42
+
43
+ int xsh = gridDim.x * 10;
44
+ int ysh = gridDim.y * 10;
45
+ int zsh = gridDim.z;
46
+
47
+ int column = blockIdx.x * 10 + threadIdx.x;
48
+ int row = blockIdx.y * 10 + threadIdx.y;
49
+ int plane = blockIdx.z;
50
+
51
+ int index = plane * ysh * xsh + row * xsh + column;
52
+
53
+ if ( index < xsh*ysh*zsh ) {
54
+ a[index] = index;
55
+ }
56
+
57
+ }
58
+ """
59
+ mod = SourceModule(code)
60
+ kernel = mod.get_function("Funktion")
61
+
62
+ xsh = 100
63
+ ysh = 100
64
+ zsh = 100
65
+
66
+ a = np.arange(xsh*ysh*zsh, dtype=np.int32)
67
+ a = a.reshape(zsh, ysh, xsh)
68
+
69
+ a_gpu = gpuarray.zeros((zsh, ysh, xsh), np.int32)
70
+
71
+ block = (10, 10, 1)
72
+ grid = (xsh//10, ysh//10, zsh)
73
+
74
+ kernel(a_gpu, block = block, grid = grid)
75
+
76
+ test = np.abs(a_gpu.get() - a)
77
+
78
+ if np.sum(test) == 0:
79
+ print("PyCUDA test okay!")
80
+ else:
81
+ print("Something went wrong!")
82
+
83
+ ctx.pop()
84
+ del ctx
File without changes