warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/bvh.h CHANGED
@@ -20,7 +20,12 @@
20
20
  #include "builtin.h"
21
21
  #include "intersect.h"
22
22
 
23
- #define BVH_LEAF_SIZE (4)
23
+ #ifdef __CUDA_ARCH__
24
+ #define BVH_SHARED_STACK 1
25
+ #else
26
+ #define BVH_SHARED_STACK 0
27
+ #endif
28
+
24
29
  #define SAH_NUM_BUCKETS (16)
25
30
  #define USE_LOAD4
26
31
  #define BVH_QUERY_STACK_SIZE (32)
@@ -34,133 +39,133 @@ namespace wp
34
39
 
35
40
  struct bounds3
36
41
  {
37
- CUDA_CALLABLE inline bounds3() : lower( FLT_MAX)
38
- , upper(-FLT_MAX) {}
39
-
40
- CUDA_CALLABLE inline bounds3(const vec3& lower, const vec3& upper) : lower(lower), upper(upper) {}
41
-
42
- CUDA_CALLABLE inline vec3 center() const { return 0.5f*(lower+upper); }
43
- CUDA_CALLABLE inline vec3 edges() const { return upper-lower; }
44
-
45
- CUDA_CALLABLE inline void expand(float r)
46
- {
47
- lower -= vec3(r);
48
- upper += vec3(r);
49
- }
50
-
51
- CUDA_CALLABLE inline void expand(const vec3& r)
52
- {
53
- lower -= r;
54
- upper += r;
55
- }
56
-
57
- CUDA_CALLABLE inline bool empty() const { return lower[0] >= upper[0] || lower[1] >= upper[1] || lower[2] >= upper[2]; }
58
-
59
- CUDA_CALLABLE inline bool overlaps(const vec3& p) const
60
- {
61
- if (p[0] < lower[0] ||
62
- p[1] < lower[1] ||
63
- p[2] < lower[2] ||
64
- p[0] > upper[0] ||
65
- p[1] > upper[1] ||
66
- p[2] > upper[2])
67
- {
68
- return false;
69
- }
70
- else
71
- {
72
- return true;
73
- }
74
- }
75
-
76
- CUDA_CALLABLE inline bool overlaps(const bounds3& b) const
77
- {
78
- if (lower[0] > b.upper[0] ||
79
- lower[1] > b.upper[1] ||
80
- lower[2] > b.upper[2] ||
81
- upper[0] < b.lower[0] ||
82
- upper[1] < b.lower[1] ||
83
- upper[2] < b.lower[2])
84
- {
85
- return false;
86
- }
87
- else
88
- {
89
- return true;
90
- }
91
- }
92
-
93
- CUDA_CALLABLE inline bool overlaps(const vec3& b_lower, const vec3& b_upper) const
94
- {
95
- if (lower[0] > b_upper[0] ||
96
- lower[1] > b_upper[1] ||
97
- lower[2] > b_upper[2] ||
98
- upper[0] < b_lower[0] ||
99
- upper[1] < b_lower[1] ||
100
- upper[2] < b_lower[2])
101
- {
102
- return false;
103
- }
104
- else
105
- {
106
- return true;
107
- }
108
- }
109
-
110
- CUDA_CALLABLE inline void add_point(const vec3& p)
111
- {
112
- lower = min(lower, p);
113
- upper = max(upper, p);
114
- }
115
-
116
- CUDA_CALLABLE inline void add_bounds(const vec3& lower_other, const vec3& upper_other)
117
- {
118
- // lower_other will only impact the lower of the new bounds
119
- // upper_other will only impact the upper of the new bounds
120
- // this costs only half of the computation of adding lower_other and upper_other separately
121
- lower = min(lower, lower_other);
122
- upper = max(upper, upper_other);
123
- }
124
-
125
- CUDA_CALLABLE inline float area() const
126
- {
127
- vec3 e = upper-lower;
128
- return 2.0f*(e[0]*e[1] + e[0]*e[2] + e[1]*e[2]);
129
- }
130
-
131
- vec3 lower;
132
- vec3 upper;
42
+ CUDA_CALLABLE inline bounds3() : lower( FLT_MAX)
43
+ , upper(-FLT_MAX) {}
44
+
45
+ CUDA_CALLABLE inline bounds3(const vec3& lower, const vec3& upper) : lower(lower), upper(upper) {}
46
+
47
+ CUDA_CALLABLE inline vec3 center() const { return 0.5f*(lower+upper); }
48
+ CUDA_CALLABLE inline vec3 edges() const { return upper-lower; }
49
+
50
+ CUDA_CALLABLE inline void expand(float r)
51
+ {
52
+ lower -= vec3(r);
53
+ upper += vec3(r);
54
+ }
55
+
56
+ CUDA_CALLABLE inline void expand(const vec3& r)
57
+ {
58
+ lower -= r;
59
+ upper += r;
60
+ }
61
+
62
+ CUDA_CALLABLE inline bool empty() const { return lower[0] >= upper[0] || lower[1] >= upper[1] || lower[2] >= upper[2]; }
63
+
64
+ CUDA_CALLABLE inline bool overlaps(const vec3& p) const
65
+ {
66
+ if (p[0] < lower[0] ||
67
+ p[1] < lower[1] ||
68
+ p[2] < lower[2] ||
69
+ p[0] > upper[0] ||
70
+ p[1] > upper[1] ||
71
+ p[2] > upper[2])
72
+ {
73
+ return false;
74
+ }
75
+ else
76
+ {
77
+ return true;
78
+ }
79
+ }
80
+
81
+ CUDA_CALLABLE inline bool overlaps(const bounds3& b) const
82
+ {
83
+ if (lower[0] > b.upper[0] ||
84
+ lower[1] > b.upper[1] ||
85
+ lower[2] > b.upper[2] ||
86
+ upper[0] < b.lower[0] ||
87
+ upper[1] < b.lower[1] ||
88
+ upper[2] < b.lower[2])
89
+ {
90
+ return false;
91
+ }
92
+ else
93
+ {
94
+ return true;
95
+ }
96
+ }
97
+
98
+ CUDA_CALLABLE inline bool overlaps(const vec3& b_lower, const vec3& b_upper) const
99
+ {
100
+ if (lower[0] > b_upper[0] ||
101
+ lower[1] > b_upper[1] ||
102
+ lower[2] > b_upper[2] ||
103
+ upper[0] < b_lower[0] ||
104
+ upper[1] < b_lower[1] ||
105
+ upper[2] < b_lower[2])
106
+ {
107
+ return false;
108
+ }
109
+ else
110
+ {
111
+ return true;
112
+ }
113
+ }
114
+
115
+ CUDA_CALLABLE inline void add_point(const vec3& p)
116
+ {
117
+ lower = min(lower, p);
118
+ upper = max(upper, p);
119
+ }
120
+
121
+ CUDA_CALLABLE inline void add_bounds(const vec3& lower_other, const vec3& upper_other)
122
+ {
123
+ // lower_other will only impact the lower of the new bounds
124
+ // upper_other will only impact the upper of the new bounds
125
+ // this costs only half of the computation of adding lower_other and upper_other separately
126
+ lower = min(lower, lower_other);
127
+ upper = max(upper, upper_other);
128
+ }
129
+
130
+ CUDA_CALLABLE inline float area() const
131
+ {
132
+ vec3 e = upper-lower;
133
+ return 2.0f*(e[0]*e[1] + e[0]*e[2] + e[1]*e[2]);
134
+ }
135
+
136
+ vec3 lower;
137
+ vec3 upper;
133
138
  };
