warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/bvh.cu CHANGED
@@ -9,6 +9,7 @@
9
9
  #include "warp.h"
10
10
  #include "cuda_util.h"
11
11
  #include "bvh.h"
12
+ #include "sort.h"
12
13
 
13
14
  #include <vector>
14
15
  #include <algorithm>
@@ -16,25 +17,32 @@
16
17
  #include <cuda.h>
17
18
  #include <cuda_runtime_api.h>
18
19
 
20
+ #define THRUST_IGNORE_CUB_VERSION_CHECK
21
+
22
+ #include <cub/cub.cuh>
23
+
24
+
19
25
  namespace wp
20
26
  {
21
27
 
22
- __global__ void bvh_refit_kernel(int n, const int* __restrict__ parents, int* __restrict__ child_count, BVHPackedNodeHalf* __restrict__ lowers, BVHPackedNodeHalf* __restrict__ uppers, const bounds3* bounds)
28
+ __global__ void bvh_refit_kernel(int n, const int* __restrict__ parents, int* __restrict__ child_count, BVHPackedNodeHalf* __restrict__ node_lowers, BVHPackedNodeHalf* __restrict__ node_uppers, const vec3* item_lowers, const vec3* item_uppers)
23
29
  {
24
30
  int index = blockDim.x*blockIdx.x + threadIdx.x;
25
31
 
26
32
  if (index < n)
27
33
  {
28
- bool leaf = lowers[index].b;
34
+ bool leaf = node_lowers[index].b;
29
35
 
30
36
  if (leaf)
31
37
  {
32
38
  // update the leaf node
33
- const int leaf_index = lowers[index].i;
34
- const bounds3& b = bounds[leaf_index];
39
+ const int leaf_index = node_lowers[index].i;
35
40
 
36
- make_node(lowers+index, b.lower, leaf_index, true);
37
- make_node(uppers+index, b.upper, 0, false);
41
+ vec3 lower = item_lowers[leaf_index];
42
+ vec3 upper = item_uppers[leaf_index];
43
+
44
+ make_node(node_lowers+index, lower, leaf_index, true);
45
+ make_node(node_uppers+index, upper, 0, false);
38
46
  }
39
47
  else
40
48
  {
@@ -59,6 +67,214 @@ __global__ void bvh_refit_kernel(int n, const int* __restrict__ parents, int* __
59
67
  // if we have are the last thread (such that the parent node is now complete)
60
68
  // then update its bounds and move onto the the next parent in the hierarchy
61
69
  if (finished == 1)
70
+ {
71
+ const int left_child = node_lowers[parent].i;
72
+ const int right_child = node_uppers[parent].i;
73
+
74
+ vec3 left_lower = vec3(node_lowers[left_child].x,
75
+ node_lowers[left_child].y,
76
+ node_lowers[left_child].z);
77
+
78
+ vec3 left_upper = vec3(node_uppers[left_child].x,
79
+ node_uppers[left_child].y,
80
+ node_uppers[left_child].z);
81
+
82
+ vec3 right_lower = vec3(node_lowers[right_child].x,
83
+ node_lowers[right_child].y,
84
+ node_lowers[right_child].z);
85
+
86
+
87
+ vec3 right_upper = vec3(node_uppers[right_child].x,
88
+ node_uppers[right_child].y,
89
+ node_uppers[right_child].z);
90
+
91
+ // union of child bounds
92
+ vec3 lower = min(left_lower, right_lower);
93
+ vec3 upper = max(left_upper, right_upper);
94
+
95
+ // write new BVH nodes
96
+ make_node(node_lowers+parent, lower, left_child, false);
97
+ make_node(node_uppers+parent, upper, right_child, false);
98
+
99
+ // move onto processing the parent
100
+ index = parent;
101
+ }
102
+ else
103
+ {
104
+ // parent not ready (we are the first child), terminate thread
105
+ break;
106
+ }
107
+ }
108
+ }
109
+ }
110
+
111
+
112
+ void bvh_refit_device(BVH& bvh)
113
+ {
114
+ ContextGuard guard(bvh.context);
115
+
116
+ // clear child counters
117
+ memset_device(WP_CURRENT_CONTEXT, bvh.node_counts, 0, sizeof(int)*bvh.max_nodes);
118
+
119
+ wp_launch_device(WP_CURRENT_CONTEXT, bvh_refit_kernel, bvh.num_items, (bvh.num_items, bvh.node_parents, bvh.node_counts, bvh.node_lowers, bvh.node_uppers, bvh.item_lowers, bvh.item_uppers));
120
+ }
121
+
122
+
123
+ /////////////////////////////////////////////////////////////////////////////////////////////
124
+
125
+ // Create a linear BVH as described in Fast and Simple Agglomerative LBVH construction
126
+ // this is a bottom-up clustering method that outputs one node per-leaf
127
+ //
128
+ class LinearBVHBuilderGPU
129
+ {
130
+ public:
131
+
132
+ LinearBVHBuilderGPU();
133
+ ~LinearBVHBuilderGPU();
134
+
135
+ // takes a bvh (host ref), and pointers to the GPU lower and upper bounds for each triangle
136
+ void build(BVH& bvh, const vec3* item_lowers, const vec3* item_uppers, int num_items, bounds3* total_bounds);
137
+
138
+ private:
139
+
140
+ // temporary data used during building
141
+ int* indices;
142
+ int* keys;
143
+ int* deltas;
144
+ int* range_lefts;
145
+ int* range_rights;
146
+ int* num_children;
147
+
148
+ // bounds data when total item bounds built on GPU
149
+ vec3* total_lower;
150
+ vec3* total_upper;
151
+ vec3* total_inv_edges;
152
+ };
153
+
154
+ ////////////////////////////////////////////////////////
155
+
156
+
157
+
158
+ __global__ void compute_morton_codes(const vec3* __restrict__ item_lowers, const vec3* __restrict__ item_uppers, int n, const vec3* grid_lower, const vec3* grid_inv_edges, int* __restrict__ indices, int* __restrict__ keys)
159
+ {
160
+ const int index = blockDim.x*blockIdx.x + threadIdx.x;
161
+
162
+ if (index < n)
163
+ {
164
+ vec3 lower = item_lowers[index];
165
+ vec3 upper = item_uppers[index];
166
+
167
+ vec3 center = 0.5f*(lower+upper);
168
+
169
+ vec3 local = cw_mul((center-grid_lower[0]), grid_inv_edges[0]);
170
+
171
+ // 10-bit Morton codes stored in lower 30bits (1024^3 effective resolution)
172
+ int key = morton3<1024>(local[0], local[1], local[2]);
173
+
174
+ indices[index] = index;
175
+ keys[index] = key;
176
+ }
177
+ }
178
+
179
+ // calculate the index of the first differing bit between two adjacent Morton keys
180
+ __global__ void compute_key_deltas(const int* __restrict__ keys, int* __restrict__ deltas, int n)
181
+ {
182
+ const int index = blockDim.x*blockIdx.x + threadIdx.x;
183
+
184
+ if (index < n)
185
+ {
186
+ int a = keys[index];
187
+ int b = keys[index+1];
188
+
189
+ int x = a^b;
190
+
191
+ deltas[index] = x;// __clz(x);
192
+ }
193
+ }
194
+
195
+ __global__ void build_leaves(const vec3* __restrict__ item_lowers, const vec3* __restrict__ item_uppers, int n, const int* __restrict__ indices, int* __restrict__ range_lefts, int* __restrict__ range_rights, BVHPackedNodeHalf* __restrict__ lowers, BVHPackedNodeHalf* __restrict__ uppers)
196
+ {
197
+ const int index = blockDim.x*blockIdx.x + threadIdx.x;
198
+
199
+ if (index < n)
200
+ {
201
+ const int item = indices[index];
202
+
203
+ vec3 lower = item_lowers[item];
204
+ vec3 upper = item_uppers[item];
205
+
206
+ // write leaf nodes
207
+ lowers[index] = make_node(lower, item, true);
208
+ uppers[index] = make_node(upper, item, false);
209
+
210
+ // write leaf key ranges
211
+ range_lefts[index] = index;
212
+ range_rights[index] = index;
213
+ }
214
+ }
215
+
216
+ // this bottom-up process assigns left and right children and combines bounds to form internal nodes
217
+ // there is one thread launched per-leaf node, each thread calculates it's parent node and assigns
218
+ // itself to either the left or right parent slot, the last child to complete the parent and moves
219
+ // up the hierarchy
220
+ __global__ void build_hierarchy(int n, int* root, const int* __restrict__ deltas, int* __restrict__ num_children, volatile int* __restrict__ range_lefts, volatile int* __restrict__ range_rights, volatile int* __restrict__ parents, volatile BVHPackedNodeHalf* __restrict__ lowers, volatile BVHPackedNodeHalf* __restrict__ uppers)
221
+ {
222
+ int index = blockDim.x*blockIdx.x + threadIdx.x;
223
+
224
+ if (index < n)
225
+ {
226
+ const int internal_offset = n;
227
+
228
+ for (;;)
229
+ {
230
+ int left = range_lefts[index];
231
+ int right = range_rights[index];
232
+
233
+ // check if we are the root node, if so then store out our index and terminate
234
+ if (left == 0 && right == n-1)
235
+ {
236
+ *root = index;
237
+ parents[index] = -1;
238
+
239
+ break;
240
+ }
241
+
242
+ int childCount = 0;
243
+
244
+ int parent;
245
+
246
+ if (left == 0 || (right != n-1 && deltas[right] < deltas[left-1]))
247
+ {
248
+ parent = right + internal_offset;
249
+
250
+ // set parent left child
251
+ parents[index] = parent;
252
+ lowers[parent].i = index;
253
+ range_lefts[parent] = left;
254
+
255
+ // ensure above writes are visible to all threads
256
+ __threadfence();
257
+
258
+ childCount = atomicAdd(&num_children[parent], 1);
259
+ }
260
+ else
261
+ {
262
+ parent = left + internal_offset - 1;
263
+
264
+ // set parent right child
265
+ parents[index] = parent;
266
+ uppers[parent].i = index;
267
+ range_rights[parent] = right;
268
+
269
+ // ensure above writes are visible to all threads
270
+ __threadfence();
271
+
272
+ childCount = atomicAdd(&num_children[parent], 1);
273
+ }
274
+
275
+ // if we have are the last thread (such that the parent node is now complete)
276
+ // then update its bounds and move onto the the next parent in the hierarchy
277
+ if (childCount == 1)
62
278
  {
63
279
  const int left_child = lowers[parent].i;
64
280
  const int right_child = uppers[parent].i;
@@ -72,15 +288,15 @@ __global__ void bvh_refit_kernel(int n, const int* __restrict__ parents, int* __
72
288
  uppers[left_child].z);
73
289
 
74
290
  vec3 right_lower = vec3(lowers[right_child].x,
75
- lowers[right_child].y,
76
- lowers[right_child].z);
291
+ lowers[right_child].y,
292
+ lowers[right_child].z);
77
293
 
78
294
 
79
295
  vec3 right_upper = vec3(uppers[right_child].x,
80
- uppers[right_child].y,
81
- uppers[right_child].z);
296
+ uppers[right_child].y,
297
+ uppers[right_child].z);
82
298
 
83
- // union of child bounds
299
+ // bounds_union of child bounds
84
300
  vec3 lower = min(left_lower, right_lower);
85
301
  vec3 upper = max(left_upper, right_upper);
86
302
 
@@ -100,30 +316,158 @@ __global__ void bvh_refit_kernel(int n, const int* __restrict__ parents, int* __
100
316
  }
101
317
  }
102
318
 
319
+ CUDA_CALLABLE inline vec3 Vec3Max(const vec3& a, const vec3& b) { return wp::max(a, b); }
320
+ CUDA_CALLABLE inline vec3 Vec3Min(const vec3& a, const vec3& b) { return wp::min(a, b); }
103
321
 
104
- void bvh_refit_device(BVH& bvh, const bounds3* b)
322
+ __global__ void compute_total_bounds(const vec3* item_lowers, const vec3* item_uppers, vec3* total_lower, vec3* total_upper, int num_items)
105
323
  {
106
- ContextGuard guard(bvh.context);
324
+ typedef cub::BlockReduce<vec3, 256> BlockReduce;
107
325
 
108
- // clear child counters
109
- memset_device(WP_CURRENT_CONTEXT, bvh.node_counts, 0, sizeof(int)*bvh.max_nodes);
326
+ __shared__ typename BlockReduce::TempStorage temp_storage;
327
+
328
+ const int blockStart = blockDim.x*blockIdx.x;
329
+ const int numValid = ::min(num_items-blockStart, blockDim.x);
330
+
331
+ const int tid = blockStart + threadIdx.x;
332
+
333
+ if (tid < num_items)
334
+ {
335
+ vec3 lower = item_lowers[tid];
336
+ vec3 upper = item_uppers[tid];
337
+
338
+ vec3 block_upper = BlockReduce(temp_storage).Reduce(upper, Vec3Max, numValid);
339
+
340
+ // sync threads because second reduce uses same temp storage as first
341
+ __syncthreads();
342
+
343
+ vec3 block_lower = BlockReduce(temp_storage).Reduce(lower, Vec3Min, numValid);
110
344
 
111
- wp_launch_device(WP_CURRENT_CONTEXT, bvh_refit_kernel, bvh.max_nodes, (bvh.max_nodes, bvh.node_parents, bvh.node_counts, bvh.node_lowers, bvh.node_uppers, b));
345
+ if (threadIdx.x == 0)
346
+ {
347
+ // write out block results, expanded by the radius
348
+ atomic_max(total_upper, block_upper);
349
+ atomic_min(total_lower, block_lower);
350
+ }
351
+ }
352
+ }
353
+
354
+ // compute inverse edge length, this is just done on the GPU to avoid a CPU->GPU sync point
355
+ __global__ void compute_total_inv_edges(const vec3* total_lower, const vec3* total_upper, vec3* total_inv_edges)
356
+ {
357
+ vec3 edges = (total_upper[0]-total_lower[0]);
358
+ edges += vec3(0.0001f);
359
+
360
+ total_inv_edges[0] = vec3(1.0f/edges[0], 1.0f/edges[1], 1.0f/edges[2]);
112
361
  }
113
362
 
114
- __global__ void set_bounds_from_lowers_and_uppers(int n, bounds3* b, const vec3* lowers, const vec3* uppers)
363
+
364
+
365
+ LinearBVHBuilderGPU::LinearBVHBuilderGPU()
366
+ : indices(NULL)
367
+ , keys(NULL)
368
+ , deltas(NULL)
369
+ , range_lefts(NULL)
370
+ , range_rights(NULL)
371
+ , num_children(NULL)
372
+ , total_lower(NULL)
373
+ , total_upper(NULL)
374
+ , total_inv_edges(NULL)
115
375
  {
116
- const int tid = blockIdx.x*blockDim.x + threadIdx.x;
376
+ total_lower = (vec3*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(vec3));
377
+ total_upper = (vec3*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(vec3));
378
+ total_inv_edges = (vec3*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(vec3));
379
+ }
117
380
 
118
- if (tid < n)
381
+ LinearBVHBuilderGPU::~LinearBVHBuilderGPU()
382
+ {
383
+ free_temp_device(WP_CURRENT_CONTEXT, total_lower);
384
+ free_temp_device(WP_CURRENT_CONTEXT, total_upper);
385
+ free_temp_device(WP_CURRENT_CONTEXT, total_inv_edges);
386
+ }
387
+
388
+
389
+
390
+ void LinearBVHBuilderGPU::build(BVH& bvh, const vec3* item_lowers, const vec3* item_uppers, int num_items, bounds3* total_bounds)
391
+ {
392
+ // allocate temporary memory used during building
393
+ indices = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items*2); // *2 for radix sort
394
+ keys = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items*2); // *2 for radix sort
395
+ deltas = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*num_items); // highest differenting bit between keys for item i and i+1
396
+ range_lefts = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
397
+ range_rights = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
398
+ num_children = (int*)alloc_temp_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh.max_nodes);
399
+
400
+ // if total bounds supplied by the host then we just
401
+ // compute our edge length and upload it to the GPU directly
402
+ if (total_bounds)
119
403
  {
120
- b[tid] = bounds3(lowers[tid], uppers[tid]);
404
+ // calculate Morton codes
405
+ vec3 edges = (*total_bounds).edges();
406
+ edges += vec3(0.0001f);
407
+
408
+ vec3 inv_edges = vec3(1.0f/edges[0], 1.0f/edges[1], 1.0f/edges[2]);
409
+
410
+ memcpy_h2d(WP_CURRENT_CONTEXT, total_lower, &total_bounds->lower[0], sizeof(vec3));
411
+ memcpy_h2d(WP_CURRENT_CONTEXT, total_upper, &total_bounds->upper[0], sizeof(vec3));
412
+ memcpy_h2d(WP_CURRENT_CONTEXT, total_inv_edges, &inv_edges[0], sizeof(vec3));
413
+ }
414
+ else
415
+ {
416
+ static vec3 upper(-FLT_MAX);
417
+ static vec3 lower(FLT_MAX);
418
+
419
+ memcpy_h2d(WP_CURRENT_CONTEXT, total_lower, &lower, sizeof(lower));
420
+ memcpy_h2d(WP_CURRENT_CONTEXT, total_upper, &upper, sizeof(upper));
421
+
422
+ // compute the total bounds on the GPU
423
+ wp_launch_device(WP_CURRENT_CONTEXT, compute_total_bounds, num_items, (item_lowers, item_uppers, total_lower, total_upper, num_items));
424
+
425
+ // compute the total edge length
426
+ wp_launch_device(WP_CURRENT_CONTEXT, compute_total_inv_edges, 1, (total_lower, total_upper, total_inv_edges));
121
427
  }
