cuslines 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,472 @@
1
+ import numpy as np
2
+ from abc import ABC, abstractmethod
3
+ import logging
4
+ from importlib.resources import files
5
+ from time import time
6
+
7
+ from dipy.reconst import shm
8
+
9
+ from cuda.core import Device, LaunchConfig, Program, launch, ProgramOptions
10
+ from cuda.pathfinder import find_nvidia_header_directory
11
+ from cuda.cccl import get_include_paths
12
+ from cuda.bindings import runtime, driver
13
+ from cuda.bindings.runtime import cudaMemcpyKind
14
+
15
+ from cuslines.cuda_python.cutils import (
16
+ REAL_SIZE,
17
+ REAL_DTYPE,
18
+ REAL_DTYPE_AS_STR,
19
+ REAL3_DTYPE_AS_STR,
20
+ checkCudaErrors,
21
+ ModelType,
22
+ THR_X_SL,
23
+ BLOCK_Y,
24
+ )
25
+
26
+ logger = logging.getLogger("GPUStreamlines")
27
+
28
+
29
+ class GPUDirectionGetter(ABC):
30
+ @abstractmethod
31
+ def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp):
32
+ pass
33
+
34
+ @abstractmethod
35
+ def generateStreamlines(self, n, nseeds_gpu, block, grid, sp):
36
+ pass
37
+
38
+ def allocate_on_gpu(self, n):
39
+ pass
40
+
41
+ def deallocate_on_gpu(self, n):
42
+ pass
43
+
44
+ def compile_program(self, debug: bool = False):
45
+ start_time = time()
46
+ logger.info("Compiling GPUStreamlines")
47
+
48
+ cuslines_cuda = files("cuslines").joinpath("cuda_c")
49
+
50
+ if debug:
51
+ program_opts = {
52
+ "ptxas_options": ["-O0", "-v"],
53
+ "device_code_optimize": True,
54
+ "debug": True,
55
+ "lineinfo": True,
56
+ }
57
+ else:
58
+ program_opts = {"ptxas_options": ["-O3"]}
59
+
60
+ program_options = ProgramOptions(
61
+ name="cuslines",
62
+ use_fast_math=True,
63
+ std="c++17",
64
+ define_macro="__NVRTC__",
65
+ include_path=[
66
+ str(cuslines_cuda),
67
+ find_nvidia_header_directory("cudart"),
68
+ find_nvidia_header_directory("curand"),
69
+ get_include_paths().libcudacxx,
70
+ ],
71
+ **program_opts,
72
+ )
73
+
74
+ # Here we assume all devices are the same,
75
+ # so we compile once for any current device.
76
+ # I think this is reasonable
77
+ dev = Device()
78
+ dev.set_current()
79
+ cuda_path = cuslines_cuda.joinpath("generate_streamlines_cuda.cu")
80
+ with open(cuda_path, "r") as f:
81
+ prog = Program(f.read(), code_type="c++", options=program_options)
82
+ self.module = prog.compile(
83
+ "cubin",
84
+ name_expressions=(
85
+ self.getnum_kernel_name,
86
+ self.genstreamlines_kernel_name,
87
+ ),
88
+ )
89
+ logger.info(
90
+ "GPUStreamlines compiled successfully in %.2f seconds", time() - start_time
91
+ )
92
+
93
+
94
+ class BootDirectionGetter(GPUDirectionGetter):
95
+ def __init__(
96
+ self,
97
+ model_type: str,
98
+ min_signal: float,
99
+ H: np.ndarray,
100
+ R: np.ndarray,
101
+ delta_b: np.ndarray,
102
+ delta_q: np.ndarray,
103
+ sampling_matrix: np.ndarray,
104
+ b0s_mask: np.ndarray,
105
+ ):
106
+ if model_type.upper() == "OPDT":
107
+ self.model_type = int(ModelType.OPDT)
108
+ elif model_type.upper() == "CSA":
109
+ self.model_type = int(ModelType.CSA)
110
+ else:
111
+ raise ValueError(
112
+ f"Invalid model_type {model_type}, must be one of 'OPDT', 'CSA'"
113
+ )
114
+
115
+ checkCudaErrors(driver.cuInit(0))
116
+
117
+ self.H = np.ascontiguousarray(H, dtype=REAL_DTYPE)
118
+ self.R = np.ascontiguousarray(R, dtype=REAL_DTYPE)
119
+ self.delta_b = np.ascontiguousarray(delta_b, dtype=REAL_DTYPE)
120
+ self.delta_q = np.ascontiguousarray(delta_q, dtype=REAL_DTYPE)
121
+ self.delta_nr = int(delta_b.shape[0])
122
+ self.min_signal = REAL_DTYPE(min_signal)
123
+ self.sampling_matrix = np.ascontiguousarray(sampling_matrix, dtype=REAL_DTYPE)
124
+ self.b0s_mask = np.ascontiguousarray(b0s_mask, dtype=np.int32)
125
+
126
+ self.H_d = []
127
+ self.R_d = []
128
+ self.delta_b_d = []
129
+ self.delta_q_d = []
130
+ self.b0s_mask_d = []
131
+ self.sampling_matrix_d = []
132
+
133
+ self.getnum_kernel_name = f"getNumStreamlinesBoot_k<{THR_X_SL},{BLOCK_Y},{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>"
134
+ self.genstreamlines_kernel_name = f"genStreamlinesMergeBoot_k<{THR_X_SL},{BLOCK_Y},{model_type.upper()},{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>"
135
+ self.compile_program()
136
+
137
+ @classmethod
138
+ def from_dipy_opdt(
139
+ cls,
140
+ gtab,
141
+ sphere,
142
+ sh_order_max=6,
143
+ full_basis=False,
144
+ sh_lambda=0.006,
145
+ min_signal=1,
146
+ ):
147
+ sampling_matrix, _, _ = shm.real_sh_descoteaux(
148
+ sh_order_max, sphere.theta, sphere.phi, full_basis=full_basis, legacy=False
149
+ )
150
+
151
+ model = shm.OpdtModel(
152
+ gtab, sh_order_max=sh_order_max, smooth=sh_lambda, min_signal=min_signal
153
+ )
154
+ fit_matrix = model._fit_matrix
155
+ delta_b, delta_q = fit_matrix
156
+
157
+ b0s_mask = gtab.b0s_mask
158
+ dwi_mask = ~b0s_mask
159
+ x, y, z = model.gtab.gradients[dwi_mask].T
160
+ _, theta, phi = shm.cart2sphere(x, y, z)
161
+ B, _, _ = shm.real_sym_sh_basis(sh_order_max, theta, phi)
162
+ H = shm.hat(B)
163
+ R = shm.lcr_matrix(H)
164
+
165
+ return cls(
166
+ model_type="OPDT",
167
+ min_signal=min_signal,
168
+ H=H,
169
+ R=R,
170
+ delta_b=delta_b,
171
+ delta_q=delta_q,
172
+ sampling_matrix=sampling_matrix,
173
+ b0s_mask=gtab.b0s_mask,
174
+ )
175
+
176
+ @classmethod
177
+ def from_dipy_csa(
178
+ cls,
179
+ gtab,
180
+ sphere,
181
+ sh_order_max=6,
182
+ full_basis=False,
183
+ sh_lambda=0.006,
184
+ min_signal=1,
185
+ ):
186
+ sampling_matrix, _, _ = shm.real_sh_descoteaux(
187
+ sh_order_max, sphere.theta, sphere.phi, full_basis=full_basis, legacy=False
188
+ )
189
+
190
+ model = shm.CsaOdfModel(
191
+ gtab, sh_order_max=sh_order_max, smooth=sh_lambda, min_signal=min_signal
192
+ )
193
+ fit_matrix = model._fit_matrix
194
+ delta_b = fit_matrix
195
+ delta_q = fit_matrix
196
+
197
+ b0s_mask = gtab.b0s_mask
198
+ dwi_mask = ~b0s_mask
199
+ x, y, z = model.gtab.gradients[dwi_mask].T
200
+ _, theta, phi = shm.cart2sphere(x, y, z)
201
+ B, _, _ = shm.real_sym_sh_basis(sh_order_max, theta, phi)
202
+ H = shm.hat(B)
203
+ R = shm.lcr_matrix(H)
204
+
205
+ return cls(
206
+ model_type="CSA",
207
+ min_signal=min_signal,
208
+ H=H,
209
+ R=R,
210
+ delta_b=delta_b,
211
+ delta_q=delta_q,
212
+ sampling_matrix=sampling_matrix,
213
+ b0s_mask=gtab.b0s_mask,
214
+ )
215
+
216
+ def allocate_on_gpu(self, n):
217
+ self.H_d.append(checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.H.size)))
218
+ self.R_d.append(checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.R.size)))
219
+ self.delta_b_d.append(
220
+ checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.delta_b.size))
221
+ )
222
+ self.delta_q_d.append(
223
+ checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.delta_q.size))
224
+ )
225
+ self.b0s_mask_d.append(
226
+ checkCudaErrors(runtime.cudaMalloc(np.int32().nbytes * self.b0s_mask.size))
227
+ )
228
+ self.sampling_matrix_d.append(
229
+ checkCudaErrors(runtime.cudaMalloc(REAL_SIZE * self.sampling_matrix.size))
230
+ )
231
+
232
+ checkCudaErrors(
233
+ runtime.cudaMemcpy(
234
+ self.H_d[n],
235
+ self.H.ctypes.data,
236
+ REAL_SIZE * self.H.size,
237
+ cudaMemcpyKind.cudaMemcpyHostToDevice,
238
+ )
239
+ )
240
+ checkCudaErrors(
241
+ runtime.cudaMemcpy(
242
+ self.R_d[n],
243
+ self.R.ctypes.data,
244
+ REAL_SIZE * self.R.size,
245
+ cudaMemcpyKind.cudaMemcpyHostToDevice,
246
+ )
247
+ )
248
+ checkCudaErrors(
249
+ runtime.cudaMemcpy(
250
+ self.delta_b_d[n],
251
+ self.delta_b.ctypes.data,
252
+ REAL_SIZE * self.delta_b.size,
253
+ cudaMemcpyKind.cudaMemcpyHostToDevice,
254
+ )
255
+ )
256
+ checkCudaErrors(
257
+ runtime.cudaMemcpy(
258
+ self.delta_q_d[n],
259
+ self.delta_q.ctypes.data,
260
+ REAL_SIZE * self.delta_q.size,
261
+ cudaMemcpyKind.cudaMemcpyHostToDevice,
262
+ )
263
+ )
264
+ checkCudaErrors(
265
+ runtime.cudaMemcpy(
266
+ self.b0s_mask_d[n],
267
+ self.b0s_mask.ctypes.data,
268
+ np.int32().nbytes * self.b0s_mask.size,
269
+ cudaMemcpyKind.cudaMemcpyHostToDevice,
270
+ )
271
+ )
272
+ checkCudaErrors(
273
+ runtime.cudaMemcpy(
274
+ self.sampling_matrix_d[n],
275
+ self.sampling_matrix.ctypes.data,
276
+ REAL_SIZE * self.sampling_matrix.size,
277
+ cudaMemcpyKind.cudaMemcpyHostToDevice,
278
+ )
279
+ )
280
+
281
+ def deallocate_on_gpu(self, n):
282
+ if self.H_d[n]:
283
+ checkCudaErrors(runtime.cudaFree(self.H_d[n]))
284
+ if self.R_d[n]:
285
+ checkCudaErrors(runtime.cudaFree(self.R_d[n]))
286
+ if self.delta_b_d[n]:
287
+ checkCudaErrors(runtime.cudaFree(self.delta_b_d[n]))
288
+ if self.delta_q_d[n]:
289
+ checkCudaErrors(runtime.cudaFree(self.delta_q_d[n]))
290
+ if self.b0s_mask_d[n]:
291
+ checkCudaErrors(runtime.cudaFree(self.b0s_mask_d[n]))
292
+ if self.sampling_matrix_d[n]:
293
+ checkCudaErrors(runtime.cudaFree(self.sampling_matrix_d[n]))
294
+
295
+ def _shared_mem_bytes(self, sp):
296
+ return (
297
+ REAL_SIZE
298
+ * BLOCK_Y
299
+ * 2
300
+ * (
301
+ sp.gpu_tracker.n32dimt
302
+ + max(sp.gpu_tracker.n32dimt, sp.gpu_tracker.samplm_nr)
303
+ )
304
+ + np.int32().nbytes * BLOCK_Y * sp.gpu_tracker.samplm_nr
305
+ )
306
+
307
+ def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp):
308
+ ker = self.module.get_kernel(self.getnum_kernel_name)
309
+ shared_memory = self._shared_mem_bytes(sp)
310
+ config = LaunchConfig(block=block, grid=grid, shmem_size=shared_memory)
311
+
312
+ launch(
313
+ sp.gpu_tracker.streams[n],
314
+ config,
315
+ ker,
316
+ self.model_type,
317
+ sp.gpu_tracker.max_angle,
318
+ self.min_signal,
319
+ sp.gpu_tracker.relative_peak_thresh,
320
+ sp.gpu_tracker.min_separation_angle,
321
+ sp.gpu_tracker.rng_seed,
322
+ nseeds_gpu,
323
+ sp.seeds_d[n],
324
+ sp.gpu_tracker.dimx,
325
+ sp.gpu_tracker.dimy,
326
+ sp.gpu_tracker.dimz,
327
+ sp.gpu_tracker.dimt,
328
+ sp.gpu_tracker.dataf_d[n],
329
+ self.H_d[n],
330
+ self.R_d[n],
331
+ self.delta_nr,
332
+ self.delta_b_d[n],
333
+ self.delta_q_d[n],
334
+ self.b0s_mask_d[n],
335
+ sp.gpu_tracker.samplm_nr,
336
+ self.sampling_matrix_d[n],
337
+ sp.gpu_tracker.sphere_vertices_d[n],
338
+ sp.gpu_tracker.sphere_edges_d[n],
339
+ sp.gpu_tracker.nedges,
340
+ sp.shDirTemp0_d[n],
341
+ sp.slinesOffs_d[n],
342
+ )
343
+
344
+ def generateStreamlines(self, n, nseeds_gpu, block, grid, sp):
345
+ ker = self.module.get_kernel(self.genstreamlines_kernel_name)
346
+ shared_memory = self._shared_mem_bytes(sp)
347
+ config = LaunchConfig(block=block, grid=grid, shmem_size=shared_memory)
348
+
349
+ launch(
350
+ sp.gpu_tracker.streams[n],
351
+ config,
352
+ ker,
353
+ sp.gpu_tracker.max_angle,
354
+ sp.gpu_tracker.tc_threshold,
355
+ sp.gpu_tracker.step_size,
356
+ sp.gpu_tracker.relative_peak_thresh,
357
+ sp.gpu_tracker.min_separation_angle,
358
+ sp.gpu_tracker.rng_seed,
359
+ sp.gpu_tracker.rng_offset + n * nseeds_gpu,
360
+ nseeds_gpu,
361
+ sp.seeds_d[n],
362
+ sp.gpu_tracker.dimx,
363
+ sp.gpu_tracker.dimy,
364
+ sp.gpu_tracker.dimz,
365
+ sp.gpu_tracker.dimt,
366
+ sp.gpu_tracker.dataf_d[n],
367
+ sp.gpu_tracker.metric_map_d[n],
368
+ sp.gpu_tracker.samplm_nr,
369
+ sp.gpu_tracker.sphere_vertices_d[n],
370
+ sp.gpu_tracker.sphere_edges_d[n],
371
+ sp.gpu_tracker.nedges,
372
+ self.min_signal,
373
+ self.delta_nr,
374
+ self.H_d[n],
375
+ self.R_d[n],
376
+ self.delta_b_d[n],
377
+ self.delta_q_d[n],
378
+ self.sampling_matrix_d[n],
379
+ self.b0s_mask_d[n],
380
+ sp.slinesOffs_d[n],
381
+ sp.shDirTemp0_d[n],
382
+ sp.slineSeed_d[n],
383
+ sp.slineLen_d[n],
384
+ sp.sline_d[n],
385
+ )
386
+
387
+
388
+ class ProbDirectionGetter(GPUDirectionGetter):
389
+ def __init__(self):
390
+ checkCudaErrors(driver.cuInit(0))
391
+ self.getnum_kernel_name = f"getNumStreamlinesProb_k<{THR_X_SL},{BLOCK_Y},{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>"
392
+ self.genstreamlines_kernel_name = f"genStreamlinesMergeProb_k<{THR_X_SL},{BLOCK_Y},PROB,{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>"
393
+ self.compile_program()
394
+
395
+ def getNumStreamlines(self, n, nseeds_gpu, block, grid, sp):
396
+ ker = self.module.get_kernel(self.getnum_kernel_name)
397
+ shared_memory = (
398
+ REAL_SIZE * BLOCK_Y * sp.gpu_tracker.n32dimt
399
+ + np.int32().nbytes * BLOCK_Y * sp.gpu_tracker.n32dimt
400
+ )
401
+ config = LaunchConfig(block=block, grid=grid, shmem_size=shared_memory)
402
+
403
+ launch(
404
+ sp.gpu_tracker.streams[n],
405
+ config,
406
+ ker,
407
+ sp.gpu_tracker.max_angle,
408
+ sp.gpu_tracker.relative_peak_thresh,
409
+ sp.gpu_tracker.min_separation_angle,
410
+ sp.gpu_tracker.rng_seed,
411
+ nseeds_gpu,
412
+ sp.seeds_d[n],
413
+ sp.gpu_tracker.dimx,
414
+ sp.gpu_tracker.dimy,
415
+ sp.gpu_tracker.dimz,
416
+ sp.gpu_tracker.dimt,
417
+ sp.gpu_tracker.dataf_d[n],
418
+ sp.gpu_tracker.sphere_vertices_d[n],
419
+ sp.gpu_tracker.sphere_edges_d[n],
420
+ sp.gpu_tracker.nedges,
421
+ sp.shDirTemp0_d[n],
422
+ sp.slinesOffs_d[n],
423
+ )
424
+
425
+ def _shared_mem_bytes(self, sp):
426
+ return REAL_SIZE * BLOCK_Y * sp.gpu_tracker.n32dimt
427
+
428
+ def generateStreamlines(self, n, nseeds_gpu, block, grid, sp):
429
+ ker = self.module.get_kernel(self.genstreamlines_kernel_name)
430
+ shared_memory = self._shared_mem_bytes(sp)
431
+ config = LaunchConfig(block=block, grid=grid, shmem_size=shared_memory)
432
+
433
+ launch(
434
+ sp.gpu_tracker.streams[n],
435
+ config,
436
+ ker,
437
+ sp.gpu_tracker.max_angle,
438
+ sp.gpu_tracker.tc_threshold,
439
+ sp.gpu_tracker.step_size,
440
+ sp.gpu_tracker.relative_peak_thresh,
441
+ sp.gpu_tracker.min_separation_angle,
442
+ sp.gpu_tracker.rng_seed,
443
+ sp.gpu_tracker.rng_offset + n * nseeds_gpu,
444
+ nseeds_gpu,
445
+ sp.seeds_d[n],
446
+ sp.gpu_tracker.dimx,
447
+ sp.gpu_tracker.dimy,
448
+ sp.gpu_tracker.dimz,
449
+ sp.gpu_tracker.dimt,
450
+ sp.gpu_tracker.dataf_d[n],
451
+ sp.gpu_tracker.metric_map_d[n],
452
+ sp.gpu_tracker.samplm_nr,
453
+ sp.gpu_tracker.sphere_vertices_d[n],
454
+ sp.gpu_tracker.sphere_edges_d[n],
455
+ sp.gpu_tracker.nedges,
456
+ sp.slinesOffs_d[n],
457
+ sp.shDirTemp0_d[n],
458
+ sp.slineSeed_d[n],
459
+ sp.slineLen_d[n],
460
+ sp.sline_d[n],
461
+ )
462
+
463
+
464
+ class PttDirectionGetter(ProbDirectionGetter):
465
+ def __init__(self):
466
+ checkCudaErrors(driver.cuInit(0))
467
+ self.getnum_kernel_name = f"getNumStreamlinesProb_k<{THR_X_SL},{BLOCK_Y},{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>"
468
+ self.genstreamlines_kernel_name = f"genStreamlinesMergeProb_k<{THR_X_SL},{BLOCK_Y},PTT,{REAL_DTYPE_AS_STR},{REAL3_DTYPE_AS_STR}>"
469
+ self.compile_program()
470
+
471
+ def _shared_mem_bytes(self, sp):
472
+ return 0