134
139
 
135
140
  CUDA_CALLABLE inline bounds3 bounds_union(const bounds3& a, const vec3& b)
136
141
  {
137
- return bounds3(min(a.lower, b), max(a.upper, b));
142
+ return bounds3(min(a.lower, b), max(a.upper, b));
138
143
  }
139
144
 
140
145
  CUDA_CALLABLE inline bounds3 bounds_union(const bounds3& a, const bounds3& b)
141
146
  {
142
- return bounds3(min(a.lower, b.lower), max(a.upper, b.upper));
147
+ return bounds3(min(a.lower, b.lower), max(a.upper, b.upper));
143
148
  }
144
149
 
145
150
  CUDA_CALLABLE inline bounds3 bounds_intersection(const bounds3& a, const bounds3& b)
146
151
  {
147
- return bounds3(max(a.lower, b.lower), min(a.upper, b.upper));
152
+ return bounds3(max(a.lower, b.lower), min(a.upper, b.upper));
148
153
  }
149
154
 
150
155
  struct BVHPackedNodeHalf
151
156
  {
152
- float x;
153
- float y;
154
- float z;
155
- // For non-leaf nodes:
156
- // - 'lower.i' represents the index of the left child node.
157
- // - 'upper.i' represents the index of the right child node.
158
- //
159
- // For leaf nodes:
160
- // - 'lower.i' indicates the start index of the primitives in 'primitive_indices'.
161
- // - 'upper.i' indicates the index just after the last primitive in 'primitive_indices'
162
- unsigned int i : 31;
163
- unsigned int b : 1;
157
+ float x;
158
+ float y;
159
+ float z;
160
+ // For non-leaf nodes:
161
+ // - 'lower.i' represents the index of the left child node.
162
+ // - 'upper.i' represents the index of the right child node.
163
+ //
164
+ // For leaf nodes:
165
+ // - 'lower.i' indicates the start index of the primitives in 'primitive_indices'.
166
+ // - 'upper.i' indicates the index just after the last primitive in 'primitive_indices'
167
+ unsigned int i : 31;
168
+ unsigned int b : 1;
164
169
  };