428
+
429
+ // assign 30-bit Morton code based on the centroid of each triangle and bounds for each leaf
430
+ wp_launch_device(WP_CURRENT_CONTEXT, compute_morton_codes, num_items, (item_lowers, item_uppers, num_items, total_lower, total_inv_edges, indices, keys));
431
+
432
+ // sort items based on Morton key (note the 32-bit sort key corresponds to the template parameter to morton3, i.e. 3x9 bit keys combined)
433
+ radix_sort_pairs_device(WP_CURRENT_CONTEXT, keys, indices, num_items);
434
+
435
+ // calculate deltas between adjacent keys
436
+ wp_launch_device(WP_CURRENT_CONTEXT, compute_key_deltas, num_items, (keys, deltas, num_items-1));
437
+
438
+ // initialize leaf nodes
439
+ wp_launch_device(WP_CURRENT_CONTEXT, build_leaves, num_items, (item_lowers, item_uppers, num_items, indices, range_lefts, range_rights, bvh.node_lowers, bvh.node_uppers));
440
+
441
+ // reset children count, this is our atomic counter so we know when an internal node is complete, only used during building
442
+ memset_device(WP_CURRENT_CONTEXT, num_children, 0, sizeof(int)*bvh.max_nodes);
443
+
444
+ // build the tree and internal node bounds
445
+ wp_launch_device(WP_CURRENT_CONTEXT, build_hierarchy, num_items, (num_items, bvh.root, deltas, num_children, range_lefts, range_rights, bvh.node_parents, bvh.node_lowers, bvh.node_uppers));
446
+
447
+ // free temporary memory
448
+ free_temp_device(WP_CURRENT_CONTEXT, indices);
449
+ free_temp_device(WP_CURRENT_CONTEXT, keys);
450
+ free_temp_device(WP_CURRENT_CONTEXT, deltas);
451
+
452
+ free_temp_device(WP_CURRENT_CONTEXT, range_lefts);
453
+ free_temp_device(WP_CURRENT_CONTEXT, range_rights);
454
+ free_temp_device(WP_CURRENT_CONTEXT, num_children);
455
+
456
+ }
457
+
458
+ void bvh_destroy_device(wp::BVH& bvh)
459
+ {
460
+ ContextGuard guard(bvh.context);
461
+
462
+ free_device(WP_CURRENT_CONTEXT, bvh.node_lowers); bvh.node_lowers = NULL;
463
+ free_device(WP_CURRENT_CONTEXT, bvh.node_uppers); bvh.node_uppers = NULL;
464
+ free_device(WP_CURRENT_CONTEXT, bvh.node_parents); bvh.node_parents = NULL;
465
+ free_device(WP_CURRENT_CONTEXT, bvh.node_counts); bvh.node_counts = NULL;
466
+ free_device(WP_CURRENT_CONTEXT, bvh.root); bvh.root = NULL;
122
467
  }
