cuslines 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,103 @@
1
+ /* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ *
3
+ * Redistribution and use in source and binary forms, with or without
4
+ * modification, are permitted provided that the following conditions are met:
5
+ *
6
+ * 1. Redistributions of source code must retain the above copyright notice, this
7
+ * list of conditions and the following disclaimer.
8
+ *
9
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
10
+ * this list of conditions and the following disclaimer in the documentation
11
+ * and/or other materials provided with the distribution.
12
+ *
13
+ * 3. Neither the name of the copyright holder nor the names of its
14
+ * contributors may be used to endorse or promote products derived from
15
+ * this software without specific prior written permission.
16
+ *
17
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
21
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
24
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
+ */
28
+
29
+ #ifndef __GLOBALS_H__
30
+ #define __GLOBALS_H__
31
+
32
+ #define REAL_SIZE 4
33
+
34
+ #if REAL_SIZE == 4
35
+
36
+ #define REAL float
37
+ #define REAL3 float3
38
+ #define MAKE_REAL3 make_float3
39
+ #define RCONV "%f"
40
+ #define FLOOR floorf
41
+ #define LOG __logf
42
+ #define EXP __expf
43
+ #define REAL_MAX __int_as_float(0x7f7fffffU)
44
+ #define REAL_MIN (-REAL_MAX)
45
+ #define COS __cosf
46
+ #define SIN __sinf
47
+ #define FABS fabsf
48
+ #define SQRT sqrtf
49
+ #define RSQRT rsqrtf
50
+ #define ACOS acosf
51
+
52
+ #elif REAL_SIZE == 8
53
+
54
+ #define REAL double
55
+ #define REAL3 double3
56
+ #define MAKE_REAL3 make_double3
57
+ #define RCONV "%lf"
58
+ #define FLOOR floor
59
+ #define LOG log
60
+ #define EXP exp
61
+ #define REAL_MAX __longlong_as_double(0x7fefffffffffffffLL)
62
+ #define REAL_MIN (-REAL_MAX)
63
+ #define COS cos
64
+ #define SIN sin
65
+ #define FABS fabs
66
+ #define SQRT sqrt
67
+ #define RSQRT rsqrt
68
+ #define ACOS acos
69
+
70
+ #endif
71
+ // TODO: half this in when WMGMI seeding
72
+ #define MAX_SLINE_LEN (501)
73
+ #define PMF_THRESHOLD_P ((REAL)0.05)
74
+
75
+ #define THR_X_BL (64)
76
+ #define THR_X_SL (32)
77
+
78
+ #define MAX_SLINES_PER_SEED (10)
79
+
80
+ #define MIN(x,y) (((x)<(y))?(x):(y))
81
+ #define MAX(x,y) (((x)>(y))?(x):(y))
82
+ #define POW2(n) (1 << (n))
83
+
84
+ #define DIV_UP(a,b) (((a)+((b)-1))/(b))
85
+
86
+ #define EXCESS_ALLOC_FACT 2
87
+
88
+ #define NORM_EPS ((REAL)1e-8)
89
+
90
+ #if 0
91
+ #define DEBUG
92
+ #endif
93
+
94
+ enum ModelType {
95
+ OPDT = 0,
96
+ CSA = 1,
97
+ PROB = 2,
98
+ PTT = 3,
99
+ };
100
+
101
+ enum {OUTSIDEIMAGE, INVALIDPOINT, TRACKPOINT, ENDPOINT};
102
+
103
+ #endif
cuslines/cuda_c/ptt.cu ADDED
@@ -0,0 +1,559 @@
1
+ template<typename REAL_T>
2
+ __device__ __forceinline__ void norm3_d(REAL_T *num, int fail_ind) {
3
+ const REAL_T scale = SQRT(num[0] * num[0] + num[1] * num[1] + num[2] * num[2]);
4
+
5
+ if (scale > NORM_EPS) {
6
+ num[0] /= scale;
7
+ num[1] /= scale;
8
+ num[2] /= scale;
9
+ } else {
10
+ num[0] = num[1] = num[2] = 0;
11
+ num[fail_ind] = 1.0; // this can happen randomly during propogation, though is exceedingly rare
12
+ }
13
+ }
14
+
15
+ template<typename REAL_T>
16
+ __device__ __forceinline__ void crossnorm3_d(REAL_T *dest, const REAL_T *src1, const REAL_T *src2, int fail_ind) {
17
+ dest[0] = src1[1] * src2[2] - src1[2] * src2[1];
18
+ dest[1] = src1[2] * src2[0] - src1[0] * src2[2];
19
+ dest[2] = src1[0] * src2[1] - src1[1] * src2[0];
20
+
21
+ norm3_d(dest, fail_ind);
22
+ }
23
+
24
+ template<int BDIM_X, typename REAL_T, typename REAL3_T>
25
+ __device__ REAL_T interp4_d(const REAL3_T pos, const REAL_T* frame, const REAL_T *__restrict__ pmf,
26
+ const int dimx, const int dimy, const int dimz, const int dimt,
27
+ const REAL3_T *__restrict__ odf_sphere_vertices) {
28
+ const int tidx = threadIdx.x;
29
+
30
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
31
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
32
+
33
+ int closest_odf_idx = 0;
34
+ REAL_T __max_cos = REAL_T(0);
35
+
36
+ #pragma unroll
37
+ for (int ii = tidx; ii < dimt; ii+= BDIM_X) { // TODO: I need to think about better ways of parallelizing this
38
+ REAL_T cos_sim = FABS(
39
+ odf_sphere_vertices[ii].x * frame[0] \
40
+ + odf_sphere_vertices[ii].y * frame[1] \
41
+ + odf_sphere_vertices[ii].z * frame[2]);
42
+ if (cos_sim > __max_cos) {
43
+ __max_cos = cos_sim;
44
+ closest_odf_idx = ii;
45
+ }
46
+ }
47
+ __syncwarp(WMASK);
48
+
49
+ #pragma unroll
50
+ for(int i = BDIM_X/2; i; i /= 2) {
51
+ const REAL_T __tmp = __shfl_xor_sync(WMASK, __max_cos, i, BDIM_X);
52
+ const int __tmp_idx = __shfl_xor_sync(WMASK, closest_odf_idx, i, BDIM_X);
53
+ if (__tmp > __max_cos ||
54
+ (__tmp == __max_cos && __tmp_idx < closest_odf_idx)) {
55
+ __max_cos = __tmp;
56
+ closest_odf_idx = __tmp_idx;
57
+ }
58
+ }
59
+ __syncwarp(WMASK);
60
+
61
+ #if 0
62
+ if (closest_odf_idx >= dimt || closest_odf_idx < 0) {
63
+ printf("Error: closest_odf_idx out of bounds: %d (dimt: %d)\n", closest_odf_idx, dimt);
64
+ }
65
+ #endif
66
+
67
+ // TODO: maybe this should be texture memory, I am not so sure
68
+ const int rv = trilinear_interp_d<THR_X_SL>(dimx, dimy, dimz, dimt, closest_odf_idx, pmf, pos, &__max_cos);
69
+
70
+ if (rv != 0) {
71
+ return 0; // No support
72
+ } else {
73
+ return __max_cos;
74
+ }
75
+ }
76
+
77
+ template<typename REAL_T>
78
+ __device__ void prepare_propagator_d(REAL_T k1, REAL_T k2, REAL_T arclength,
79
+ REAL_T *propagator) {
80
+ if ((FABS(k1) < K_SMALL) && (FABS(k2) < K_SMALL)) {
81
+ propagator[0] = arclength;
82
+ propagator[1] = 0;
83
+ propagator[2] = 0;
84
+ propagator[3] = 1;
85
+ propagator[4] = 0;
86
+ propagator[5] = 0;
87
+ propagator[6] = 0;
88
+ propagator[7] = 0;
89
+ propagator[8] = 1;
90
+ } else {
91
+ if (FABS(k1) < K_SMALL) {
92
+ k1 = K_SMALL;
93
+ }
94
+ if (FABS(k2) < K_SMALL) {
95
+ k2 = K_SMALL;
96
+ }
97
+ const REAL_T k = SQRT(k1*k1+k2*k2);
98
+ const REAL_T sinkt = SIN(k*arclength);
99
+ const REAL_T coskt = COS(k*arclength);
100
+ const REAL_T kk = 1/(k*k);
101
+
102
+ propagator[0] = sinkt/k;
103
+ propagator[1] = k1*(1-coskt)*kk;
104
+ propagator[2] = k2*(1-coskt)*kk;
105
+ propagator[3] = coskt;
106
+ propagator[4] = k1*sinkt/k;
107
+ propagator[5] = k2*sinkt/k;
108
+ propagator[6] = -propagator[5];
109
+ propagator[7] = k1*k2*(coskt-1)*kk;
110
+ propagator[8] = (k1*k1+k2*k2*coskt)*kk;
111
+ }
112
+ }
113
+
114
+ template<typename REAL_T>
115
+ __device__ void random_normal(curandStatePhilox4_32_10_t *st, REAL_T* probing_frame) {
116
+ probing_frame[3] = curand_normal(st);
117
+ probing_frame[4] = curand_normal(st);
118
+ probing_frame[5] = curand_normal(st);
119
+ REAL_T dot = probing_frame[3]*probing_frame[0]
120
+ + probing_frame[4]*probing_frame[1]
121
+ + probing_frame[5]*probing_frame[2];
122
+
123
+ probing_frame[3] -= dot*probing_frame[0];
124
+ probing_frame[4] -= dot*probing_frame[1];
125
+ probing_frame[5] -= dot*probing_frame[2];
126
+ REAL_T n2 = probing_frame[3]*probing_frame[3]
127
+ + probing_frame[4]*probing_frame[4]
128
+ + probing_frame[5]*probing_frame[5];
129
+
130
+ if (n2 < NORM_EPS) {
131
+ REAL_T abs_x = FABS(probing_frame[0]);
132
+ REAL_T abs_y = FABS(probing_frame[1]);
133
+ REAL_T abs_z = FABS(probing_frame[2]);
134
+
135
+ if (abs_x <= abs_y && abs_x <= abs_z) {
136
+ probing_frame[3] = 0.0;
137
+ probing_frame[4] = probing_frame[2];
138
+ probing_frame[5] = -probing_frame[1];
139
+ }
140
+ else if (abs_y <= abs_z) {
141
+ probing_frame[3] = -probing_frame[2];
142
+ probing_frame[4] = 0.0;
143
+ probing_frame[5] = probing_frame[0];
144
+ }
145
+ else {
146
+ probing_frame[3] = probing_frame[1];
147
+ probing_frame[4] = -probing_frame[0];
148
+ probing_frame[5] = 0.0;
149
+ }
150
+ }
151
+ }
152
+
153
+ template<bool IS_INIT, typename REAL_T>
154
+ __device__ void get_probing_frame_d(const REAL_T* frame, curandStatePhilox4_32_10_t *st, REAL_T* probing_frame) {
155
+ if (IS_INIT) {
156
+ for (int ii = 0; ii < 3; ii++) { // tangent
157
+ probing_frame[ii] = frame[ii];
158
+ }
159
+ norm3_d(probing_frame, 0);
160
+
161
+ random_normal(st, probing_frame);
162
+ norm3_d(probing_frame + 3, 1); // norm
163
+
164
+ // calculate binorm
165
+ crossnorm3_d(probing_frame + 2*3, probing_frame, probing_frame + 3, 2); // binorm
166
+ } else {
167
+ for (int ii = 0; ii < 9; ii++) {
168
+ probing_frame[ii] = frame[ii];
169
+ }
170
+ }
171
+ }
172
+
173
+ template<typename REAL_T>
174
+ __device__ void propagate_frame_d(REAL_T* propagator, REAL_T* frame, REAL_T* direc) {
175
+ REAL_T __tmp[3];
176
+
177
+ for (int ii = 0; ii < 3; ii++) {
178
+ direc[ii] = propagator[0]*frame[ii] + propagator[1]*frame[3+ii] + propagator[2]*frame[6+ii];
179
+ __tmp[ii] = propagator[3]*frame[ii] + propagator[4]*frame[3+ii] + propagator[5]*frame[6+ii];
180
+ frame[2*3 + ii] = propagator[6]*frame[ii] + propagator[7]*frame[3+ii] + propagator[8]*frame[6+ii];
181
+ }
182
+
183
+ norm3_d(__tmp, 0); // normalize tangent
184
+ crossnorm3_d(frame + 3, frame + 2*3, __tmp, 1); // calc normal
185
+ crossnorm3_d(frame + 2*3, __tmp, frame + 3, 2); // calculate binorm from tangent, norm
186
+
187
+ for (int ii = 0; ii < 3; ii++) {
188
+ frame[ii] = __tmp[ii];
189
+ }
190
+ }
191
+
192
+ template<int BDIM_X, typename REAL_T, typename REAL3_T>
193
+ __device__ REAL_T calculate_data_support_d(REAL_T support,
194
+ const REAL3_T pos, const REAL_T *__restrict__ pmf,
195
+ const int dimx, const int dimy, const int dimz, const int dimt,
196
+ const REAL_T probe_step_size,
197
+ const REAL_T absolpmf_thresh,
198
+ const REAL3_T *__restrict__ odf_sphere_vertices,
199
+ REAL_T* probing_prop_sh,
200
+ REAL_T* direc_sh,
201
+ REAL3_T* probing_pos_sh,
202
+ REAL_T* k1_sh, REAL_T* k2_sh,
203
+ REAL_T* probing_frame_sh) {
204
+ const int tidx = threadIdx.x;
205
+
206
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
207
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
208
+
209
+ if (tidx == 0) {
210
+ prepare_propagator_d(
211
+ *k1_sh, *k2_sh,
212
+ probe_step_size, probing_prop_sh);
213
+ probing_pos_sh->x = pos.x;
214
+ probing_pos_sh->y = pos.y;
215
+ probing_pos_sh->z = pos.z;
216
+ }
217
+ __syncwarp(WMASK);
218
+
219
+ for (int ii = 0; ii < PROBE_QUALITY; ii++) { // we spend about 2/3 of our time in this loop when doing PTT
220
+ if (tidx == 0) {
221
+ propagate_frame_d(
222
+ probing_prop_sh,
223
+ probing_frame_sh,
224
+ direc_sh);
225
+
226
+ probing_pos_sh->x += direc_sh[0];
227
+ probing_pos_sh->y += direc_sh[1];
228
+ probing_pos_sh->z += direc_sh[2];
229
+ }
230
+ __syncwarp(WMASK);
231
+
232
+ const REAL_T fod_amp = interp4_d<BDIM_X>( // This is the most expensive call
233
+ *probing_pos_sh, probing_frame_sh, pmf,
234
+ dimx, dimy, dimz, dimt,
235
+ odf_sphere_vertices);
236
+
237
+ if (!ALLOW_WEAK_LINK && (fod_amp < absolpmf_thresh)) {
238
+ return 0;
239
+ }
240
+ support += fod_amp;
241
+ }
242
+ return support;
243
+ }
244
+
245
+ template<int BDIM_X,
246
+ int BDIM_Y,
247
+ bool IS_INIT,
248
+ typename REAL_T,
249
+ typename REAL3_T>
250
+ __device__ int get_direction_ptt_d(
251
+ curandStatePhilox4_32_10_t *st,
252
+ const REAL_T *__restrict__ pmf,
253
+ const REAL_T max_angle,
254
+ const REAL_T step_size,
255
+ REAL3_T dir,
256
+ REAL_T *__frame_sh,
257
+ const int dimx, const int dimy, const int dimz, const int dimt,
258
+ REAL3_T pos,
259
+ const REAL3_T *__restrict__ odf_sphere_vertices,
260
+ REAL3_T *__restrict__ dirs) {
261
+ // Aydogan DB, Shi Y. Parallel Transport Tractography. IEEE Trans
262
+ // Med Imaging. 2021 Feb;40(2):635-647. doi: 10.1109/TMI.2020.3034038.
263
+ // Epub 2021 Feb 2. PMID: 33104507; PMCID: PMC7931442.
264
+ // https://github.com/nibrary/nibrary/blob/main/src/dMRI/tractography/algorithms/ptt
265
+ // Assumes probe count 1, data_support_exponent 1 for now
266
+ // Implemented with new CDF sampling strategy
267
+ // And using initial directions from voxel-wise peaks as in DIPY
268
+
269
+ const int tidx = threadIdx.x;
270
+ const int tidy = threadIdx.y;
271
+
272
+ const int lid = (threadIdx.y*BDIM_X + threadIdx.x) % 32;
273
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
274
+
275
+ __shared__ REAL_T face_cdf_sh[BDIM_Y*DISC_FACE_CNT];
276
+ __shared__ REAL_T vert_pdf_sh[BDIM_Y*DISC_VERT_CNT];
277
+
278
+ __shared__ REAL_T probing_frame_sh[BDIM_Y*9];
279
+ __shared__ REAL_T k1_probe_sh[BDIM_Y];
280
+ __shared__ REAL_T k2_probe_sh[BDIM_Y];
281
+
282
+ __shared__ REAL_T probing_prop_sh[BDIM_Y*9];
283
+ __shared__ REAL_T direc_sh[BDIM_Y*3];
284
+ __shared__ REAL3_T probing_pos_sh[BDIM_Y];
285
+
286
+ REAL_T *__face_cdf_sh = face_cdf_sh + tidy*DISC_FACE_CNT;
287
+ REAL_T *__vert_pdf_sh = vert_pdf_sh + tidy*DISC_VERT_CNT;
288
+
289
+ REAL_T *__probing_frame_sh = probing_frame_sh + tidy*9;
290
+ REAL_T *__k1_probe_sh = k1_probe_sh + tidy;
291
+ REAL_T *__k2_probe_sh = k2_probe_sh + tidy;
292
+
293
+ REAL_T *__probing_prop_sh = probing_prop_sh + tidy*9;
294
+ REAL_T *__direc_sh = direc_sh + tidy*3;
295
+ REAL3_T *__probing_pos_sh = probing_pos_sh + tidy;
296
+
297
+ const REAL_T probe_step_size = ((step_size / PROBE_FRAC) / (PROBE_QUALITY - 1));
298
+ const REAL_T max_curvature = 2.0 * SIN(max_angle / 2.0) / (step_size / PROBE_FRAC); // This seems to work well
299
+ const REAL_T absolpmf_thresh = PMF_THRESHOLD_P * max_d<BDIM_X>(dimt, pmf, REAL_MIN);
300
+
301
+ #if 0
302
+ printf("absolpmf_thresh: %f, max_curvature: %f, probe_step_size: %f\n", absolpmf_thresh, max_curvature, probe_step_size);
303
+ printf("max_angle: %f\n", max_angle);
304
+ printf("step_size: %f\n", step_size);
305
+ #endif
306
+
307
+ REAL_T __tmp;
308
+
309
+ __syncwarp(WMASK);
310
+ if (IS_INIT) {
311
+ if (tidx==0) {
312
+ __frame_sh[0] = dir.x;
313
+ __frame_sh[1] = dir.y;
314
+ __frame_sh[2] = dir.z;
315
+ }
316
+ }
317
+
318
+ const REAL_T first_val = interp4_d<BDIM_X>(
319
+ pos, __frame_sh, pmf,
320
+ dimx, dimy, dimz, dimt,
321
+ odf_sphere_vertices);
322
+ __syncwarp(WMASK);
323
+
324
+ // Calculate __vert_pdf_sh
325
+ bool support_found = false;
326
+ for (int ii = 0; ii < DISC_VERT_CNT; ii++) {
327
+ if (tidx == 0) {
328
+ *__k1_probe_sh = DISC_VERT[ii*2] * max_curvature;
329
+ *__k2_probe_sh = DISC_VERT[ii*2+1] * max_curvature;
330
+ get_probing_frame_d<IS_INIT>(__frame_sh, st, __probing_frame_sh);
331
+ }
332
+ __syncwarp(WMASK);
333
+
334
+ const REAL_T this_support = calculate_data_support_d<BDIM_X>(
335
+ first_val,
336
+ pos, pmf, dimx, dimy, dimz, dimt,
337
+ probe_step_size,
338
+ absolpmf_thresh,
339
+ odf_sphere_vertices,
340
+ __probing_prop_sh, __direc_sh, __probing_pos_sh,
341
+ __k1_probe_sh, __k2_probe_sh,
342
+ __probing_frame_sh);
343
+
344
+ #if 0
345
+ if (threadIdx.y == 1 && ii == 0) {
346
+ printf(" k1_probe: %f, k2_probe %f, support %f for id: %i\n", k1_probe, k2_probe, this_support, tidx);
347
+ }
348
+ #endif
349
+
350
+ if (this_support < PROBE_QUALITY * absolpmf_thresh) {
351
+ if (tidx == 0) {
352
+ __vert_pdf_sh[ii] = 0;
353
+ }
354
+ } else {
355
+ if (tidx == 0) {
356
+ __vert_pdf_sh[ii] = this_support;
357
+ }
358
+ support_found = 1;
359
+ }
360
+ }
361
+ if (support_found == 0) {
362
+ return 0;
363
+ }
364
+
365
+ #if 0
366
+ __syncwarp(WMASK);
367
+ if (threadIdx.y == 1 && threadIdx.x == 0) {
368
+ printArrayAlways("VERT PDF", 8, DISC_VERT_CNT, __vert_pdf_sh);
369
+ }
370
+ __syncwarp(WMASK);
371
+ #endif
372
+
373
+ // Initialize __face_cdf_sh
374
+ for (int ii = tidx; ii < DISC_FACE_CNT; ii+=BDIM_X) {
375
+ __face_cdf_sh[ii] = 0;
376
+ }
377
+ __syncwarp(WMASK);
378
+
379
+ // Move vert to face
380
+ for (int ii = tidx; ii < DISC_FACE_CNT; ii+=BDIM_X) {
381
+ bool all_verts_valid = 1;
382
+ for (int jj = 0; jj < 3; jj++) {
383
+ REAL_T vert_val = __vert_pdf_sh[DISC_FACE[ii*3 + jj]];
384
+ if (vert_val == 0) {
385
+ all_verts_valid = IS_INIT; // On init, even go with faces that are not fully supported
386
+ }
387
+ __face_cdf_sh[ii] += vert_val;
388
+ }
389
+ if (!all_verts_valid) {
390
+ __face_cdf_sh[ii] = 0;
391
+ }
392
+ }
393
+ __syncwarp(WMASK);
394
+
395
+ #if 0
396
+ __syncwarp(WMASK);
397
+ if (threadIdx.y == 1 && threadIdx.x == 0) {
398
+ printArrayAlways("Face PDF", 8, DISC_FACE_CNT, __face_cdf_sh);
399
+ }
400
+ __syncwarp(WMASK);
401
+ #endif
402
+
403
+ // Prefix sum __face_cdf_sh and return 0 if all 0
404
+ prefix_sum_sh_d<BDIM_X>(__face_cdf_sh, DISC_FACE_CNT);
405
+ REAL_T last_cdf = __face_cdf_sh[DISC_FACE_CNT - 1];
406
+
407
+ if (last_cdf == 0) {
408
+ return 0;
409
+ }
410
+
411
+ #if 0
412
+ __syncwarp(WMASK);
413
+ if (threadIdx.y == 1 && threadIdx.x == 0) {
414
+ printArrayAlways("Face CDF", 8, DISC_FACE_CNT, __face_cdf_sh);
415
+ }
416
+ __syncwarp(WMASK);
417
+ #endif
418
+
419
+ // Sample random valid faces randomly
420
+ for (int ii = 0; ii < TRIES_PER_REJECTION_SAMPLING; ii++) {
421
+ if (tidx == 0) {
422
+ REAL_T r1 = curand_uniform(st);
423
+ REAL_T r2 = curand_uniform(st);
424
+ if (r1 + r2 > 1) {
425
+ r1 = 1 - r1;
426
+ r2 = 1 - r2;
427
+ }
428
+
429
+ __tmp = curand_uniform(st) * last_cdf;
430
+ int jj;
431
+ for (jj = 0; jj < DISC_FACE_CNT; jj++) { // TODO: parallelize this
432
+ if (__face_cdf_sh[jj] >= __tmp)
433
+ break;
434
+ }
435
+
436
+ const REAL_T vx0 = max_curvature * DISC_VERT[DISC_FACE[jj*3]*2];
437
+ const REAL_T vx1 = max_curvature * DISC_VERT[DISC_FACE[jj*3+1]*2];
438
+ const REAL_T vx2 = max_curvature * DISC_VERT[DISC_FACE[jj*3+2]*2];
439
+
440
+ const REAL_T vy0 = max_curvature * DISC_VERT[DISC_FACE[jj*3]*2 + 1];
441
+ const REAL_T vy1 = max_curvature * DISC_VERT[DISC_FACE[jj*3+1]*2 + 1];
442
+ const REAL_T vy2 = max_curvature * DISC_VERT[DISC_FACE[jj*3+2]*2 + 1];
443
+
444
+ *__k1_probe_sh = vx0 + r1 * (vx1 - vx0) + r2 * (vx2 - vx0);
445
+ *__k2_probe_sh = vy0 + r1 * (vy1 - vy0) + r2 * (vy2 - vy0);
446
+ get_probing_frame_d<IS_INIT>(__frame_sh, st, __probing_frame_sh);
447
+ }
448
+ __syncwarp(WMASK);
449
+
450
+ const REAL_T this_support = calculate_data_support_d<BDIM_X>(
451
+ first_val,
452
+ pos, pmf, dimx, dimy, dimz, dimt,
453
+ probe_step_size,
454
+ absolpmf_thresh,
455
+ odf_sphere_vertices,
456
+ __probing_prop_sh, __direc_sh, __probing_pos_sh,
457
+ __k1_probe_sh, __k2_probe_sh,
458
+ __probing_frame_sh);
459
+ __syncwarp(WMASK);
460
+
461
+ if (this_support < PROBE_QUALITY * absolpmf_thresh) {
462
+ continue;
463
+ }
464
+
465
+ if (tidx == 0) {
466
+ if (IS_INIT) {
467
+ dirs[0] = dir;
468
+ } else {
469
+ // propagate, but only 1/STEP_FRAC of a step
470
+ prepare_propagator_d(
471
+ *__k1_probe_sh, *__k2_probe_sh,
472
+ step_size/STEP_FRAC, __probing_prop_sh);
473
+ get_probing_frame_d<0>(__frame_sh, st, __probing_frame_sh);
474
+ propagate_frame_d(__probing_prop_sh, __probing_frame_sh, __direc_sh);
475
+ norm3_d(__direc_sh, 0); // this will be scaled by the generic stepping code
476
+ dirs[0] = MAKE_REAL3(__direc_sh[0], __direc_sh[1], __direc_sh[2]);
477
+ }
478
+ }
479
+
480
+ if (tidx < 9) {
481
+ __frame_sh[tidx] = __probing_frame_sh[tidx];
482
+ }
483
+ __syncwarp(WMASK);
484
+ return 1;
485
+ }
486
+ return 0;
487
+ }
488
+
489
+
490
+ template<int BDIM_X,
491
+ int BDIM_Y,
492
+ typename REAL_T,
493
+ typename REAL3_T>
494
+ __device__ bool init_frame_ptt_d(
495
+ curandStatePhilox4_32_10_t *st,
496
+ const REAL_T *__restrict__ pmf,
497
+ const REAL_T max_angle,
498
+ const REAL_T step_size,
499
+ REAL3_T first_step,
500
+ const int dimx, const int dimy, const int dimz, const int dimt,
501
+ REAL3_T seed,
502
+ const REAL3_T *__restrict__ sphere_vertices,
503
+ REAL_T* __frame) {
504
+ const int tidx = threadIdx.x;
505
+
506
+ const int lid = (threadIdx.y*BDIM_X + tidx) % 32;
507
+ const unsigned int WMASK = ((1ull << BDIM_X)-1) << (lid & (~(BDIM_X-1)));
508
+
509
+ bool init_norm_success;
510
+ REAL3_T tmp;
511
+
512
+ // Here we probabilistic find a good intial normal for this initial direction
513
+ init_norm_success = (bool) get_direction_ptt_d<BDIM_X, BDIM_Y, 1>(
514
+ st,
515
+ pmf,
516
+ max_angle,
517
+ step_size,
518
+ MAKE_REAL3(-first_step.x, -first_step.y, -first_step.z),
519
+ __frame,
520
+ dimx, dimy, dimz, dimt,
521
+ seed,
522
+ sphere_vertices,
523
+ &tmp);
524
+ __syncwarp(WMASK);
525
+
526
+ if (!init_norm_success) {
527
+ // try the other direction
528
+ init_norm_success = (bool) get_direction_ptt_d<BDIM_X, BDIM_Y, 1>(
529
+ st,
530
+ pmf,
531
+ max_angle,
532
+ step_size,
533
+ MAKE_REAL3(first_step.x, first_step.y, first_step.z),
534
+ __frame,
535
+ dimx, dimy, dimz, dimt,
536
+ seed,
537
+ sphere_vertices,
538
+ &tmp);
539
+ __syncwarp(WMASK);
540
+
541
+ if (!init_norm_success) { // This is rare
542
+ return false;
543
+ } else {
544
+ if (tidx == 0) {
545
+ for (int ii = 0; ii < 9; ii++) {
546
+ __frame[ii] = -__frame[ii];
547
+ }
548
+ }
549
+ __syncwarp(WMASK);
550
+ }
551
+ }
552
+ if (tidx == 0) {
553
+ for (int ii = 0; ii < 9; ii++) {
554
+ __frame[9+ii] = -__frame[ii]; // save flipped frame for second run
555
+ }
556
+ }
557
+ __syncwarp(WMASK);
558
+ return true;
559
+ }
@@ -0,0 +1,47 @@
1
+ #ifndef __PTT_CUH__
2
+ #define __PTT_CUH__
3
+
4
+ #include "disc.h"
5
+ #include "globals.h"
6
+
7
+ #define STEP_FRAC (20) // divides output step size (usually 0.5) into this many internal steps
8
+ #define PROBE_FRAC (2) // divides output step size (usually 0.5) to find probe length
9
+ #define PROBE_QUALITY (4) // Number of probing steps
10
+ #define SAMPLING_QUALITY (2) // can be 2-7
11
+ #define ALLOW_WEAK_LINK (0)
12
+ #define TRIES_PER_REJECTION_SAMPLING (1024)
13
+ #define K_SMALL ((REAL) 0.0001)
14
+
15
+ #if SAMPLING_QUALITY == 2
16
+ #define DISC_VERT_CNT DISC_2_VERT_CNT
17
+ #define DISC_FACE_CNT DISC_2_FACE_CNT
18
+ __device__ __constant__ REAL DISC_VERT[DISC_VERT_CNT*2] = DISC_2_VERT;
19
+ __device__ __constant__ int DISC_FACE[DISC_FACE_CNT*3] = DISC_2_FACE;
20
+ #elif SAMPLING_QUALITY == 3
21
+ #define DISC_VERT_CNT DISC_3_VERT_CNT
22
+ #define DISC_FACE_CNT DISC_3_FACE_CNT
23
+ __device__ __constant__ REAL DISC_VERT[DISC_VERT_CNT*2] = DISC_3_VERT;
24
+ __device__ __constant__ int DISC_FACE[DISC_FACE_CNT*3] = DISC_3_FACE;
25
+ #elif SAMPLING_QUALITY == 4
26
+ #define DISC_VERT_CNT DISC_4_VERT_CNT
27
+ #define DISC_FACE_CNT DISC_4_FACE_CNT
28
+ __device__ __constant__ REAL DISC_VERT[DISC_VERT_CNT*2] = DISC_4_VERT;
29
+ __device__ __constant__ int DISC_FACE[DISC_FACE_CNT*3] = DISC_4_FACE;
30
+ #elif SAMPLING_QUALITY == 5
31
+ #define DISC_VERT_CNT DISC_5_VERT_CNT
32
+ #define DISC_FACE_CNT DISC_5_FACE_CNT
33
+ __device__ __constant__ REAL DISC_VERT[DISC_VERT_CNT*2] = DISC_5_VERT;
34
+ __device__ __constant__ int DISC_FACE[DISC_FACE_CNT*3] = DISC_5_FACE;
35
+ #elif SAMPLING_QUALITY == 6
36
+ #define DISC_VERT_CNT DISC_6_VERT_CNT
37
+ #define DISC_FACE_CNT DISC_6_FACE_CNT
38
+ __device__ __constant__ REAL DISC_VERT[DISC_VERT_CNT*2] = DISC_6_VERT;
39
+ __device__ __constant__ int DISC_FACE[DISC_FACE_CNT*3] = DISC_6_FACE;
40
+ #elif SAMPLING_QUALITY == 7
41
+ #define DISC_VERT_CNT DISC_7_VERT_CNT
42
+ #define DISC_FACE_CNT DISC_7_FACE_CNT
43
+ __device__ __constant__ REAL DISC_VERT[DISC_VERT_CNT*2] = DISC_7_VERT; // TODO: check if removing __constant__ helps or hurts
44
+ __device__ __constant__ int DISC_FACE[DISC_FACE_CNT*3] = DISC_7_FACE;
45
+ #endif
46
+
47
+ #endif