165
170
 
166
171
  struct BVH
@@ -168,30 +173,32 @@ struct BVH
168
173
  BVHPackedNodeHalf* node_lowers;
169
174
  BVHPackedNodeHalf* node_uppers;
170
175
 
171
- // used for fast refits
172
- int* node_parents;
173
- int* node_counts;
174
- // reordered primitive indices corresponds to the ordering of leaf nodes
175
- int* primitive_indices;
176
-
177
- int max_depth;
178
- int max_nodes;
179
- int num_nodes;
180
- // since we use packed leaf nodes, the number of them is no longer the number of items, but variable
181
- int num_leaf_nodes;
182
-
183
- // pointer (CPU or GPU) to a single integer index in node_lowers, node_uppers
184
- // representing the root of the tree, this is not always the first node
185
- // for bottom-up builders
186
- int* root;
187
-
188
- // item bounds are not owned by the BVH but by the caller
176
+ // used for fast refits
177
+ int* node_parents;
178
+ int* node_counts;
179
+ // reordered primitive indices corresponds to the ordering of leaf nodes
180
+ int* primitive_indices;
181
+
182
+ int max_depth;
183
+ int max_nodes;
184
+ int num_nodes;
185
+ // since we use packed leaf nodes, the number of them is no longer the number of items, but variable
186
+ int num_leaf_nodes;
187
+
188
+ // pointer (CPU or GPU) to a single integer index in node_lowers, node_uppers
189
+ // representing the root of the tree, this is not always the first node
190
+ // for bottom-up builders
191
+ int* root;
192
+
193
+ // item bounds are not owned by the BVH but by the caller
189
194
  vec3* item_lowers;
190
- vec3* item_uppers;
191
- int num_items;
195
+ vec3* item_uppers;
196
+ int num_items;
192
197
 
193
- // cuda context
194
- void* context;
198
+ int leaf_size;
199
+
200
+ // cuda context
201
+ void* context;
195
202
  };
196
203
 
197
204
  CUDA_CALLABLE inline BVHPackedNodeHalf make_node(const vec3& bound, int child, bool leaf)