123
468
 
124
469
  } // namespace wp
125
470
 
126
- // refit to data stored in the bvh
127
471
 
128
472
  void bvh_refit_device(uint64_t id)
129
473
  {
@@ -131,12 +475,51 @@ void bvh_refit_device(uint64_t id)
131
475
  if (bvh_get_descriptor(id, bvh))
132
476
  {
133
477
  ContextGuard guard(bvh.context);
134
- wp_launch_device(WP_CURRENT_CONTEXT, wp::set_bounds_from_lowers_and_uppers, bvh.num_bounds, (bvh.num_bounds, bvh.bounds, bvh.lowers, bvh.uppers));
135
478
 
136
- bvh_refit_device(bvh, bvh.bounds);
479
+ bvh_refit_device(bvh);
137
480
  }
481
+ }
138
482
 
483
+ uint64_t bvh_create_device(void* context, wp::vec3* lowers, wp::vec3* uppers, int num_items)
484
+ {
485
+ ContextGuard guard(context);
486
+
487
+ wp::BVH bvh_host;
488
+ bvh_host.num_items = num_items;
489
+ bvh_host.max_nodes = 2*num_items;
490
+ bvh_host.node_lowers = (wp::BVHPackedNodeHalf*)alloc_device(WP_CURRENT_CONTEXT, sizeof(wp::BVHPackedNodeHalf)*bvh_host.max_nodes);
491
+ bvh_host.node_uppers = (wp::BVHPackedNodeHalf*)alloc_device(WP_CURRENT_CONTEXT, sizeof(wp::BVHPackedNodeHalf)*bvh_host.max_nodes);
492
+ bvh_host.node_parents = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh_host.max_nodes);
493
+ bvh_host.node_counts = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh_host.max_nodes);
494
+ bvh_host.root = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int));
495
+ bvh_host.item_lowers = lowers;
496
+ bvh_host.item_uppers = uppers;
497
+
498
+ bvh_host.context = context ? context : cuda_context_get_current();
499
+
500
+ wp::LinearBVHBuilderGPU builder;
501
+ builder.build(bvh_host, lowers, uppers, num_items, NULL);
502
+
503
+ // create device-side BVH descriptor
504
+ wp::BVH* bvh_device = (wp::BVH*)alloc_device(WP_CURRENT_CONTEXT, sizeof(wp::BVH));
505
+ memcpy_h2d(WP_CURRENT_CONTEXT, bvh_device, &bvh_host, sizeof(wp::BVH));
506
+
507
+ uint64_t bvh_id = (uint64_t)bvh_device;
508
+ wp::bvh_add_descriptor(bvh_id, bvh_host);
509
+
510
+ return bvh_id;
139
511
  }
