warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/mesh.cpp CHANGED
@@ -126,7 +126,7 @@ void bvh_refit_with_solid_angle_recursive_host(BVH& bvh, int index, Mesh& mesh)
126
126
 
127
127
  // write new BVH nodes
128
128
  reinterpret_cast<vec3&>(lower) = new_lower;
129
- reinterpret_cast<vec3&>(upper) = new_upper;
129
+ reinterpret_cast<vec3&>(upper) = new_upper;
130
130
  }
131
131
  }
132
132
 
@@ -135,7 +135,7 @@ void bvh_refit_with_solid_angle_host(BVH& bvh, Mesh& mesh)
135
135
  bvh_refit_with_solid_angle_recursive_host(bvh, 0, mesh);
136
136
  }
137
137
 
138
- uint64_t wp_mesh_create_host(array_t<wp::vec3> points, array_t<wp::vec3> velocities, array_t<int> indices, int num_points, int num_tris, int support_winding_number, int constructor_type)
138
+ uint64_t wp_mesh_create_host(array_t<wp::vec3> points, array_t<wp::vec3> velocities, array_t<int> indices, int num_points, int num_tris, int support_winding_number, int constructor_type, int bvh_leaf_size)
139
139
  {
140
140
  Mesh* m = new Mesh(points, velocities, indices, num_points, num_tris);
141
141
 
@@ -163,7 +163,7 @@ uint64_t wp_mesh_create_host(array_t<wp::vec3> points, array_t<wp::vec3> velocit
163
163
  }
164
164
  m->average_edge_length = sum / (num_tris*3);
165
165
 
166
- wp::bvh_create_host(m->lowers, m->uppers, num_tris, constructor_type, m->bvh);
166
+ wp::bvh_create_host(m->lowers, m->uppers, num_tris, constructor_type, m->bvh, bvh_leaf_size);
167
167
 
168
168
  if (support_winding_number)
169
169
  {
@@ -256,7 +256,7 @@ void wp_mesh_set_velocities_host(uint64_t id, wp::array_t<wp::vec3> velocities)
256
256
  #if !WP_ENABLE_CUDA
257
257
 
258
258
 
259
- WP_API uint64_t wp_mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> tris, int num_points, int num_tris, int support_winding_number, int constructor_type) { return 0; }
259
+ WP_API uint64_t wp_mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> tris, int num_points, int num_tris, int support_winding_number, int constructor_type, int bvh_leaf_size) { return 0; }
260
260
  WP_API void wp_mesh_destroy_device(uint64_t id) {}
261
261
  WP_API void wp_mesh_refit_device(uint64_t id) {}
262
262
  WP_API void wp_mesh_set_points_device(uint64_t id, wp::array_t<wp::vec3> points) {};
warp/native/mesh.cu CHANGED
@@ -21,6 +21,8 @@
21
21
  #include "bvh.h"
22
22
  #include "scan.h"
23
23
 
24
+ extern CUcontext get_current_context();
25
+
24
26
  namespace wp
25
27
  {
26
28
 
@@ -245,7 +247,7 @@ void bvh_refit_with_solid_angle_device(BVH& bvh, Mesh& mesh)
245
247
  } // namespace wp
246
248
 
247
249
 