@@ -220,17 +227,18 @@ CUDA_CALLABLE inline void make_node(volatile BVHPackedNodeHalf* n, const vec3& b
220
227
  __device__ inline wp::BVHPackedNodeHalf bvh_load_node(const wp::BVHPackedNodeHalf* nodes, int index)
221
228
  {
222
229
  #ifdef USE_LOAD4
223
- //return (const wp::BVHPackedNodeHalf&)(__ldg((const float4*)(nodes)+index));
224
- return (const wp::BVHPackedNodeHalf&)(*((const float4*)(nodes)+index));
230
+ float4 f4 = __ldg((const float4*)(nodes)+index);
231
+ return (const wp::BVHPackedNodeHalf&)f4;
232
+ //return (const wp::BVHPackedNodeHalf&)(*((const float4*)(nodes)+index));
225
233
  #else
226
- return nodes[index];
234
+ return nodes[index];
227
235
  #endif // USE_LOAD4
228
236
 
229
237
  }
230
238
  #else
231
239
  inline wp::BVHPackedNodeHalf bvh_load_node(const wp::BVHPackedNodeHalf* nodes, int index)
232
240
  {
233
- return nodes[index];
241
+ return nodes[index];
234
242
  }
235
243
  #endif // __CUDACC__
236
244
 
@@ -272,10 +280,22 @@ CUDA_CALLABLE inline BVH bvh_get(uint64_t id)
272
280
 
273
281
  CUDA_CALLABLE inline int bvh_get_num_bounds(uint64_t id)
274
282
  {
275
- BVH bvh = bvh_get(id);
276
- return bvh.num_items;
283
+ BVH bvh = bvh_get(id);
284
+ return bvh.num_items;
277
285
  }
278
286
 
287
+ // represents a strided stack in shared memory
288
+ // so each level of the stack is stored contiguously
289
+ // across the block
290
+ struct bvh_stack_t
291
+ {
292
+ inline int operator[](int depth) const { return ptr[depth*WP_TILE_BLOCK_DIM]; }
293
+ inline int& operator[](int depth) { return ptr[depth*WP_TILE_BLOCK_DIM]; }
294
+
295
+ int* ptr;
296
+
297
+ };
298
+
279
299
 
280
300
  // stores state required to traverse the BVH nodes that
281
301
  // overlap with a query AABB.
@@ -289,7 +309,7 @@ struct bvh_query_t
289
309
  input_lower(),
290
310
  input_upper(),
291
311
  bounds_nr(0),
292
- primitive_counter(-1)
312
+ primitive_counter(-1)
293
313
  {}
294
314
 
295
315
  // Required for adjoint computations.
@@ -300,214 +320,194 @@ struct bvh_query_t
300
320
 
301
321
  BVH bvh;
302
322
 
303
- // BVH traversal stack:
304
- int stack[BVH_QUERY_STACK_SIZE];
305
- int count;
323
+ // BVH traversal stack:
324
+ #if BVH_SHARED_STACK
325
+ bvh_stack_t stack;
326
+ #else
327
+ int stack[BVH_QUERY_STACK_SIZE];
328
+ #endif
306
329
 
307
- // >= 0 if currently in a packed leaf node
308
- int primitive_counter;
309
-
330
+ int count;
331
+
332
+ // >= 0 if currently in a packed leaf node
333
+ int primitive_counter;
334
+
310
335
  // inputs
311
336
  wp::vec3 input_lower; // start for ray
312
337
  wp::vec3 input_upper; // dir for ray
313
338
 
314
- int bounds_nr;
315
- bool is_ray;
339
+ int bounds_nr;
340
+ bool is_ray;
316
341
  };
317
342
 
318
343
  CUDA_CALLABLE inline bool bvh_query_intersection_test(const bvh_query_t& query, const vec3& node_lower, const vec3& node_upper)
319
344
  {
320
- if (query.is_ray)
321
- {
322
- float t = 0.0f;
323
- return intersect_ray_aabb(query.input_lower, query.input_upper, node_lower, node_upper, t);
324
- }
325
- else
326
- {
327
- return intersect_aabb_aabb(query.input_lower, query.input_upper, node_lower, node_upper);
328
- }
345
+ if (query.is_ray)
346
+ {
347
+ float t = 0.0f;
348
+ return intersect_ray_aabb(query.input_lower, query.input_upper, node_lower, node_upper, t);
349
+ }
350
+ else
351
+ {
352
+ return intersect_aabb_aabb(query.input_lower, query.input_upper, node_lower, node_upper);
353
+ }
329
354
  }
330
355
 