140
512
 
141
513
 
514
+ void bvh_destroy_device(uint64_t id)
515
+ {
516
+ wp::BVH bvh;
517
+ if (wp::bvh_get_descriptor(id, bvh))
518
+ {
519
+ wp::bvh_destroy_device(bvh);
520
+ wp::bvh_rem_descriptor(id);
142
521
 
522
+ // free descriptor
523
+ free_device(WP_CURRENT_CONTEXT, (void*)id);
524
+ }
525
+ }
warp/native/bvh.h CHANGED
@@ -113,7 +113,7 @@ struct BVHPackedNodeHalf
113
113
  };
114
114
 
115
115
  struct BVH
116
- {
116
+ {
117
117
  BVHPackedNodeHalf* node_lowers;
118
118
  BVHPackedNodeHalf* node_uppers;
119
119
 
@@ -123,33 +123,22 @@ struct BVH
123
123
 
124
124
  int max_depth;
125
125
  int max_nodes;
126
- int num_nodes;
127
-
128
- int root;
126
+ int num_nodes;
127
+
128
+ // pointer (CPU or GPU) to a single integer index in node_lowers, node_uppers
129
+ // representing the root of the tree, this is not always the first node
130
+ // for bottom-up builders
131
+ int* root;
129
132
 
130
- vec3* lowers;
131
- vec3* uppers;
132
- bounds3* bounds;
133
- int num_bounds;
133
+ // item bounds are not owned by the BVH but by the caller
134
+ vec3* item_lowers;
135
+ vec3* item_uppers;
136
+ int num_items;
134
137
 
138
+ // cuda context
135
139
  void* context;
136
140
  };
137
141
 
138
- #if !defined(__CUDA_ARCH__)
139
-
140
- BVH bvh_create(const bounds3* bounds, int num_bounds);
141
-
142
- void bvh_destroy_host(BVH& bvh);
143
- void bvh_destroy_device(BVH& bvh);
144
-
145
- void bvh_refit_host(BVH& bvh, const bounds3* bounds);
146
- void bvh_refit_device(BVH& bvh, const bounds3* bounds);
147
-
148
- // copy host BVH to device
149
- BVH bvh_clone(void* context, const BVH& bvh_host);
150
-
151
- #endif // !__CUDA_ARCH__
152
-
153
142
  CUDA_CALLABLE inline BVHPackedNodeHalf make_node(const vec3& bound, int child, bool leaf)
154
143
  {
155
144
  BVHPackedNodeHalf n;
@@ -162,7 +151,7 @@ CUDA_CALLABLE inline BVHPackedNodeHalf make_node(const vec3& bound, int child, b
162
151
  return n;
163
152
  }
164
153
 
165
- // variation of make_node through volatile pointers used in BuildHierarchy
154
+ // variation of make_node through volatile pointers used in build_hierarchy
166
155
  CUDA_CALLABLE inline void make_node(volatile BVHPackedNodeHalf* n, const vec3& bound, int child, bool leaf)
167
156
  {
168
157
  n->x = bound[0];
@@ -211,7 +200,7 @@ CUDA_CALLABLE inline BVH bvh_get(uint64_t id)
211
200
  CUDA_CALLABLE inline int bvh_get_num_bounds(uint64_t id)
212
201
  {
213
202
  BVH bvh = bvh_get(id);
214
- return bvh.num_bounds;
203
+ return bvh.num_items;
215
204
  }
216
205
 
217
206
 
@@ -220,11 +209,20 @@ CUDA_CALLABLE inline int bvh_get_num_bounds(uint64_t id)
220
209
  struct bvh_query_t
221
210
  {
222
211
  CUDA_CALLABLE bvh_query_t()
212
+ : bvh(),
213
+ stack(),
214
+ count(0),
215
+ is_ray(false),
216
+ input_lower(),
217
+ input_upper(),
218
+ bounds_nr(0)
219
+ {}
220
+
221
+ // Required for adjoint computations.
222
+ CUDA_CALLABLE inline bvh_query_t& operator+=(const bvh_query_t& other)
223
223
  {
224
+ return *this;
224
225
  }
225
- CUDA_CALLABLE bvh_query_t(int)
226
- {
227
- } // for backward pass
228
226
 
229
227
  BVH bvh;
230
228
 
@@ -257,17 +255,8 @@ CUDA_CALLABLE inline bvh_query_t bvh_query(
257
255
  query.bvh = bvh;
258
256
  query.is_ray = is_ray;
259
257
 
260
-
261
- // if no bvh nodes, return empty query.
262
- if (bvh.num_nodes == 0)
263
- {
264
- query.count = 0;
265
- return query;
266
- }
267
-
268
- // optimization: make the latest
269
-
270
- query.stack[0] = bvh.root;
258
+ // optimization: make the latest
259
+ query.stack[0] = *bvh.root;
271
260
  query.count = 1;
272
261
  query.input_lower = lower;
273
262
  query.input_upper = upper;
@@ -422,17 +411,19 @@ CUDA_CALLABLE inline void adj_bvh_query_next(bvh_query_t& query, int& index, bvh
422
411
 
423
412
  }
424
413
 
425
-
426
-
427
-
428
414
  CUDA_CALLABLE bool bvh_get_descriptor(uint64_t id, BVH& bvh);
429
415
  CUDA_CALLABLE void bvh_add_descriptor(uint64_t id, const BVH& bvh);
430
416
  CUDA_CALLABLE void bvh_rem_descriptor(uint64_t id);
431
417
 
418
+ #if !__CUDA_ARCH__
432
419
 
420
+ void bvh_destroy_host(wp::BVH& bvh);
421
+ void bvh_refit_host(wp::BVH& bvh);
433
422
 
423
+ void bvh_destroy_device(wp::BVH& bvh);
424
+ void bvh_refit_device(uint64_t id);
434
425
 
435
-
426
+ #endif
436
427
 
437
428
  } // namespace wp
438
429
 
@@ -63,14 +63,14 @@ extern void __jit_debug_register_code();
63
63
  }
64
64
 
65
65
  namespace wp {
66
-
66
+
67
67
  #if defined (_WIN32)
68
- // Windows defaults to using the COFF binary format (aka. "msvc" in the target triple).
69
- // Override it to use the ELF format to support DWARF debug info, but keep using the
70
- // Microsoft calling convention (see also https://llvm.org/docs/DebuggingJITedCode.html).
71
- static const char* target_triple = "x86_64-pc-windows-elf";
68
+ // Windows defaults to using the COFF binary format (aka. "msvc" in the target triple).
69
+ // Override it to use the ELF format to support DWARF debug info, but keep using the
70
+ // Microsoft calling convention (see also https://llvm.org/docs/DebuggingJITedCode.html).
71
+ static const char* target_triple = "x86_64-pc-windows-elf";
72
72
  #else
73
- static const char* target_triple = LLVM_DEFAULT_TARGET_TRIPLE;
73
+ static const char* target_triple = LLVM_DEFAULT_TARGET_TRIPLE;
74
74
  #endif
75
75
 
76
76
  static void initialize_llvm()
@@ -95,6 +95,11 @@ static std::unique_ptr<llvm::Module> cpp_to_llvm(const std::string& input_file,
95
95
  args.push_back("-triple");
96
96
  args.push_back(target_triple);
97
97
 
98
+ #if defined(__x86_64__) || defined(_M_X64)
99
+ args.push_back("-target-feature");
100
+ args.push_back("+f16c"); // Enables support for _Float16
101
+ #endif
102
+
98
103
  clang::IntrusiveRefCntPtr<clang::DiagnosticOptions> diagnostic_options = new clang::DiagnosticOptions();
99
104
  std::unique_ptr<clang::TextDiagnosticPrinter> text_diagnostic_printer =
100
105
  std::make_unique<clang::TextDiagnosticPrinter>(llvm::errs(), &*diagnostic_options);
@@ -354,6 +359,7 @@ WP_API int load_obj(const char* object_file, const char* module_name)
354
359
  SYMBOL(log10f), SYMBOL_T(log10, double(*)(double)),
355
360
  SYMBOL(expf), SYMBOL_T(exp, double(*)(double)),
356
361
  SYMBOL(sqrtf), SYMBOL_T(sqrt, double(*)(double)),
362
+ SYMBOL(cbrtf), SYMBOL_T(cbrt, double(*)(double)),
357
363
  SYMBOL(powf), SYMBOL_T(pow, double(*)(double, double)),
358
364
  SYMBOL(floorf), SYMBOL_T(floor, double(*)(double)),
359
365
  SYMBOL(ceilf), SYMBOL_T(ceil, double(*)(double)),
@@ -382,8 +388,7 @@ WP_API int load_obj(const char* object_file, const char* module_name)
382
388
  SYMBOL(__chkstk),
383
389
  #elif defined(__APPLE__)
384
390
  SYMBOL(__bzero),
385
- SYMBOL(__sincos_stret),
386
- SYMBOL(__sincosf_stret),
391
+ SYMBOL(__sincos_stret), SYMBOL(__sincosf_stret),
387
392
  #else
388
393
  SYMBOL(sincosf), SYMBOL_T(sincos, void(*)(double,double*,double*)),
389
394
  #endif