cuslines 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1066 @@
1
+ //#define USE_FIXED_PERMUTATION
2
+ #ifdef USE_FIXED_PERMUTATION
3
+ //__device__ const int fixedPerm[] = {44, 47, 53, 0, 3, 3, 39, 9, 19, 21, 50, 36, 23,
4
+ // 6, 24, 24, 12, 1, 38, 39, 23, 46, 24, 17, 37, 25,
5
+ // 13, 8, 9, 20, 51, 16, 51, 5, 15, 47, 0, 18, 35,
6
+ // 24, 49, 51, 29, 19, 19, 14, 39, 32, 1, 9, 32, 31,
7
+ // 10, 52, 23};
8
+ __device__ const int fixedPerm[] = {
9
+ 47, 117, 67, 103, 9, 21, 36, 87, 70, 88, 140, 58, 39, 87, 88, 81, 25, 77,
10
+ 72, 9, 148, 115, 79, 82, 99, 29, 147, 147, 142, 32, 9, 127, 32, 31, 114, 28,
11
+ 34, 128, 128, 53, 133, 38, 17, 79, 132, 105, 42, 31, 120, 1, 65, 57, 35, 102,
12
+ 119, 11, 82, 91, 128, 142, 99, 53, 140, 121, 84, 68, 6, 47, 127, 131, 100, 78,
13
+ 143, 148, 23, 141, 117, 85, 48, 49, 69, 95, 94, 0, 113, 36, 48, 93, 131, 98,
14
+ 42, 112, 149, 127, 0, 138, 114, 43, 127, 23, 130, 121, 98, 62, 123, 82, 148, 50,
15
+ 14, 41, 58, 36, 10, 86, 43, 104, 11, 2, 51, 80, 32, 128, 38, 19, 42, 115,
16
+ 77, 30, 24, 125, 2, 3, 94, 107, 13, 112, 40, 72, 19, 95, 72, 67, 61, 14,
17
+ 96, 4, 139, 86, 121, 109};
18
+ #endif
19
+
20
+ template<int BDIM_X,
21
+ typename VAL_T>
22
+ __device__ VAL_T avgMask(const int mskLen,
23
+ const int *__restrict__ mask,
24
+ const VAL_T *__restrict__ data) {
25
+
26
+ const int tidx = threadIdx.x;
27
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
28
+
29
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
30
+
31
+ int __myCnt = 0;
32
+ VAL_T __mySum = 0;
33
+
34
+ for(int i = tidx; i < mskLen; i += BDIM_X) {
35
+ if(mask[i]) {
36
+ __myCnt++;
37
+ __mySum += data[i];
38
+ }
39
+ }
40
+
41
+ #pragma unroll
42
+ for(int i = BDIM_X/2; i; i /= 2) {
43
+ __mySum += __shfl_xor_sync(WMASK, __mySum, i, BDIM_X);
44
+ __myCnt += __shfl_xor_sync(WMASK, __myCnt, i, BDIM_X);
45
+ }
46
+
47
+ return __mySum/__myCnt;
48
+
49
+ }
50
+
51
+ template<
52
+ int BDIM_X,
53
+ typename LEN_T,
54
+ typename MSK_T,
55
+ typename VAL_T>
56
+ __device__ LEN_T maskGet(const LEN_T n,
57
+ const MSK_T *__restrict__ mask,
58
+ const VAL_T *__restrict__ plain,
59
+ VAL_T *__restrict__ masked) {
60
+
61
+ const int tidx = threadIdx.x;
62
+
63
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
64
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
65
+
66
+ const int __laneMask = (1 << tidx)-1;
67
+
68
+ int woff = 0;
69
+ for(int j = 0; j < n; j += BDIM_X) {
70
+
71
+ const int __act = (j+tidx < n) ? !mask[j+tidx] : 0;
72
+ const int __msk = __ballot_sync(WMASK, __act);
73
+
74
+ const int toff = __popc(__msk & __laneMask);
75
+ if (__act) {
76
+ masked[woff+toff] = plain[j+tidx];
77
+ }
78
+ woff += __popc(__msk);
79
+ }
80
+ return woff;
81
+ }
82
+
83
+ template<
84
+ int BDIM_X,
85
+ typename LEN_T,
86
+ typename MSK_T,
87
+ typename VAL_T>
88
+ __device__ void maskPut(const LEN_T n,
89
+ const MSK_T *__restrict__ mask,
90
+ const VAL_T *__restrict__ masked,
91
+ VAL_T *__restrict__ plain) {
92
+
93
+ const int tidx = threadIdx.x;
94
+
95
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
96
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
97
+
98
+ const int __laneMask = (1 << tidx)-1;
99
+
100
+ int woff = 0;
101
+ for(int j = 0; j < n; j += BDIM_X) {
102
+
103
+ const int __act = (j+tidx < n) ? !mask[j+tidx] : 0;
104
+ const int __msk = __ballot_sync(WMASK, __act);
105
+
106
+ const int toff = __popc(__msk & __laneMask);
107
+ if (__act) {
108
+ plain[j+tidx] = masked[woff+toff];
109
+ }
110
+ woff += __popc(__msk);
111
+ }
112
+ return;
113
+ }
114
+
115
+ template<int BDIM_X,
116
+ int BDIM_Y,
117
+ typename REAL_T,
118
+ typename REAL3_T>
119
+ __device__ int closest_peak_d(const REAL_T max_angle,
120
+ const REAL3_T direction, //dir
121
+ const int npeaks,
122
+ const REAL3_T *__restrict__ peaks,
123
+ REAL3_T *__restrict__ peak) {// dirs,
124
+
125
+ const int tidx = threadIdx.x;
126
+
127
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
128
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
129
+
130
+ //const REAL_T cos_similarity = COS(MAX_ANGLE_P);
131
+ const REAL_T cos_similarity = COS(max_angle);
132
+ #if 0
133
+ if (!threadIdx.y && !tidx) {
134
+ printf("direction: (%f, %f, %f)\n",
135
+ direction.x, direction.y, direction.z);
136
+ }
137
+ __syncwarp(WMASK);
138
+ #endif
139
+ REAL_T cpeak_dot = 0;
140
+ int cpeak_idx = -1;
141
+ for(int j = 0; j < npeaks; j += BDIM_X) {
142
+ if (j+tidx < npeaks) {
143
+ #if 0
144
+ if (!threadIdx.y && !tidx) {
145
+ printf("j+tidx: %d, peaks[j+tidx]: (%f, %f, %f)\n",
146
+ j+tidx, peaks[j+tidx].x, peaks[j+tidx].y, peaks[j+tidx].z);
147
+ }
148
+ #endif
149
+ const REAL_T dot = direction.x*peaks[j+tidx].x+
150
+ direction.y*peaks[j+tidx].y+
151
+ direction.z*peaks[j+tidx].z;
152
+
153
+ if (FABS(dot) > FABS(cpeak_dot)) {
154
+ cpeak_dot = dot;
155
+ cpeak_idx = j+tidx;
156
+ }
157
+ }
158
+ }
159
+ #if 0
160
+ if (!threadIdx.y && !tidx) {
161
+ printf("cpeak_idx: %d, cpeak_dot: %f\n", cpeak_idx, cpeak_dot);
162
+ }
163
+ __syncwarp(WMASK);
164
+ #endif
165
+
166
+ #pragma unroll
167
+ for(int j = BDIM_X/2; j; j /= 2) {
168
+
169
+ const REAL_T dot = __shfl_xor_sync(WMASK, cpeak_dot, j, BDIM_X);
170
+ const int idx = __shfl_xor_sync(WMASK, cpeak_idx, j, BDIM_X);
171
+ if (FABS(dot) > FABS(cpeak_dot)) {
172
+ cpeak_dot = dot;
173
+ cpeak_idx = idx;
174
+ }
175
+ }
176
+ #if 0
177
+ if (!threadIdx.y && !tidx) {
178
+ printf("cpeak_idx: %d, cpeak_dot: %f, cos_similarity: %f\n", cpeak_idx, cpeak_dot, cos_similarity);
179
+ }
180
+ __syncwarp(WMASK);
181
+ #endif
182
+ if (cpeak_idx >= 0) {
183
+ if (cpeak_dot >= cos_similarity) {
184
+ peak[0] = peaks[cpeak_idx];
185
+ return 1;
186
+ }
187
+ if (cpeak_dot <= -cos_similarity) {
188
+ peak[0] = MAKE_REAL3(-peaks[cpeak_idx].x,
189
+ -peaks[cpeak_idx].y,
190
+ -peaks[cpeak_idx].z);
191
+ return 1;
192
+ }
193
+ }
194
+ return 0;
195
+ }
196
+
197
+ template<int BDIM_X,
198
+ typename VAL_T>
199
+ __device__ void ndotp_d(const int N,
200
+ const int M,
201
+ const VAL_T *__restrict__ srcV,
202
+ const VAL_T *__restrict__ srcM,
203
+ VAL_T *__restrict__ dstV) {
204
+
205
+ const int tidx = threadIdx.x;
206
+
207
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
208
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
209
+
210
+ //#pragma unroll
211
+ for(int i = 0; i < N; i++) {
212
+
213
+ VAL_T __tmp = 0;
214
+
215
+ //#pragma unroll
216
+ for(int j = 0; j < M; j += BDIM_X) {
217
+ if (j+tidx < M) {
218
+ __tmp += srcV[j+tidx]*srcM[i*M + j+tidx];
219
+ }
220
+ }
221
+ #pragma unroll
222
+ for(int j = BDIM_X/2; j; j /= 2) {
223
+ #if 0
224
+ __tmp += __shfl_xor_sync(WMASK, __tmp, j, BDIM_X);
225
+ #else
226
+ __tmp += __shfl_down_sync(WMASK, __tmp, j, BDIM_X);
227
+ #endif
228
+ }
229
+ // values could be held by BDIM_X threads and written
230
+ // together every BDIM_X iterations...
231
+
232
+ if (tidx == 0) {
233
+ dstV[i] = __tmp;
234
+ }
235
+ }
236
+ return;
237
+ }
238
+
239
+
240
+ template<int BDIM_X,
241
+ typename VAL_T>
242
+ __device__ void ndotp_log_opdt_d(const int N,
243
+ const int M,
244
+ const VAL_T *__restrict__ srcV,
245
+ const VAL_T *__restrict__ srcM,
246
+ VAL_T *__restrict__ dstV) {
247
+
248
+ const int tidx = threadIdx.x;
249
+
250
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
251
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
252
+
253
+ const VAL_T ONEP5 = static_cast<VAL_T>(1.5);
254
+
255
+ //#pragma unroll
256
+ for(int i = 0; i < N; i++) {
257
+
258
+ VAL_T __tmp = 0;
259
+
260
+ //#pragma unroll
261
+ for(int j = 0; j < M; j += BDIM_X) {
262
+ if (j+tidx < M) {
263
+ const VAL_T v = srcV[j+tidx];
264
+ __tmp += -LOG(v)*(ONEP5+LOG(v))*v * srcM[i*M + j+tidx];
265
+ }
266
+ }
267
+ #pragma unroll
268
+ for(int j = BDIM_X/2; j; j /= 2) {
269
+ #if 0
270
+ __tmp += __shfl_xor_sync(WMASK, __tmp, j, BDIM_X);
271
+ #else
272
+ __tmp += __shfl_down_sync(WMASK, __tmp, j, BDIM_X);
273
+ #endif
274
+ }
275
+ // values could be held by BDIM_X threads and written
276
+ // together every BDIM_X iterations...
277
+
278
+ if (tidx == 0) {
279
+ dstV[i] = __tmp;
280
+ }
281
+ }
282
+ return;
283
+ }
284
+
285
+ template<int BDIM_X,
286
+ typename VAL_T>
287
+ __device__ void ndotp_log_csa_d(const int N,
288
+ const int M,
289
+ const VAL_T *__restrict__ srcV,
290
+ const VAL_T *__restrict__ srcM,
291
+ VAL_T *__restrict__ dstV) {
292
+
293
+ const int tidx = threadIdx.x;
294
+
295
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
296
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
297
+ // Clamp values
298
+ constexpr VAL_T min = .001;
299
+ constexpr VAL_T max = .999;
300
+
301
+ //#pragma unroll
302
+ for(int i = 0; i < N; i++) {
303
+
304
+ VAL_T __tmp = 0;
305
+
306
+ //#pragma unroll
307
+ for(int j = 0; j < M; j += BDIM_X) {
308
+ if (j+tidx < M) {
309
+ const VAL_T v = MIN(MAX(srcV[j+tidx], min), max);
310
+ __tmp += LOG(-LOG(v)) * srcM[i*M + j+tidx];
311
+ }
312
+ }
313
+ #pragma unroll
314
+ for(int j = BDIM_X/2; j; j /= 2) {
315
+ #if 0
316
+ __tmp += __shfl_xor_sync(WMASK, __tmp, j, BDIM_X);
317
+ #else
318
+ __tmp += __shfl_down_sync(WMASK, __tmp, j, BDIM_X);
319
+ #endif
320
+ }
321
+ // values could be held by BDIM_X threads and written
322
+ // together every BDIM_X iterations...
323
+
324
+ if (tidx == 0) {
325
+ dstV[i] = __tmp;
326
+ }
327
+ }
328
+ return;
329
+ }
330
+
331
+
332
+ template<int BDIM_X,
333
+ typename REAL_T>
334
+ __device__ void fit_opdt(const int delta_nr,
335
+ const int hr_side,
336
+ const REAL_T *__restrict__ delta_q,
337
+ const REAL_T *__restrict__ delta_b,
338
+ const REAL_T *__restrict__ __msk_data_sh,
339
+ REAL_T *__restrict__ __h_sh,
340
+ REAL_T *__restrict__ __r_sh) {
341
+ const int tidx = threadIdx.x;
342
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
343
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
344
+
345
+ ndotp_log_opdt_d<BDIM_X>(delta_nr, hr_side, __msk_data_sh, delta_q, __r_sh);
346
+ ndotp_d <BDIM_X>(delta_nr, hr_side, __msk_data_sh, delta_b, __h_sh);
347
+ __syncwarp(WMASK);
348
+ #pragma unroll
349
+ for(int j = tidx; j < delta_nr; j += BDIM_X) {
350
+ __r_sh[j] -= __h_sh[j];
351
+ }
352
+ __syncwarp(WMASK);
353
+ }
354
+
355
+ template<int BDIM_X, typename REAL_T>
356
+ __device__ void fit_csa(const int delta_nr,
357
+ const int hr_side,
358
+ const REAL_T *__restrict__ fit_matrix,
359
+ const REAL_T *__restrict__ __msk_data_sh,
360
+ REAL_T *__restrict__ __r_sh) {
361
+ const int tidx = threadIdx.x;
362
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
363
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
364
+
365
+ constexpr REAL _n0_const = 0.28209479177387814; // .5 / sqrt(pi)
366
+ ndotp_log_csa_d<BDIM_X>(delta_nr, hr_side, __msk_data_sh, fit_matrix, __r_sh);
367
+ __syncwarp(WMASK);
368
+ if (tidx == 0) {
369
+ __r_sh[0] = _n0_const;
370
+ }
371
+ __syncwarp(WMASK);
372
+ }
373
+
374
+ template<int BDIM_X, ModelType MODEL_T, typename REAL_T>
375
+ __device__ void fit_model_coef(const int delta_nr, // delta_nr is number of ODF directions
376
+ const int hr_side, // hr_side is number of data directions
377
+ const REAL_T *__restrict__ delta_q,
378
+ const REAL_T *__restrict__ delta_b, // these are fit matrices the model can use, different for each model
379
+ const REAL_T *__restrict__ __msk_data_sh, // __msk_data_sh is the part of the data currently being operated on by this block
380
+ REAL_T *__restrict__ __h_sh, // these last two are modifications to the coefficients that will be returned
381
+ REAL_T *__restrict__ __r_sh) {
382
+ switch(MODEL_T) {
383
+ case OPDT:
384
+ fit_opdt<BDIM_X>(delta_nr, hr_side, delta_q, delta_b, __msk_data_sh, __h_sh, __r_sh);
385
+ break;
386
+ case CSA:
387
+ fit_csa<BDIM_X>(delta_nr, hr_side, delta_q, __msk_data_sh, __r_sh);
388
+ break;
389
+ default:
390
+ printf("FATAL: Invalid Model Type.\n");
391
+ break;
392
+ }
393
+ }
394
+
395
+ template<int BDIM_X,
396
+ int BDIM_Y,
397
+ int NATTEMPTS,
398
+ ModelType MODEL_T,
399
+ typename REAL_T,
400
+ typename REAL3_T>
401
+ __device__ int get_direction_boot_d(
402
+ curandStatePhilox4_32_10_t *st,
403
+ const REAL_T max_angle,
404
+ const REAL_T min_signal,
405
+ const REAL_T relative_peak_thres,
406
+ const REAL_T min_separation_angle,
407
+ REAL3_T dir,
408
+ const int dimx,
409
+ const int dimy,
410
+ const int dimz,
411
+ const int dimt,
412
+ const REAL_T *__restrict__ dataf,
413
+ const int *__restrict__ b0s_mask, // not using this (and its opposite, dwi_mask)
414
+ // but not clear if it will never be needed so
415
+ // we'll keep it here for now...
416
+ const REAL3_T point,
417
+ const REAL_T *__restrict__ H,
418
+ const REAL_T *__restrict__ R,
419
+ // model unused
420
+ // max_angle, pmf_threshold from global defines
421
+ // b0s_mask already passed
422
+ // min_signal from global defines
423
+ const int delta_nr,
424
+ const REAL_T *__restrict__ delta_b,
425
+ const REAL_T *__restrict__ delta_q, // fit_matrix
426
+ const int samplm_nr,
427
+ const REAL_T *__restrict__ sampling_matrix,
428
+ const REAL3_T *__restrict__ sphere_vertices,
429
+ const int2 *__restrict__ sphere_edges,
430
+ const int num_edges,
431
+ REAL3_T *__restrict__ dirs) {
432
+
433
+ const int tidx = threadIdx.x;
434
+ const int tidy = threadIdx.y;
435
+
436
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
437
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
438
+
439
+ const int n32dimt = ((dimt+31)/32)*32;
440
+
441
+ extern REAL_T __shared__ __sh[];
442
+
443
+ REAL_T *__vox_data_sh = reinterpret_cast<REAL_T *>(__sh);
444
+ REAL_T *__msk_data_sh = __vox_data_sh + BDIM_Y*n32dimt;
445
+
446
+ REAL_T *__r_sh = __msk_data_sh + BDIM_Y*n32dimt;
447
+ REAL_T *__h_sh = __r_sh + BDIM_Y*MAX(n32dimt, samplm_nr);
448
+
449
+ __vox_data_sh += tidy*n32dimt;
450
+ __msk_data_sh += tidy*n32dimt;
451
+
452
+ __r_sh += tidy*MAX(n32dimt, samplm_nr);
453
+ __h_sh += tidy*MAX(n32dimt, samplm_nr);
454
+
455
+ // compute hr_side (may be passed from python)
456
+ int hr_side = 0;
457
+ for(int j = tidx; j < dimt; j += BDIM_X) {
458
+ hr_side += !b0s_mask[j] ? 1 : 0;
459
+ }
460
+ #pragma unroll
461
+ for(int i = BDIM_X/2; i; i /= 2) {
462
+ hr_side += __shfl_xor_sync(WMASK, hr_side, i, BDIM_X);
463
+ }
464
+
465
+ #pragma unroll
466
+ for(int i = 0; i < NATTEMPTS; i++) {
467
+
468
+ const int rv = trilinear_interp_d<BDIM_X>(dimx, dimy, dimz, dimt, -1, dataf, point, __vox_data_sh);
469
+
470
+ const int nmsk = maskGet<BDIM_X>(dimt, b0s_mask, __vox_data_sh, __msk_data_sh);
471
+
472
+ //if (!tidx && !threadIdx.y && !blockIdx.x) {
473
+ //
474
+ // printf("interp of %f, %f, %f\n", point.x, point.y, point.z);
475
+ // printf("hr_side: %d\n", hr_side);
476
+ // printArray("vox_data", 6, dimt, __vox_data_sh[tidy]);
477
+ // printArray("msk_data", 6, nmsk, __msk_data_sh[tidy]);
478
+ //}
479
+ //break;
480
+
481
+ __syncwarp(WMASK);
482
+
483
+ if (rv == 0) {
484
+
485
+ ndotp_d<BDIM_X>(hr_side, hr_side, __msk_data_sh, R, __r_sh);
486
+ //__syncwarp();
487
+ //printArray("__r", 5, hr_side*hr_side, R);
488
+ //printArray("__r_sh", 6, hr_side, __r_sh[tidy]);
489
+
490
+ ndotp_d<BDIM_X>(hr_side, hr_side, __msk_data_sh, H, __h_sh);
491
+ //__syncwarp();
492
+ //printArray("__h_sh", 6, hr_side, __h_sh[tidy]);
493
+
494
+ __syncwarp(WMASK);
495
+
496
+ for(int j = 0; j < hr_side; j += BDIM_X) {
497
+ if (j+tidx < hr_side) {
498
+ #ifdef USE_FIXED_PERMUTATION
499
+ const int srcPermInd = fixedPerm[j+tidx];
500
+ #else
501
+ const int srcPermInd = curand(st) % hr_side;
502
+ // if (srcPermInd < 0 || srcPermInd >= hr_side) {
503
+ // printf("srcPermInd: %d\n", srcPermInd);
504
+ // }
505
+ #endif
506
+ __h_sh[j+tidx] += __r_sh[srcPermInd];
507
+ //__h_sh[j+tidx] += __r_sh[j+tidx];
508
+ }
509
+ }
510
+ __syncwarp(WMASK);
511
+
512
+ //printArray("h+perm(r):", 6, hr_side, __h_sh[tidy]);
513
+ //__syncwarp();
514
+
515
+ // vox_data[dwi_mask] = masked_data
516
+ maskPut<BDIM_X>(dimt, b0s_mask, __h_sh, __vox_data_sh);
517
+ __syncwarp(WMASK);
518
+
519
+ //printArray("vox_data[dwi_mask]:", 6, dimt, __vox_data_sh[tidy]);
520
+ //__syncwarp();
521
+
522
+ for(int j = tidx; j < dimt; j += BDIM_X) {
523
+ //__vox_data_sh[j] = MAX(MIN_SIGNAL_P, __vox_data_sh[j]);
524
+ __vox_data_sh[j] = MAX(min_signal, __vox_data_sh[j]);
525
+ }
526
+ __syncwarp(WMASK);
527
+
528
+ const REAL_T denom = avgMask<BDIM_X>(dimt, b0s_mask, __vox_data_sh);
529
+
530
+ for(int j = tidx; j < dimt; j += BDIM_X) {
531
+ __vox_data_sh[j] /= denom;
532
+ }
533
+ __syncwarp();
534
+
535
+ //if (!tidx && !threadIdx.y && !blockIdx.x) {
536
+ // printf("denom: %f\n", denom);
537
+ //}
538
+ ////break;
539
+ //if (!tidx && !threadIdx.y && !blockIdx.x) {
540
+ //
541
+ // printf("__vox_data_sh:\n");
542
+ // printArray("vox_data", 6, dimt, __vox_data_sh[tidy]);
543
+ //}
544
+ //break;
545
+
546
+ maskGet<BDIM_X>(dimt, b0s_mask, __vox_data_sh, __msk_data_sh);
547
+ __syncwarp(WMASK);
548
+
549
+ fit_model_coef<BDIM_X, MODEL_T>(delta_nr, hr_side, delta_q, delta_b, __msk_data_sh, __h_sh, __r_sh);
550
+
551
+ // __r_sh[tidy] <- python 'coef'
552
+
553
+ ndotp_d<BDIM_X>(samplm_nr, delta_nr, __r_sh, sampling_matrix, __h_sh);
554
+
555
+ // __h_sh[tidy] <- python 'pmf'
556
+ } else {
557
+ #pragma unroll
558
+ for(int j = tidx; j < samplm_nr; j += BDIM_X) {
559
+ __h_sh[j] = 0;
560
+ }
561
+ // __h_sh[tidy] <- python 'pmf'
562
+ }
563
+ __syncwarp(WMASK);
564
+ #if 0
565
+ if (!threadIdx.y && threadIdx.x == 0) {
566
+ for(int j = 0; j < samplm_nr; j++) {
567
+ printf("pmf[%d]: %f\n", j, __h_sh[tidy][j]);
568
+ }
569
+ }
570
+ //return;
571
+ #endif
572
+ const REAL_T abs_pmf_thr = PMF_THRESHOLD_P*max_d<BDIM_X>(samplm_nr, __h_sh, REAL_MIN);
573
+ __syncwarp(WMASK);
574
+
575
+ #pragma unroll
576
+ for(int j = tidx; j < samplm_nr; j += BDIM_X) {
577
+ const REAL_T __v = __h_sh[j];
578
+ if (__v < abs_pmf_thr) {
579
+ __h_sh[j] = 0;
580
+ }
581
+ }
582
+ __syncwarp(WMASK);
583
+ #if 0
584
+ if (!threadIdx.y && threadIdx.x == 0) {
585
+ printf("abs_pmf_thr: %f\n", abs_pmf_thr);
586
+ for(int j = 0; j < samplm_nr; j++) {
587
+ printf("pmfNORM[%d]: %f\n", j, __h_sh[tidy][j]);
588
+ }
589
+ }
590
+ //return;
591
+ #endif
592
+ #if 0
593
+ if init:
594
+ directions = peak_directions(pmf, sphere)[0]
595
+ return directions
596
+ else:
597
+ peaks = peak_directions(pmf, sphere)[0]
598
+ if (len(peaks) > 0):
599
+ return closest_peak(directions, peaks, cos_similarity)
600
+ #endif
601
+ const int ndir = peak_directions_d<BDIM_X,
602
+ BDIM_Y>(__h_sh, dirs,
603
+ sphere_vertices,
604
+ sphere_edges,
605
+ num_edges,
606
+ samplm_nr,
607
+ reinterpret_cast<int *>(__r_sh), // reuse __r_sh as shInd in func which is large enough
608
+ relative_peak_thres,
609
+ min_separation_angle);
610
+ if (NATTEMPTS == 1) { // init=True...
611
+ return ndir; // and dirs;
612
+ } else { // init=False...
613
+ if (ndir > 0) {
614
+ /*
615
+ if (!threadIdx.y && threadIdx.x == 0 && ndir > 1) {
616
+ printf("NATTEMPTS=5 and ndir: %d!!!\n", ndir);
617
+ }
618
+ */
619
+ REAL3_T peak;
620
+ const int foundPeak = closest_peak_d<BDIM_X, BDIM_Y, REAL_T, REAL3_T>(max_angle, dir, ndir, dirs, &peak);
621
+ __syncwarp(WMASK);
622
+ if (foundPeak) {
623
+ if (tidx == 0) {
624
+ dirs[0] = peak;
625
+ }
626
+ return 1;
627
+ }
628
+ }
629
+ }
630
+ }
631
+ return 0;
632
+ }
633
+
634
+ template<int BDIM_X,
635
+ int BDIM_Y,
636
+ typename REAL_T,
637
+ typename REAL3_T>
638
+ __global__ void getNumStreamlinesBoot_k(
639
+ const ModelType model_type,
640
+ const REAL_T max_angle,
641
+ const REAL_T min_signal,
642
+ const REAL_T relative_peak_thres,
643
+ const REAL_T min_separation_angle,
644
+ const long long rndSeed,
645
+ const int nseed,
646
+ const REAL3_T *__restrict__ seeds,
647
+ const int dimx,
648
+ const int dimy,
649
+ const int dimz,
650
+ const int dimt,
651
+ const REAL_T *__restrict__ dataf,
652
+ const REAL_T *__restrict__ H,
653
+ const REAL_T *__restrict__ R,
654
+ const int delta_nr,
655
+ const REAL_T *__restrict__ delta_b,
656
+ const REAL_T *__restrict__ delta_q,
657
+ const int *__restrict__ b0s_mask, // change to int
658
+ const int samplm_nr,
659
+ const REAL_T *__restrict__ sampling_matrix,
660
+ const REAL3_T *__restrict__ sphere_vertices,
661
+ const int2 *__restrict__ sphere_edges,
662
+ const int num_edges,
663
+ REAL3_T *__restrict__ shDir0,
664
+ int *slineOutOff) {
665
+
666
+ const int tidx = threadIdx.x;
667
+ const int slid = blockIdx.x*blockDim.y + threadIdx.y;
668
+ const size_t gid = blockIdx.x * blockDim.y * blockDim.x + blockDim.x * threadIdx.y + threadIdx.x;
669
+
670
+ if (slid >= nseed) {
671
+ return;
672
+ }
673
+
674
+ REAL3_T seed = seeds[slid];
675
+ // seed = lin_mat*seed + offset
676
+
677
+ REAL3_T *__restrict__ __shDir = shDir0+slid*samplm_nr;
678
+
679
+ // const int hr_side = dimt-1;
680
+
681
+ curandStatePhilox4_32_10_t st;
682
+ //curand_init(rndSeed, slid + rndOffset, DIV_UP(hr_side, BDIM_X)*tidx, &st); // each thread uses DIV_UP(hr_side/BDIM_X)
683
+ curand_init(rndSeed, gid, 0, &st); // each thread uses DIV_UP(hr_side/BDIM_X)
684
+ // elements of the same sequence
685
+ // python:
686
+ //directions = get_direction(None, dataf, dwi_mask, sphere, s, H, R, model, max_angle,
687
+ // pmf_threshold, b0s_mask, min_signal, fit_matrix,
688
+ // sampling_matrix, init=True)
689
+
690
+ //if (!tidx && !threadIdx.y && !blockIdx.x) {
691
+ // printf("seed: %f, %f, %f\n", seed.x, seed.y, seed.z);
692
+ //}
693
+
694
+ int ndir;
695
+ switch(model_type) {
696
+ case OPDT:
697
+ ndir = get_direction_boot_d<BDIM_X,
698
+ BDIM_Y,
699
+ 1,
700
+ OPDT>(
701
+ &st,
702
+ max_angle,
703
+ min_signal,
704
+ relative_peak_thres,
705
+ min_separation_angle,
706
+ MAKE_REAL3(0,0,0),
707
+ dimx, dimy, dimz, dimt, dataf,
708
+ b0s_mask /* !dwi_mask */,
709
+ seed,
710
+ H, R,
711
+ // model unused
712
+ // max_angle, pmf_threshold from global defines
713
+ // b0s_mask already passed
714
+ // min_signal from global defines
715
+ delta_nr,
716
+ delta_b, delta_q, // fit_matrix
717
+ samplm_nr,
718
+ sampling_matrix,
719
+ sphere_vertices,
720
+ sphere_edges,
721
+ num_edges,
722
+ __shDir);
723
+ break;
724
+ case CSA:
725
+ ndir = get_direction_boot_d<BDIM_X,
726
+ BDIM_Y,
727
+ 1,
728
+ CSA>(
729
+ &st,
730
+ max_angle,
731
+ min_signal,
732
+ relative_peak_thres,
733
+ min_separation_angle,
734
+ MAKE_REAL3(0,0,0),
735
+ dimx, dimy, dimz, dimt, dataf,
736
+ b0s_mask /* !dwi_mask */,
737
+ seed,
738
+ H, R,
739
+ // model unused
740
+ // max_angle, pmf_threshold from global defines
741
+ // b0s_mask already passed
742
+ // min_signal from global defines
743
+ delta_nr,
744
+ delta_b, delta_q, // fit_matrix
745
+ samplm_nr,
746
+ sampling_matrix,
747
+ sphere_vertices,
748
+ sphere_edges,
749
+ num_edges,
750
+ __shDir);
751
+ break;
752
+ default:
753
+ printf("FATAL: Invalid Model Type.\n");
754
+ break;
755
+ }
756
+
757
+ if (tidx == 0) {
758
+ slineOutOff[slid] = ndir;
759
+ }
760
+
761
+ return;
762
+ }
763
+
764
+ template<int BDIM_X,
765
+ int BDIM_Y,
766
+ ModelType MODEL_T,
767
+ typename REAL_T,
768
+ typename REAL3_T>
769
+ __device__ int tracker_boot_d(
770
+ curandStatePhilox4_32_10_t *st,
771
+ const REAL_T max_angle,
772
+ const REAL_T tc_threshold,
773
+ const REAL_T step_size,
774
+ const REAL_T relative_peak_thres,
775
+ const REAL_T min_separation_angle,
776
+ REAL3_T seed,
777
+ REAL3_T first_step,
778
+ REAL3_T voxel_size,
779
+ const int dimx,
780
+ const int dimy,
781
+ const int dimz,
782
+ const int dimt,
783
+ const REAL_T *__restrict__ dataf,
784
+ const REAL_T *__restrict__ metric_map,
785
+ const int samplm_nr,
786
+ const REAL3_T *__restrict__ sphere_vertices,
787
+ const int2 *__restrict__ sphere_edges,
788
+ const int num_edges,
789
+ /*BOOT specific params*/
790
+ const REAL_T min_signal,
791
+ const int delta_nr,
792
+ const REAL_T *__restrict__ H,
793
+ const REAL_T *__restrict__ R,
794
+ const REAL_T *__restrict__ delta_b,
795
+ const REAL_T *__restrict__ delta_q,
796
+ const REAL_T *__restrict__ sampling_matrix,
797
+ const int *__restrict__ b0s_mask,
798
+ /*BOOT specific params*/
799
+ int *__restrict__ nsteps,
800
+ REAL3_T *__restrict__ streamline) {
801
+
802
+ const int tidx = threadIdx.x;
803
+ const int tidy = threadIdx.y;
804
+
805
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
806
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
807
+
808
+ int tissue_class = TRACKPOINT;
809
+
810
+ REAL3_T point = seed;
811
+ REAL3_T direction = first_step;
812
+ __shared__ REAL3_T __sh_new_dir[BDIM_Y];
813
+
814
+ if (tidx == 0) {
815
+ streamline[0] = point;
816
+ #if 0
817
+ if (threadIdx.y == 1) {
818
+ printf("streamline[0]: %f, %f, %f\n", point.x, point.y, point.z);
819
+ }
820
+ #endif
821
+ }
822
+ __syncwarp(WMASK);
823
+
824
+ int step_frac = 1;
825
+
826
+ int i;
827
+ for(i = 1; i < MAX_SLINE_LEN*step_frac; i++) {
828
+ int ndir = get_direction_boot_d<BDIM_X,
829
+ BDIM_Y,
830
+ 5,
831
+ MODEL_T>(
832
+ st,
833
+ max_angle,
834
+ min_signal,
835
+ relative_peak_thres,
836
+ min_separation_angle,
837
+ direction,
838
+ dimx, dimy, dimz, dimt, dataf,
839
+ b0s_mask /* !dwi_mask */,
840
+ point,
841
+ H, R,
842
+ delta_nr,
843
+ delta_b, delta_q, // fit_matrix
844
+ samplm_nr,
845
+ sampling_matrix,
846
+ sphere_vertices,
847
+ sphere_edges,
848
+ num_edges,
849
+ __sh_new_dir + tidy);
850
+ __syncwarp(WMASK);
851
+ direction = __sh_new_dir[tidy];
852
+ __syncwarp(WMASK);
853
+
854
+ if (ndir == 0) {
855
+ break;
856
+ }
857
+
858
+ point.x += (direction.x / voxel_size.x) * (step_size / step_frac);
859
+ point.y += (direction.y / voxel_size.y) * (step_size / step_frac);
860
+ point.z += (direction.z / voxel_size.z) * (step_size / step_frac);
861
+
862
+ if ((tidx == 0) && ((i % step_frac) == 0)){
863
+ streamline[i/step_frac] = point;
864
+ }
865
+ __syncwarp(WMASK);
866
+
867
+ tissue_class = check_point_d<BDIM_X, BDIM_Y>(tc_threshold, point, dimx, dimy, dimz, metric_map);
868
+
869
+ if (tissue_class == ENDPOINT ||
870
+ tissue_class == INVALIDPOINT ||
871
+ tissue_class == OUTSIDEIMAGE) {
872
+ break;
873
+ }
874
+ }
875
+ nsteps[0] = i/step_frac;
876
+ if (((i % step_frac) != 0) && i < step_frac*(MAX_SLINE_LEN - 1)){
877
+ nsteps[0]++;
878
+ if (tidx == 0) {
879
+ streamline[nsteps[0]] = point;
880
+ }
881
+ }
882
+
883
+ return tissue_class;
884
+ }
885
+
886
+ template<int BDIM_X,
887
+ int BDIM_Y,
888
+ ModelType MODEL_T,
889
+ typename REAL_T,
890
+ typename REAL3_T>
891
+ __global__ void genStreamlinesMergeBoot_k(
892
+ const REAL_T max_angle,
893
+ const REAL_T tc_threshold,
894
+ const REAL_T step_size,
895
+ const REAL_T relative_peak_thres,
896
+ const REAL_T min_separation_angle,
897
+ const long long rndSeed,
898
+ const int rndOffset,
899
+ const int nseed,
900
+ const REAL3_T *__restrict__ seeds,
901
+ const int dimx,
902
+ const int dimy,
903
+ const int dimz,
904
+ const int dimt,
905
+ const REAL_T *__restrict__ dataf,
906
+ const REAL_T *__restrict__ metric_map,
907
+ const int samplm_nr,
908
+ const REAL3_T *__restrict__ sphere_vertices,
909
+ const int2 *__restrict__ sphere_edges,
910
+ const int num_edges,
911
+ /*BOOT specific params*/
912
+ const REAL_T min_signal,
913
+ const int delta_nr,
914
+ const REAL_T *__restrict__ H,
915
+ const REAL_T *__restrict__ R,
916
+ const REAL_T *__restrict__ delta_b,
917
+ const REAL_T *__restrict__ delta_q,
918
+ const REAL_T *__restrict__ sampling_matrix,
919
+ const int *__restrict__ b0s_mask,
920
+ /*BOOT specific params*/
921
+ const int *__restrict__ slineOutOff,
922
+ REAL3_T *__restrict__ shDir0,
923
+ int *__restrict__ slineSeed,
924
+ int *__restrict__ slineLen,
925
+ REAL3_T *__restrict__ sline) {
926
+
927
+ const int tidx = threadIdx.x;
928
+ const int tidy = threadIdx.y;
929
+
930
+ const int slid = blockIdx.x*blockDim.y + threadIdx.y;
931
+
932
+ const int lid = (tidy*BDIM_X + tidx) % 32;
933
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
934
+
935
+ curandStatePhilox4_32_10_t st;
936
+ // const int gbid = blockIdx.y*gridDim.x + blockIdx.x;
937
+ const size_t gid = blockIdx.x * blockDim.y * blockDim.x + blockDim.x * threadIdx.y + threadIdx.x;
938
+ //curand_init(rndSeed, slid+rndOffset, DIV_UP(hr_side, BDIM_X)*tidx, &st); // each thread uses DIV_UP(HR_SIDE/BDIM_X)
939
+ curand_init(rndSeed, gid+1, 0, &st); // each thread uses DIV_UP(hr_side/BDIM_X)
940
+ // elements of the same sequence
941
+ if (slid >= nseed) {
942
+ return;
943
+ }
944
+
945
+ REAL3_T seed = seeds[slid];
946
+
947
+ int ndir = slineOutOff[slid+1]-slineOutOff[slid];
948
+ #if 0
949
+ if (threadIdx.y == 0 && threadIdx.x == 0) {
950
+ printf("%s: ndir: %d\n", __func__, ndir);
951
+ for(int i = 0; i < ndir; i++) {
952
+ printf("__shDir[%d][%d]: (%f, %f, %f)\n",
953
+ tidy, i, __shDir[tidy][i].x, __shDir[tidy][i].y, __shDir[tidy][i].z);
954
+ }
955
+ }
956
+ #endif
957
+ __syncwarp(WMASK);
958
+
959
+ int slineOff = slineOutOff[slid];
960
+
961
+ for(int i = 0; i < ndir; i++) {
962
+ REAL3_T first_step = shDir0[slid*samplm_nr + i];
963
+
964
+ REAL3_T *__restrict__ currSline = sline + slineOff*MAX_SLINE_LEN*2;
965
+
966
+ if (tidx == 0) {
967
+ slineSeed[slineOff] = slid;
968
+ }
969
+ #if 0
970
+ if (threadIdx.y == 0 && threadIdx.x == 0) {
971
+ printf("calling trackerF from: (%f, %f, %f)\n", first_step.x, first_step.y, first_step.z);
972
+ }
973
+ #endif
974
+
975
+ int stepsB;
976
+ const int tissue_classB = tracker_boot_d<BDIM_X,
977
+ BDIM_Y,
978
+ MODEL_T>(
979
+ &st,
980
+ max_angle,
981
+ tc_threshold,
982
+ step_size,
983
+ relative_peak_thres,
984
+ min_separation_angle,
985
+ seed,
986
+ MAKE_REAL3(-first_step.x, -first_step.y, -first_step.z),
987
+ MAKE_REAL3(1, 1, 1),
988
+ dimx, dimy, dimz, dimt, dataf,
989
+ metric_map,
990
+ samplm_nr,
991
+ sphere_vertices,
992
+ sphere_edges,
993
+ num_edges,
994
+ min_signal,
995
+ delta_nr,
996
+ H,
997
+ R,
998
+ delta_b,
999
+ delta_q,
1000
+ sampling_matrix,
1001
+ b0s_mask,
1002
+ &stepsB,
1003
+ currSline);
1004
+
1005
+ // reverse backward sline
1006
+ for(int j = 0; j < stepsB/2; j += BDIM_X) {
1007
+ if (j+tidx < stepsB/2) {
1008
+ const REAL3_T __p = currSline[j+tidx];
1009
+ currSline[j+tidx] = currSline[stepsB-1 - (j+tidx)];
1010
+ currSline[stepsB-1 - (j+tidx)] = __p;
1011
+ }
1012
+ }
1013
+
1014
+ int stepsF;
1015
+ const int tissue_classF = tracker_boot_d<BDIM_X,
1016
+ BDIM_Y,
1017
+ MODEL_T>(
1018
+ &st,
1019
+ max_angle,
1020
+ tc_threshold,
1021
+ step_size,
1022
+ relative_peak_thres,
1023
+ min_separation_angle,
1024
+ seed,
1025
+ first_step,
1026
+ MAKE_REAL3(1, 1, 1),
1027
+ dimx, dimy, dimz, dimt, dataf,
1028
+ metric_map,
1029
+ samplm_nr,
1030
+ sphere_vertices,
1031
+ sphere_edges,
1032
+ num_edges,
1033
+ min_signal,
1034
+ delta_nr,
1035
+ H,
1036
+ R,
1037
+ delta_b,
1038
+ delta_q,
1039
+ sampling_matrix,
1040
+ b0s_mask,
1041
+ &stepsF,
1042
+ currSline + stepsB-1);
1043
+ if (tidx == 0) {
1044
+ slineLen[slineOff] = stepsB-1+stepsF;
1045
+ }
1046
+
1047
+ slineOff += 1;
1048
+ #if 0
1049
+ if (threadIdx.y == 0 && threadIdx.x == 0) {
1050
+ printf("%s: stepsF: %d, tissue_classF: %d\n", __func__, stepsF, tissue_classF);
1051
+ }
1052
+ __syncwarp(WMASK);
1053
+ #endif
1054
+ //if (/* !return_all || */0 &&
1055
+ // tissue_classF != ENDPOINT &&
1056
+ // tissue_classF != OUTSIDEIMAGE) {
1057
+ // continue;
1058
+ //}
1059
+ //if (/* !return_all || */ 0 &&
1060
+ // tissue_classB != ENDPOINT &&
1061
+ // tissue_classB != OUTSIDEIMAGE) {
1062
+ // continue;
1063
+ //}
1064
+ }
1065
+ return;
1066
+ }