331
356
  CUDA_CALLABLE inline bvh_query_t bvh_query(
332
- uint64_t id, bool is_ray, const vec3& lower, const vec3& upper)
357
+ uint64_t id, bool is_ray, const vec3& lower, const vec3& upper)
333
358
  {
334
- // This routine traverses the BVH tree until it finds
335
- // the first overlapping bound.
336
-
337
- // initialize empty
338
- bvh_query_t query;
359
+ // This routine traverses the BVH tree until it finds
360
+ // the first overlapping bound.
361
+
362
+ // initialize empty
363
+ bvh_query_t query;
364
+
365
+ #if BVH_SHARED_STACK
366
+ __shared__ int stack[BVH_QUERY_STACK_SIZE*WP_TILE_BLOCK_DIM];
367
+ query.stack.ptr = &stack[threadIdx.x];
368
+ #endif
339
369
 
340
- query.bounds_nr = -1;
370
+ query.bounds_nr = -1;
341
371
 
342
- BVH bvh = bvh_get(id);
372
+ BVH bvh = bvh_get(id);
343
373
 
344
- query.bvh = bvh;
345
- query.is_ray = is_ray;
374
+ query.bvh = bvh;
375
+ query.is_ray = is_ray;
346
376
 
347
- // optimization: make the latest
348
- query.stack[0] = *bvh.root;
349
- query.count = 1;
350
- query.input_lower = lower;
351
- query.input_upper = upper;
377
+ // optimization: make the latest
378
+ query.stack[0] = *bvh.root;
379
+ query.count = 1;
380
+ query.input_lower = lower;
381
+ query.input_upper = upper;
352
382
 
353
- // Navigate through the bvh, find the first overlapping leaf node.
354
- while (query.count)
355
- {
356
- const int node_index = query.stack[--query.count];
357
- BVHPackedNodeHalf node_lower = bvh_load_node(bvh.node_lowers, node_index);
358
- BVHPackedNodeHalf node_upper = bvh_load_node(bvh.node_uppers, node_index);
383
+ // Navigate through the bvh, find the first overlapping leaf node.
384
+ while (query.count)
385
+ {
386
+ const int node_index = query.stack[--query.count];
387
+ BVHPackedNodeHalf node_lower = bvh_load_node(bvh.node_lowers, node_index);
388
+ BVHPackedNodeHalf node_upper = bvh_load_node(bvh.node_uppers, node_index);
359
389
 
360
390
  if (!bvh_query_intersection_test(query, reinterpret_cast<vec3&>(node_lower), reinterpret_cast<vec3&>(node_upper)))
361
- {
362
- continue;
363
- }
364
-
365
- const int left_index = node_lower.i;
366
- const int right_index = node_upper.i;
367
- // Make bounds from this AABB
368
- if (node_lower.b)
369
- {
370
- // Reached a leaf node, point to its first primitive
371
- // Back up one level and return
372
- query.primitive_counter = left_index;
373
- query.stack[query.count++] = node_index;
374
- return query;
375
- }
376
- else
377
- {
378
- query.stack[query.count++] = left_index;
379
- query.stack[query.count++] = right_index;
380
- }
381
- }
382
-
383
- return query;
391
+ {
392
+ continue;
393
+ }
394
+
395
+ const int left_index = node_lower.i;
396
+ const int right_index = node_upper.i;
397
+ // Make bounds from this AABB
398
+ if (node_lower.b)
399
+ {
400
+ // Reached a leaf node, point to its first primitive
401
+ // Back up one level and return
402
+ query.primitive_counter = 0;
403
+ query.stack[query.count++] = node_index;
404
+ return query;
405
+ }
406
+ else
407
+ {
408
+ query.stack[query.count++] = left_index;
409
+ query.stack[query.count++] = right_index;
410
+ }
411
+ }
412
+
413
+ return query;
384
414
  }
385
415
 
386
416
  CUDA_CALLABLE inline bvh_query_t bvh_query_aabb(
387
417
  uint64_t id, const vec3& lower, const vec3& upper)
388
418
  {
389
- return bvh_query(id, false, lower, upper);
419
+ return bvh_query(id, false, lower, upper);
390
420
  }
391
421
 
392
-
393
- CUDA_CALLABLE inline bvh_query_t bvh_query_ray(
394
- uint64_t id, const vec3& start, const vec3& dir)
422
+ CUDA_CALLABLE inline bvh_query_t bvh_query_ray(uint64_t id, const vec3& start, const vec3& dir)
395
423
  {
396
- return bvh_query(id, true, start, 1.0f / dir);
424
+ return bvh_query(id, true, start, 1.0f / dir);
397
425
  }
398
426
 
399
427
  //Stub