248
- uint64_t wp_mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> indices, int num_points, int num_tris, int support_winding_number, int constructor_type)
250
+ uint64_t wp_mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> indices, int num_points, int num_tris, int support_winding_number, int constructor_type, int bvh_leaf_size)
249
251
  {
250
252
  ContextGuard guard(context);
251
253
 
@@ -280,7 +282,7 @@ uint64_t wp_mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::
280
282
 
281
283
  // compute triangle bound and construct BVH
282
284
  wp_launch_device(WP_CURRENT_CONTEXT, wp::compute_triangle_bounds, mesh.num_tris, (mesh.num_tris, mesh.points, mesh.indices, mesh.lowers, mesh.uppers));
283
- wp::bvh_create_device(mesh.context, mesh.lowers, mesh.uppers, num_tris, constructor_type, mesh.bvh);
285
+ wp::bvh_create_device(mesh.context, mesh.lowers, mesh.uppers, num_tris, constructor_type, mesh.bvh, bvh_leaf_size);
284
286
 
285
287
  // we need to overwrite mesh.bvh because it is not initialized when we construct it on device
286
288
  wp_memcpy_h2d(WP_CURRENT_CONTEXT, &(mesh_device->bvh), &mesh.bvh, sizeof(wp::BVH));
warp/native/mesh.h CHANGED
@@ -1374,6 +1374,7 @@ CUDA_CALLABLE inline bool mesh_query_ray(uint64_t id, const vec3& start, const v
1374
1374
  Mesh mesh = mesh_get(id);
1375
1375
 
1376
1376
  int stack[BVH_QUERY_STACK_SIZE];
1377
+
1377
1378
  stack[0] = *mesh.bvh.root;
1378
1379
  int count = 1;
1379
1380
 
@@ -1463,6 +1464,139 @@ CUDA_CALLABLE inline bool mesh_query_ray(uint64_t id, const vec3& start, const v
1463
1464
 
1464
1465
  }
1465
1466
 
1467
+ template <typename T>
1468
+ CUDA_CALLABLE inline void _swap(T& a, T& b)
1469
+ {
1470
+ T t = a; a = b; b = t;
1471
+ }
1472
+
1473
+ CUDA_CALLABLE inline bool mesh_query_ray_ordered(uint64_t id, const vec3& start, const vec3& dir, float max_t, float& t, float& u, float& v, float& sign, vec3& normal, int& face)
1474
+ {
1475
+ Mesh mesh = mesh_get(id);
1476
+
1477
+ int stack[BVH_QUERY_STACK_SIZE];
1478
+ float stack_dist[BVH_QUERY_STACK_SIZE];
1479
+
1480
+ stack[0] = *mesh.bvh.root;
1481
+ stack_dist[0] = -FLT_MAX;
1482
+
1483
+ int count = 1;
1484
+
1485
+ vec3 rcp_dir = vec3(1.0f/dir[0], 1.0f/dir[1], 1.0f/dir[2]);
1486
+
1487
+ float min_t = max_t;
1488
+ int min_face;
1489
+ float min_u;
1490
+ float min_v;
1491
+ float min_sign = 1.0f;
1492
+ vec3 min_normal;
1493
+
1494
+ while (count)
1495
+ {
1496
+ count -= 1;
1497
+
1498
+ const int nodeIndex = stack[count];
1499
+ const float nodeDist = stack_dist[count];
1500
+
1501
+ if (nodeDist < min_t)
1502
+ {
1503
+ int left_index = mesh.bvh.node_lowers[nodeIndex].i;
1504
+ int right_index = mesh.bvh.node_uppers[nodeIndex].i;
1505
+ bool leaf = mesh.bvh.node_lowers[nodeIndex].b;
1506
+
1507
+ if (leaf)
1508
+ {
1509
+ const int start_index = left_index;
1510
+ const int end_index = right_index;
1511
+ // loops through primitives in the leaf
1512
+ for (int primitive_counter = start_index; primitive_counter < end_index ; primitive_counter++)
1513
+ {
1514
+ int primitive_index = mesh.bvh.primitive_indices[primitive_counter];
1515
+ int i = mesh.indices[primitive_index * 3 + 0];
1516
+ int j = mesh.indices[primitive_index * 3 + 1];
1517
+ int k = mesh.indices[primitive_index * 3 + 2];
1518
+
1519
+ vec3 p = mesh.points[i];
1520
+ vec3 q = mesh.points[j];
1521
+ vec3 r = mesh.points[k];
1522
+
1523
+ float t, u, v, w, sign;
1524
+ vec3 n;
1525
+
1526
+ if (intersect_ray_tri_rtcd(start, dir, p, q, r, t, u, v, w, sign, &n))
1527
+ {
1528
+ if (t < min_t && t >= 0.0f)
1529
+ {
1530
+ min_t = t;
1531
+ min_face = primitive_index;
1532
+ min_u = u;
1533
+ min_v = v;
1534
+ min_sign = sign;
1535
+ min_normal = n;
1536
+ }
1537
+ }
1538
+ }
1539
+ }
1540
+ else
1541
+ {
1542
+ const float eps = 1.e-3f;
1543
+
1544
+ BVHPackedNodeHalf left_lower = bvh_load_node(mesh.bvh.node_lowers, left_index);
1545
+ BVHPackedNodeHalf left_upper = bvh_load_node(mesh.bvh.node_uppers, left_index);
1546
+
1547
+ BVHPackedNodeHalf right_lower = bvh_load_node(mesh.bvh.node_lowers, right_index);
1548
+ BVHPackedNodeHalf right_upper = bvh_load_node(mesh.bvh.node_uppers, right_index);
1549
+
1550
+ float left_dist = FLT_MAX;
1551
+ bool left_hit = intersect_ray_aabb(start, rcp_dir, vec3(left_lower.x-eps, left_lower.y-eps, left_lower.z-eps), vec3(left_upper.x+eps, left_upper.y+eps, left_upper.z+eps), left_dist);
1552
+
1553
+ float right_dist = FLT_MAX;
1554
+ bool right_hit = intersect_ray_aabb(start, rcp_dir, vec3(right_lower.x-eps, right_lower.y-eps, right_lower.z-eps), vec3(right_upper.x+eps, right_upper.y+eps, right_upper.z+eps), right_dist);
1555
+
1556
+
1557
+ if (left_dist < right_dist)
1558
+ {
1559
+ _swap(left_index, right_index);
1560
+ _swap(left_dist, right_dist);
1561
+ _swap(left_hit, right_hit);
1562
+ }
1563
+
1564
+ if (left_hit && left_dist < min_t)
1565
+ {
1566
+ stack[count] = left_index;
1567
+ stack_dist[count] = left_dist;
1568
+ count += 1;
1569
+ }
1570
+
1571
+ if (right_hit && right_dist < min_t)
1572
+ {
1573
+ stack[count] = right_index;
1574
+ stack_dist[count] = right_dist;
1575
+ count += 1;
1576
+ }
1577
+ }
1578
+ }
1579
+ }
1580
+
1581
+ if (min_t < max_t)
1582
+ {
1583
+ // write outputs
1584
+ u = min_u;
1585
+ v = min_v;
1586
+ sign = min_sign;
1587
+ t = min_t;
1588
+ normal = normalize(min_normal);
1589
+ face = min_face;
1590
+
1591
+ return true;
1592
+ }
1593
+ else
1594
+ {
1595
+ return false;
1596
+ }
1597
+
1598
+ }
1599
+
1466
1600
 
1467
1601
  CUDA_CALLABLE inline void adj_mesh_query_ray(
1468
1602
  uint64_t id, const vec3& start, const vec3& dir, float max_t, float t, float u, float v, float sign, const vec3& n, int face,
@@ -1589,7 +1723,12 @@ struct mesh_query_aabb_t
1589
1723
  // Mesh Id
1590
1724
  Mesh mesh;
1591
1725
  // BVH traversal stack:
1726
+ #if BVH_SHARED_STACK
1727
+ bvh_stack_t stack;
1728
+ #else
1592
1729
  int stack[BVH_QUERY_STACK_SIZE];
1730
+ #endif
1731
+
1593
1732
  int count;
1594
1733
 
1595
1734
  // inputs
@@ -1617,13 +1756,16 @@ CUDA_CALLABLE inline mesh_query_aabb_t mesh_query_aabb(
1617
1756
 
1618
1757
  Mesh mesh = mesh_get(id);
1619
1758
  query.mesh = mesh;
1759
+
1760
+ #if BVH_SHARED_STACK
1761
+ __shared__ int stack[BVH_QUERY_STACK_SIZE * WP_TILE_BLOCK_DIM];
1762
+ query.stack.ptr = &stack[threadIdx.x];
1763
+ #endif
1620
1764
 
1621
1765
  query.stack[0] = *mesh.bvh.root;
1622
1766
  query.count = 1;
1623
1767
  query.input_lower = lower;
1624
1768
  query.input_upper = upper;
1625
-
1626
- wp::bounds3 input_bounds(query.input_lower, query.input_upper);
1627
1769
 
1628
1770
  // Navigate through the bvh, find the first overlapping leaf node.
1629
1771
  while (query.count)
@@ -1632,10 +1774,13 @@ CUDA_CALLABLE inline mesh_query_aabb_t mesh_query_aabb(
1632
1774
  BVHPackedNodeHalf node_lower = bvh_load_node(mesh.bvh.node_lowers, nodeIndex);
1633
1775
  BVHPackedNodeHalf node_upper = bvh_load_node(mesh.bvh.node_uppers, nodeIndex);
1634
1776
 
1635
- if (!input_bounds.overlaps(reinterpret_cast<vec3&>(node_lower), reinterpret_cast<vec3&>(node_upper)))
1777
+ if (query.primitive_counter == 0)
1636
1778
  {
1637
- // Skip this box, it doesn't overlap with our target box.
1638
- continue;
1779
+ if (!intersect_aabb_aabb(query.input_lower, query.input_upper, reinterpret_cast<vec3&>(node_lower), reinterpret_cast<vec3&>(node_upper)))
1780
+ {
1781
+ // Skip this box, it doesn't overlap with our target box.
1782
+ continue;
1783
+ }
1639
1784
  }
1640
1785
 
1641
1786
  const int left_index = node_lower.i;
@@ -1646,7 +1791,7 @@ CUDA_CALLABLE inline mesh_query_aabb_t mesh_query_aabb(
1646
1791
  {
1647
1792
  // Reached a leaf node, point to its first primitive
1648
1793
  // Back up one level and return
1649
- query.primitive_counter = left_index;
1794
+ query.primitive_counter = 0;
1650
1795
  query.stack[query.count++] = nodeIndex;
1651
1796
  return query;
1652
1797
  }
@@ -1671,45 +1816,6 @@ CUDA_CALLABLE inline bool mesh_query_aabb_next(mesh_query_aabb_t& query, int& in
1671
1816
  {
1672
1817
  Mesh mesh = query.mesh;
1673
1818
 
1674
- wp::bounds3 input_bounds(query.input_lower, query.input_upper);
1675
-
1676
- if (query.primitive_counter != -1)
1677
- // currently in a leaf node which is the last node in the stack
1678
- {
1679
- const int node_index = query.stack[query.count - 1];
1680
- BVHPackedNodeHalf node_lower = bvh_load_node(mesh.bvh.node_lowers, node_index);
1681
- BVHPackedNodeHalf node_upper = bvh_load_node(mesh.bvh.node_uppers, node_index);
1682
-
1683
- const int end = node_upper.i;
1684
- for (int primitive_counter = query.primitive_counter; primitive_counter < end; primitive_counter++)
1685
- {
1686
- int primitive_index = mesh.bvh.primitive_indices[primitive_counter];
1687
- if (input_bounds.overlaps(mesh.lowers[primitive_index], mesh.uppers[primitive_index]))
1688
- {
1689
- if (primitive_counter < end - 1)
1690
- // still need to come back to this leaf node for the leftover primitives
1691
- {
1692
- query.primitive_counter = primitive_counter + 1;
1693
- }
1694
- else
1695
- // no need to come back to this leaf node
1696
- {
1697
- query.count--;
1698
- query.primitive_counter = -1;
1699
- }
1700
- index = primitive_index;
1701
- query.face = primitive_index;
1702
-
1703
- return true;
1704
- }
1705
- }
1706
- // if we reach here it means we have finished the current leaf node without finding intersections
1707
- query.primitive_counter = -1;
1708
- // remove the leaf node from the back of the stack because it is finished
1709
- // and continue the bvh traversal
1710
- query.count--;
1711
- }
1712
-
1713
1819
  // Navigate through the bvh, find the first overlapping leaf node.
1714
1820
  while (query.count)
1715
1821
  {
@@ -1717,7 +1823,7 @@ CUDA_CALLABLE inline bool mesh_query_aabb_next(mesh_query_aabb_t& query, int& in
1717
1823
  BVHPackedNodeHalf node_lower = bvh_load_node(mesh.bvh.node_lowers, node_index);
1718
1824
  BVHPackedNodeHalf node_upper = bvh_load_node(mesh.bvh.node_uppers, node_index);
1719
1825
 
1720
- if (!input_bounds.overlaps(reinterpret_cast<vec3&>(node_lower), reinterpret_cast<vec3&>(node_upper)))
1826
+ if (!intersect_aabb_aabb(query.input_lower, query.input_upper, reinterpret_cast<vec3&>(node_lower), reinterpret_cast<vec3&>(node_upper)))
1721
1827
  {
1722
1828
  // Skip this box, it doesn't overlap with our target box.
1723
1829
  continue;
@@ -1731,24 +1837,32 @@ CUDA_CALLABLE inline bool mesh_query_aabb_next(mesh_query_aabb_t& query, int& in
1731
1837
  {
1732
1838
  // found leaf, loop through its content primitives
1733
1839
  const int start = left_index;
1734
- const int end = right_index;
1735
1840
 
1736
- for (int primitive_counter = start; primitive_counter < end; primitive_counter++)
1841
+ if (mesh.bvh.leaf_size == 1)
1737
1842
  {
1738
- int primitive_index = mesh.bvh.primitive_indices[primitive_counter];
1739
- if (input_bounds.overlaps(mesh.lowers[primitive_index], mesh.uppers[primitive_index]))
1843
+ int primitive_index = mesh.bvh.primitive_indices[start];
1844
+ index = primitive_index;
1845
+ query.face = primitive_index;
1846
+ return true;
1847
+ }
1848
+ else
1849
+ {
1850
+ const int end = right_index;
1851
+ int primitive_index = mesh.bvh.primitive_indices[start + (query.primitive_counter++)];
1852
+ // if already visited the last primitive in the leaf node
1853
+ // move to the next node and reset the primitive counter to 0
1854
+ if (start + query.primitive_counter == end)
1855
+ {
1856
+ query.primitive_counter = 0;
1857
+ }
1858
+ // otherwise we need to keep this leaf node in stack for a future visit
1859
+ else
1860
+ {
1861
+ query.count++;
1862
+ }
1863
+
1864
+ if (intersect_aabb_aabb(query.input_lower, query.input_upper, mesh.lowers[primitive_index], mesh.uppers[primitive_index]))
1740
1865
  {
1741
- if (primitive_counter < end - 1)
1742
- // still need to come back to this leaf node for the leftover primitives
1743
- {
1744
- query.primitive_counter = primitive_counter + 1;
1745
- query.stack[query.count++] = node_index;
1746
- }
1747
- else
1748
- // no need to come back to this leaf node
1749
- {
1750
- query.primitive_counter = -1;
1751
- }
1752
1866
  index = primitive_index;
1753
1867
  query.face = primitive_index;
1754
1868
 
@@ -1758,6 +1872,7 @@ CUDA_CALLABLE inline bool mesh_query_aabb_next(mesh_query_aabb_t& query, int& in
1758
1872
  }
1759
1873
  else
1760
1874
  {
1875
+ query.primitive_counter = 0;
1761
1876
  query.stack[query.count++] = left_index;
1762
1877
  query.stack[query.count++] = right_index;
1763
1878
  }
warp/native/quat.h CHANGED
@@ -1621,58 +1621,6 @@ inline CUDA_CALLABLE void adj_quat_from_matrix(const mat_t<Rows,Cols,Type>& m, m
1621
1621
  adj_m.data[2][2] += dot(dq_dm22, adj_q);
1622
1622
  }
1623
1623
 
1624
- template<typename Type>
1625
- inline CUDA_CALLABLE void adj_mat_t(const vec_t<3,Type>& pos, const quat_t<Type>& rot, const vec_t<3,Type>& scale,
1626
- vec_t<3,Type>& adj_pos, quat_t<Type>& adj_rot, vec_t<3,Type>& adj_scale, const mat_t<4,4,Type>& adj_ret)
1627
- {
1628
- mat_t<3,3,Type> R = quat_to_matrix(rot);
1629
- mat_t<3,3,Type> adj_R(0);
1630
-
1631
- adj_pos[0] += adj_ret.data[0][3];
1632
- adj_pos[1] += adj_ret.data[1][3];
1633
- adj_pos[2] += adj_ret.data[2][3];
1634
-
1635
- adj_mul(R.data[0][0], scale[0], adj_R.data[0][0], adj_scale[0], adj_ret.data[0][0]);
1636
- adj_mul(R.data[1][0], scale[0], adj_R.data[1][0], adj_scale[0], adj_ret.data[1][0]);
1637
- adj_mul(R.data[2][0], scale[0], adj_R.data[2][0], adj_scale[0], adj_ret.data[2][0]);
1638
-
1639
- adj_mul(R.data[0][1], scale[1], adj_R.data[0][1], adj_scale[1], adj_ret.data[0][1]);
1640
- adj_mul(R.data[1][1], scale[1], adj_R.data[1][1], adj_scale[1], adj_ret.data[1][1]);
1641
- adj_mul(R.data[2][1], scale[1], adj_R.data[2][1], adj_scale[1], adj_ret.data[2][1]);
1642
-
1643
- adj_mul(R.data[0][2], scale[2], adj_R.data[0][2], adj_scale[2], adj_ret.data[0][2]);
1644
- adj_mul(R.data[1][2], scale[2], adj_R.data[1][2], adj_scale[2], adj_ret.data[1][2]);
1645
- adj_mul(R.data[2][2], scale[2], adj_R.data[2][2], adj_scale[2], adj_ret.data[2][2]);
1646
-
1647
- adj_quat_to_matrix(rot, adj_rot, adj_R);
1648
- }
1649
-
1650
- template<unsigned Rows, unsigned Cols, typename Type>
1651
- inline CUDA_CALLABLE mat_t<Rows,Cols,Type>::mat_t(const vec_t<3,Type>& pos, const quat_t<Type>& rot, const vec_t<3,Type>& scale)
1652
- {
1653
- mat_t<3,3,Type> R = quat_to_matrix(rot);
1654
-
1655
- data[0][0] = R.data[0][0]*scale[0];
1656
- data[1][0] = R.data[1][0]*scale[0];
1657
- data[2][0] = R.data[2][0]*scale[0];
1658
- data[3][0] = Type(0);
1659
-
1660
- data[0][1] = R.data[0][1]*scale[1];
1661
- data[1][1] = R.data[1][1]*scale[1];
1662
- data[2][1] = R.data[2][1]*scale[1];
1663
- data[3][1] = Type(0);
1664
-
1665
- data[0][2] = R.data[0][2]*scale[2];
1666
- data[1][2] = R.data[1][2]*scale[2];
1667
- data[2][2] = R.data[2][2]*scale[2];
1668
- data[3][2] = Type(0);
1669
-
1670
- data[0][3] = pos[0];
1671
- data[1][3] = pos[1];
1672
- data[2][3] = pos[2];
1673
- data[3][3] = Type(1);
1674
- }
1675
-
1676
1624
  template<typename Type=float32>
1677
1625
  inline CUDA_CALLABLE quat_t<Type> quat_identity()
1678
1626
  {
warp/native/scan.cu CHANGED
@@ -18,6 +18,8 @@
18
18
  #include "warp.h"
19
19
  #include "scan.h"
20
20
 
21
+ #include "cuda_util.h"
22
+
21
23
  #define THRUST_IGNORE_CUB_VERSION_CHECK
22
24
 
23
25
  #include <cub/device/device_scan.cuh>
warp/native/sort.cu CHANGED
@@ -23,7 +23,7 @@
23
23
 
24
24
  #include <cub/cub.cuh>
25
25
 
26
- #include <map>
26
+ #include <unordered_map>
27
27
 
28
28
  // temporary buffer for radix sort
29
29
  struct RadixSortTemp
@@ -32,8 +32,8 @@ struct RadixSortTemp
32
32
  size_t size = 0;
33
33
  };
34
34
 
35
- // map temp buffers to CUDA contexts
36
- static std::map<void*, RadixSortTemp> g_radix_sort_temp_map;
35
+ // use unique temp buffers per CUDA stream to avoid race conditions
36
+ static std::unordered_map<void*, RadixSortTemp> g_radix_sort_temp_map;
37
37
 
38
38
 
39
39
  template <typename KeyType>
@@ -44,6 +44,8 @@ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* s
44
44
  cub::DoubleBuffer<KeyType> d_keys;
45
45
  cub::DoubleBuffer<int> d_values;
46
46
 
47
+ CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
48
+
47
49
  // compute temporary memory required
48
50
  size_t sort_temp_size;
49
51
  check_cuda(cub::DeviceRadixSort::SortPairs(
@@ -52,12 +54,9 @@ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* s
52
54
  d_keys,
53
55
  d_values,
54
56
  n, 0, sizeof(KeyType)*8,
55
- (cudaStream_t)wp_cuda_stream_get_current()));
56
-
57
- if (!context)
58
- context = wp_cuda_context_get_current();
57
+ stream));
59
58
 
60
- RadixSortTemp& temp = g_radix_sort_temp_map[context];
59
+ RadixSortTemp& temp = g_radix_sort_temp_map[stream];
61
60
 
62
61
  if (sort_temp_size > temp.size)
63
62
  {
@@ -77,6 +76,17 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
77
76
  radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
78
77
  }
79
78
 
79
+ void radix_sort_release(void* context, void* stream)
80
+ {
81
+ // release temporary buffer for the given stream, if it exists
82
+ auto it = g_radix_sort_temp_map.find(stream);
83
+ if (it != g_radix_sort_temp_map.end())
84
+ {
85
+ wp_free_device(context, it->second.mem);
86
+ g_radix_sort_temp_map.erase(it);
87
+ }
88
+ }
89
+
80
90
  template <typename KeyType>
81
91
  void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
82
92
  {
@@ -153,6 +163,8 @@ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_o
153
163
  int* start_indices = NULL;
154
164
  int* end_indices = NULL;
155
165
 
166
+ CUstream stream = static_cast<CUstream>(wp_cuda_stream_get_current());
167
+
156
168
  // compute temporary memory required
157
169
  size_t sort_temp_size;
158
170
  check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
@@ -166,12 +178,9 @@ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_o
166
178
  end_indices,
167
179
  0,
168
180
  32,
169
- (cudaStream_t)wp_cuda_stream_get_current()));
170
-
171
- if (!context)
172
- context = wp_cuda_context_get_current();
181
+ stream));
173
182
 
174
- RadixSortTemp& temp = g_radix_sort_temp_map[context];
183
+ RadixSortTemp& temp = g_radix_sort_temp_map[stream];
175
184
 
176
185
  if (sort_temp_size > temp.size)
177
186
  {
warp/native/sort.h CHANGED
@@ -20,6 +20,8 @@
20
20
  #include <stddef.h>
21
21
 
22
22
  void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
23
+ void radix_sort_release(void* context, void* stream);
24
+
23
25
  void radix_sort_pairs_host(int* keys, int* values, int n);
24
26
  void radix_sort_pairs_host(float* keys, int* values, int n);
25
27
  void radix_sort_pairs_host(int64_t* keys, int* values, int n);
warp/native/sparse.cu CHANGED
@@ -16,8 +16,9 @@
16
16
  */
17
17
 
18
18
  #include "cuda_util.h"
19
+ #include "temp_buffer.h"
19
20
  #include "warp.h"
20
- #include "stdint.h"
21
+
21
22
  #include <cstdint>
22
23
 
23
24
  #define THRUST_IGNORE_CUB_VERSION_CHECK
@@ -26,6 +27,8 @@
26
27
  #include <cub/device/device_run_length_encode.cuh>
27
28
  #include <cub/device/device_scan.cuh>
28
29
 
30
+ extern CUcontext get_current_context();
31
+
29
32
  namespace
30
33
  {
31
34
 
@@ -361,7 +364,8 @@ WP_API void wp_bsr_matrix_from_triplets_device(
361
364
 
362
365
  if (bsr_nnz_event)
363
366
  {
364
- wp_cuda_event_record(bsr_nnz_event, stream);
367
+ const bool external = true;
368
+ wp_cuda_event_record(bsr_nnz_event, stream, external);
365
369
  }
366
370
  }
367
371
 
@@ -416,7 +420,7 @@ WP_API void wp_bsr_transpose_device(int row_count, int col_count, int nnz,
416
420
  // Ensures the sorted keys are available in summed_block_indices if needed
417
421
  if(d_keys.Current() != src_block_indices)
418
422
  {
419
- check_cuda(cudaMemcpy(src_block_indices, src_block_indices+nnz, size_t(nnz) * sizeof(int), cudaMemcpyDeviceToDevice));
423
+ check_cuda(cudaMemcpyAsync(src_block_indices, src_block_indices+nnz, size_t(nnz) * sizeof(int), cudaMemcpyDeviceToDevice, stream));
420
424
  }
421
425
  }
422
426
 
warp/native/spatial.h CHANGED
@@ -992,6 +992,18 @@ CUDA_CALLABLE inline void adj_transform_t(const vec_t<3,Type>& p, const quat_t<T
992
992
  adj_q += adj_ret.q;
993
993
  }
994
994
 
995
+ template<typename Type>
996
+ CUDA_CALLABLE inline void adj_transform_t(const initializer_array<7, Type> &l, const initializer_array<7, Type*>& adj_l, const transform_t<Type>& adj_ret)
997
+ {
998
+ *adj_l[0] += adj_ret.p[0];
999
+ *adj_l[1] += adj_ret.p[1];
1000
+ *adj_l[2] += adj_ret.p[2];
1001
+ *adj_l[3] += adj_ret.q[0];
1002
+ *adj_l[4] += adj_ret.q[1];
1003
+ *adj_l[5] += adj_ret.q[2];
1004
+ *adj_l[6] += adj_ret.q[3];
1005
+ }
1006
+
995
1007
  template<typename Type>
996
1008
  CUDA_CALLABLE inline void adj_transform_inverse(const transform_t<Type>& t, transform_t<Type>& adj_t, const transform_t<Type>& adj_ret)
997
1009
  {