400
428
  CUDA_CALLABLE inline void adj_bvh_query_aabb(uint64_t id, const vec3& lower, const vec3& upper,
401
- uint64_t, vec3&, vec3&, bvh_query_t&)
429
+ uint64_t, vec3&, vec3&, bvh_query_t&)
402
430
  {
403
431
  }
404
432
 
405
433
 
406
434
  CUDA_CALLABLE inline void adj_bvh_query_ray(uint64_t id, const vec3& start, const vec3& dir,
407
- uint64_t, vec3&, vec3&, bvh_query_t&)
435
+ uint64_t, vec3&, vec3&, bvh_query_t&)
408
436
  {
409
437
  }
410
438
 
411
439
 
412
440
  CUDA_CALLABLE inline bool bvh_query_next(bvh_query_t& query, int& index)
413
441
  {
414
- BVH bvh = query.bvh;
415
-
416
- if (query.primitive_counter != -1)
417
- // currently in a leaf node which is the last node in the stack
418
- {
419
- const int node_index = query.stack[query.count - 1];
420
- BVHPackedNodeHalf node_lower = bvh_load_node(bvh.node_lowers, node_index);
421
- BVHPackedNodeHalf node_upper = bvh_load_node(bvh.node_uppers, node_index);
422
-
423
- const int end = node_upper.i;
424
- for (int primitive_counter = query.primitive_counter; primitive_counter < end; primitive_counter++)
425
- {
426
- int primitive_index = bvh.primitive_indices[primitive_counter];
427
- if (bvh_query_intersection_test(query, bvh.item_lowers[primitive_index], bvh.item_uppers[primitive_index]))
428
- {
429
- if (primitive_counter < end - 1)
430
- // still need to come back to this leaf node for the leftover primitives
431
- {
432
- query.primitive_counter = primitive_counter + 1;
433
- }
434
- else
435
- // no need to come back to this leaf node
436
- {
437
- query.count--;
438
- query.primitive_counter = -1;
439
- }
440
- index = primitive_index;
441
- query.bounds_nr = primitive_index;
442
-
443
- return true;
444
- }
445
- }
446
- // if we reach here that means we have finished the current leaf node without finding intersections
447
- query.primitive_counter = -1;
448
- // remove the leaf node from the back of the stack because it is finished
449
- // and continue the bvh traversal
450
- query.count--;
451
- }
452
-
453
- // Navigate through the bvh, find the first overlapping leaf node.
454
- while (query.count)
455
- {
456
- const int node_index = query.stack[--query.count];
457
- BVHPackedNodeHalf node_lower = bvh_load_node(bvh.node_lowers, node_index);
458
- BVHPackedNodeHalf node_upper = bvh_load_node(bvh.node_uppers, node_index);
459
-
460
- const int left_index = node_lower.i;
461
- const int right_index = node_upper.i;
462
-
463
- wp::vec3 lower_pos(node_lower.x, node_lower.y, node_lower.z);
464
- wp::vec3 upper_pos(node_upper.x, node_upper.y, node_upper.z);
465
- wp::bounds3 current_bounds(lower_pos, upper_pos);
442
+ BVH bvh = query.bvh;
466
443
 
467
- if (!bvh_query_intersection_test(query, reinterpret_cast<vec3&>(node_lower), reinterpret_cast<vec3&>(node_upper)))
468
- {
469
- continue;
470
- }
471
-
472
- if (node_lower.b)
473
- {
474
- // found leaf, loop through its content primitives
475
- const int start = left_index;
476
- const int end = right_index;
477
-
478
- for (int primitive_counter = start; primitive_counter < end; primitive_counter++)
479
- {
480
- int primitive_index = bvh.primitive_indices[primitive_counter];
481
- if (bvh_query_intersection_test(query, bvh.item_lowers[primitive_index], bvh.item_uppers[primitive_index]))
482
- {
483
- if (primitive_counter < end - 1)
484
- // still need to come back to this leaf node for the leftover primitives
485
- {
486
- query.primitive_counter = primitive_counter + 1;
487
- query.stack[query.count++] = node_index;
488
- }
489
- else
490
- // no need to come back to this leaf node
491
- {
492
- query.primitive_counter = -1;
493
- }
494
- index = primitive_index;
495
- query.bounds_nr = primitive_index;
496
-
497
- return true;
498
- }
499
- }
500
- }
501
- else
502
- {
503
- query.stack[query.count++] = left_index;
504
- query.stack[query.count++] = right_index;
505
- }
506
- }
507
- return false;
444
+ // Navigate through the bvh, find the first overlapping leaf node.
445
+ while (query.count)
446
+ {
447
+ const int node_index = query.stack[--query.count];
448
+
449
+ BVHPackedNodeHalf node_lower = bvh_load_node(bvh.node_lowers, node_index);
450
+ BVHPackedNodeHalf node_upper = bvh_load_node(bvh.node_uppers, node_index);
451
+
452
+ if (query.primitive_counter == 0) {
453
+ if (!bvh_query_intersection_test(query, reinterpret_cast<vec3&>(node_lower), reinterpret_cast<vec3&>(node_upper)))
454
+ {
455
+ continue;
456
+ }
457
+ }
458
+
459
+ const int left_index = node_lower.i;
460
+ const int right_index = node_upper.i;
461
+
462
+ if (node_lower.b)
463
+ {
464
+ // found leaf, loop through its content primitives
465
+ const int start = left_index;
466
+
467
+ if (bvh.leaf_size == 1)
468
+ {
469
+ int primitive_index = bvh.primitive_indices[start];
470
+ index = primitive_index;
471
+ query.bounds_nr = primitive_index;
472
+ return true;
473
+ }
474
+ else
475
+ {
476
+ const int end = right_index;
477
+ int primitive_index = bvh.primitive_indices[start + (query.primitive_counter++)];
478
+
479
+ // if already visited the last primitive in the leaf node
480
+ // move to the next node and reset the primitive counter to 0
481
+ if (start + query.primitive_counter == end)
482
+ {
483
+ query.primitive_counter = 0;
484
+ }
485
+ // otherwise we need to keep this leaf node in stack for a future visit
486
+ else
487
+ {
488
+ query.stack[query.count++] = node_index;
489
+ }
490
+ // return true;
491
+ if (bvh_query_intersection_test(query, bvh.item_lowers[primitive_index], bvh.item_uppers[primitive_index]))
492
+ {
493
+ index = primitive_index;
494
+ query.bounds_nr = primitive_index;
495
+
496
+ return true;
497
+ }
498
+ }
499
+ }
500
+ else
501
+ {
502
+ // if it's not a leaf node we treat it as if we have visited the last primitive
503
+ query.primitive_counter = 0;
504
+ query.stack[query.count++] = left_index;
505
+ query.stack[query.count++] = right_index;
506
+ }
507
+ }
508
+ return false;
508
509
  }
509
510
 
510
-
511
511
  CUDA_CALLABLE inline int iter_next(bvh_query_t& query)
512
512
  {
513
513
  return query.bounds_nr;
@@ -540,15 +540,16 @@ CUDA_CALLABLE bool bvh_get_descriptor(uint64_t id, BVH& bvh);
540
540
  CUDA_CALLABLE void bvh_add_descriptor(uint64_t id, const BVH& bvh);
541
541
  CUDA_CALLABLE void bvh_rem_descriptor(uint64_t id);
542
542
 
543
- #if !__CUDA_ARCH__
544
-
545
- void bvh_create_host(vec3* lowers, vec3* uppers, int num_items, int constructor_type, BVH& bvh);
543
+ void bvh_create_host(vec3* lowers, vec3* uppers, int num_items, int constructor_type, BVH& bvh, int leaf_size);
546
544
  void bvh_destroy_host(wp::BVH& bvh);
547
545
  void bvh_refit_host(wp::BVH& bvh);
548
546
 
549
- void bvh_destroy_device(wp::BVH& bvh);
550
- void bvh_refit_device(uint64_t id);
547
+ #if WP_ENABLE_CUDA
551
548
 
552
- #endif
549
+ void bvh_create_device(void* context, vec3* lowers, vec3* uppers, int num_items, int constructor_type, BVH& bvh_device_on_host, int leaf_size);
550
+ void bvh_destroy_device(BVH& bvh);
551
+ void bvh_refit_device(BVH& bvh);
552
+
553
+ #endif // WP_ENABLE_CUDA
553
554
 
554
555
  } // namespace wp