warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/sort.cpp CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #include "warp.h"
@@ -12,69 +21,75 @@
12
21
 
13
22
  #include <cstdint>
14
23
 
15
- void radix_sort_pairs_host(int* keys, int* values, int n)
24
+ //Only integer keys (bit count 32 or 64) are supported. Floats need to get converted into int first. see radix_float_to_int.
25
+ template <typename KeyType>
26
+ void radix_sort_pairs_host(KeyType* keys, int* values, int n, int offset_to_scratch_memory)
16
27
  {
17
- static int tables[2][1 << 16];
28
+ const int numPasses = sizeof(KeyType) / 2;
29
+ static int tables[numPasses][1 << 16];
18
30
  memset(tables, 0, sizeof(tables));
19
-
20
- int* auxKeys = keys + n;
21
- int* auxValues = values + n;
22
-
31
+
23
32
  // build histograms
24
- for (int i=0; i < n; ++i)
25
- {
26
- const unsigned short low = keys[i] & 0xffff;
27
- const unsigned short high = keys[i] >> 16;
28
-
29
- ++tables[0][low];
30
- ++tables[1][high];
33
+ for (int p = 0; p < numPasses; ++p)
34
+ {
35
+ for (int i=0; i < n; ++i)
36
+ {
37
+ const int shift = p * 16;
38
+ const int b = (keys[i] >> shift) & 0xffff;
39
+
40
+ ++tables[p][b];
41
+ }
31
42
  }
32
43
 
33
- // convert histograms to offset tables in-place
34
- int offlow = 0;
35
- int offhigh = 0;
36
-
37
- for (int i=0; i < 65536; ++i)
44
+ // convert histograms to offset tables in-place
45
+ for (int p = 0; p < numPasses; ++p)
38
46
  {
39
- const int newofflow = offlow + tables[0][i];
40
- const int newoffhigh = offhigh + tables[1][i];
41
-
42
- tables[0][i] = offlow;
43
- tables[1][i] = offhigh;
44
-
45
- offlow = newofflow;
46
- offhigh = newoffhigh;
47
+ int off = 0;
48
+ for (int i = 0; i < 65536; ++i)
49
+ {
50
+ const int newoff = off + tables[p][i];
51
+
52
+ tables[p][i] = off;
53
+
54
+ off = newoff;
55
+ }
47
56
  }
48
-
49
- // pass 1 - sort by low 16 bits
50
- for (int i=0; i < n; ++i)
51
- {
52
- // lookup offset of input
53
- const int k = keys[i];
54
- const int v = values[i];
55
- const int b = k & 0xffff;
56
-
57
- // find offset and increment
58
- const int offset = tables[0][b]++;
59
-
60
- auxKeys[offset] = k;
61
- auxValues[offset] = v;
62
- }
63
-
64
- // pass 2 - sort by high 16 bits
65
- for (int i=0; i < n; ++i)
66
- {
67
- // lookup offset of input
68
- const int k = auxKeys[i];
69
- const int v = auxValues[i];
57
+
58
+ for (int p = 0; p < numPasses; ++p)
59
+ {
60
+ int flipFlop = p % 2;
61
+ KeyType* readKeys = keys + offset_to_scratch_memory * flipFlop;
62
+ int* readValues = values + offset_to_scratch_memory * flipFlop;
63
+ KeyType* writeKeys = keys + offset_to_scratch_memory * (1 - flipFlop);
64
+ int* writeValues = values + offset_to_scratch_memory * (1 - flipFlop);
65
+
66
+ // pass 1 - sort by low 16 bits
67
+ for (int i=0; i < n; ++i)
68
+ {
69
+ // lookup offset of input
70
+ const KeyType k = readKeys[i];
71
+ const int v = readValues[i];
72
+
73
+ const int shift = p * 16;
74
+ const int b = (k >> shift) & 0xffff;
75
+
76
+ // find offset and increment
77
+ const int offset = tables[p][b]++;
78
+
79
+ writeKeys[offset] = k;
80
+ writeValues[offset] = v;
81
+ }
82
+ }
83
+ }
70
84
 
71
- const int b = k >> 16;
72
-
73
- const int offset = tables[1][b]++;
74
-
75
- keys[offset] = k;
76
- values[offset] = v;
77
- }
85
+ void radix_sort_pairs_host(int* keys, int* values, int n)
86
+ {
87
+ radix_sort_pairs_host<int>(keys, values, n, n);
88
+ }
89
+
90
+ void radix_sort_pairs_host(int64_t* keys, int* values, int n)
91
+ {
92
+ radix_sort_pairs_host<int64_t>(keys, values, n, n);
78
93
  }
79
94
 
80
95
  //http://stereopsis.com/radix.html
@@ -85,13 +100,13 @@ inline unsigned int radix_float_to_int(float f)
85
100
  return i ^ mask;
86
101
  }
87
102
 
88
- void radix_sort_pairs_host(float* keys, int* values, int n)
103
+ void radix_sort_pairs_host(float* keys, int* values, int n, int offset_to_scratch_memory)
89
104
  {
90
105
  static unsigned int tables[2][1 << 16];
91
106
  memset(tables, 0, sizeof(tables));
92
107
 
93
- float* auxKeys = keys + n;
94
- int* auxValues = values + n;
108
+ float* auxKeys = keys + offset_to_scratch_memory;
109
+ int* auxValues = values + offset_to_scratch_memory;
95
110
 
96
111
  // build histograms
97
112
  for (int i=0; i < n; ++i)
@@ -153,14 +168,46 @@ void radix_sort_pairs_host(float* keys, int* values, int n)
153
168
  }
154
169
  }
155
170
 
171
+ void radix_sort_pairs_host(float* keys, int* values, int n)
172
+ {
173
+ radix_sort_pairs_host(keys, values, n, n);
174
+ }
175
+
176
+ void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
177
+ {
178
+ for (int i = 0; i < num_segments; ++i)
179
+ {
180
+ const int start = segment_start_indices[i];
181
+ const int end = segment_end_indices[i];
182
+ radix_sort_pairs_host(keys + start, values + start, end - start, n);
183
+ }
184
+ }
185
+
186
+ void segmented_sort_pairs_host(int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
187
+ {
188
+ for (int i = 0; i < num_segments; ++i)
189
+ {
190
+ const int start = segment_start_indices[i];
191
+ const int end = segment_end_indices[i];
192
+ radix_sort_pairs_host(keys + start, values + start, end - start, n);
193
+ }
194
+ }
195
+
196
+
156
197
  #if !WP_ENABLE_CUDA
157
198
 
158
199
  void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out) {}
159
200
 
160
201
  void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n) {}
161
202
 
203
+ void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n) {}
204
+
162
205
  void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n) {}
163
206
 
207
+ void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments) {}
208
+
209
+ void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments) {}
210
+
164
211
  #endif // !WP_ENABLE_CUDA
165
212
 
166
213
 
@@ -171,9 +218,34 @@ void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n)
171
218
  reinterpret_cast<int *>(values), n);
172
219
  }
173
220
 
221
+ void radix_sort_pairs_int64_host(uint64_t keys, uint64_t values, int n)
222
+ {
223
+ radix_sort_pairs_host(
224
+ reinterpret_cast<int64_t *>(keys),
225
+ reinterpret_cast<int *>(values), n);
226
+ }
227
+
174
228
  void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n)
175
229
  {
176
230
  radix_sort_pairs_host(
177
231
  reinterpret_cast<float *>(keys),
178
232
  reinterpret_cast<int *>(values), n);
179
- }
233
+ }
234
+
235
+ void segmented_sort_pairs_float_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
236
+ {
237
+ segmented_sort_pairs_host(
238
+ reinterpret_cast<float *>(keys),
239
+ reinterpret_cast<int *>(values), n,
240
+ reinterpret_cast<int *>(segment_start_indices),
241
+ reinterpret_cast<int *>(segment_end_indices), num_segments);
242
+ }
243
+
244
+ void segmented_sort_pairs_int_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
245
+ {
246
+ segmented_sort_pairs_host(
247
+ reinterpret_cast<int *>(keys),
248
+ reinterpret_cast<int *>(values), n,
249
+ reinterpret_cast<int *>(segment_start_indices),
250
+ reinterpret_cast<int *>(segment_end_indices), num_segments);
251
+ }
warp/native/sort.cu CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #include "warp.h"
@@ -27,11 +36,12 @@ struct RadixSortTemp
27
36
  static std::map<void*, RadixSortTemp> g_radix_sort_temp_map;
28
37
 
29
38
 
30
- void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
39
+ template <typename KeyType>
40
+ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* size_out)
31
41
  {
32
42
  ContextGuard guard(context);
33
43
 
34
- cub::DoubleBuffer<int> d_keys;
44
+ cub::DoubleBuffer<KeyType> d_keys;
35
45
  cub::DoubleBuffer<int> d_values;
36
46
 
37
47
  // compute temporary memory required
@@ -41,7 +51,7 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
41
51
  sort_temp_size,
42
52
  d_keys,
43
53
  d_values,
44
- n, 0, 32,
54
+ n, 0, sizeof(KeyType)*8,
45
55
  (cudaStream_t)cuda_stream_get_current()));
46
56
 
47
57
  if (!context)
@@ -62,15 +72,21 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
62
72
  *size_out = temp.size;
63
73
  }
64
74
 
65
- void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
75
+ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
76
+ {
77
+ radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
78
+ }
79
+
80
+ template <typename KeyType>
81
+ void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
66
82
  {
67
83
  ContextGuard guard(context);
68
84
 
69
- cub::DoubleBuffer<int> d_keys(keys, keys + n);
85
+ cub::DoubleBuffer<KeyType> d_keys(keys, keys + n);
70
86
  cub::DoubleBuffer<int> d_values(values, values + n);
71
87
 
72
88
  RadixSortTemp temp;
73
- radix_sort_reserve(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
89
+ radix_sort_reserve_internal<KeyType>(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
74
90
 
75
91
  // sort
76
92
  check_cuda(cub::DeviceRadixSort::SortPairs(
@@ -78,16 +94,31 @@ void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
78
94
  temp.size,
79
95
  d_keys,
80
96
  d_values,
81
- n, 0, 32,
97
+ n, 0, sizeof(KeyType)*8,
82
98
  (cudaStream_t)cuda_stream_get_current()));
83
99
 
84
100
  if (d_keys.Current() != keys)
85
- memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(int)*n);
101
+ memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(KeyType)*n);
86
102
 
87
103
  if (d_values.Current() != values)
88
104
  memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
89
105
  }
90
106
 
107
+ void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
108
+ {
109
+ radix_sort_pairs_device<int>(context, keys, values, n);
110
+ }
111
+
112
+ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
113
+ {
114
+ radix_sort_pairs_device<float>(context, keys, values, n);
115
+ }
116
+
117
+ void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n)
118
+ {
119
+ radix_sort_pairs_device<int64_t>(context, keys, values, n);
120
+ }
121
+
91
122
  void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
92
123
  {
93
124
  radix_sort_pairs_device(
@@ -96,7 +127,69 @@ void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
96
127
  reinterpret_cast<int *>(values), n);
97
128
  }
98
129
 
99
- void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
130
+ void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
131
+ {
132
+ radix_sort_pairs_device(
133
+ WP_CURRENT_CONTEXT,
134
+ reinterpret_cast<float *>(keys),
135
+ reinterpret_cast<int *>(values), n);
136
+ }
137
+
138
+ void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n)
139
+ {
140
+ radix_sort_pairs_device(
141
+ WP_CURRENT_CONTEXT,
142
+ reinterpret_cast<int64_t *>(keys),
143
+ reinterpret_cast<int *>(values), n);
144
+ }
145
+
146
+ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_out, size_t* size_out)
147
+ {
148
+ ContextGuard guard(context);
149
+
150
+ cub::DoubleBuffer<int> d_keys;
151
+ cub::DoubleBuffer<int> d_values;
152
+
153
+ int* start_indices = NULL;
154
+ int* end_indices = NULL;
155
+
156
+ // compute temporary memory required
157
+ size_t sort_temp_size;
158
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
159
+ NULL,
160
+ sort_temp_size,
161
+ d_keys,
162
+ d_values,
163
+ n,
164
+ num_segments,
165
+ start_indices,
166
+ end_indices,
167
+ 0,
168
+ 32,
169
+ (cudaStream_t)cuda_stream_get_current()));
170
+
171
+ if (!context)
172
+ context = cuda_context_get_current();
173
+
174
+ RadixSortTemp& temp = g_radix_sort_temp_map[context];
175
+
176
+ if (sort_temp_size > temp.size)
177
+ {
178
+ free_device(WP_CURRENT_CONTEXT, temp.mem);
179
+ temp.mem = alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
180
+ temp.size = sort_temp_size;
181
+ }
182
+
183
+ if (mem_out)
184
+ *mem_out = temp.mem;
185
+ if (size_out)
186
+ *size_out = temp.size;
187
+ }
188
+
189
+ // segment_start_indices and segment_end_indices are arrays of length num_segments, where segment_start_indices[i] is the index of the first element
190
+ // in the i-th segment and segment_end_indices[i] is the index after the last element in the i-th segment
191
+ // https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html
192
+ void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
100
193
  {
101
194
  ContextGuard guard(context);
102
195
 
@@ -104,15 +197,20 @@ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
104
197
  cub::DoubleBuffer<int> d_values(values, values + n);
105
198
 
106
199
  RadixSortTemp temp;
107
- radix_sort_reserve(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
200
+ segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
108
201
 
109
202
  // sort
110
- check_cuda(cub::DeviceRadixSort::SortPairs(
203
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
111
204
  temp.mem,
112
205
  temp.size,
113
206
  d_keys,
114
207
  d_values,
115
- n, 0, 32,
208
+ n,
209
+ num_segments,
210
+ segment_start_indices,
211
+ segment_end_indices,
212
+ 0,
213
+ 32,
116
214
  (cudaStream_t)cuda_stream_get_current()));
117
215
 
118
216
  if (d_keys.Current() != keys)
@@ -122,10 +220,58 @@ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
122
220
  memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
123
221
  }
124
222
 
125
- void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
223
+ void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
126
224
  {
127
- radix_sort_pairs_device(
225
+ segmented_sort_pairs_device(
128
226
  WP_CURRENT_CONTEXT,
129
227
  reinterpret_cast<float *>(keys),
130
- reinterpret_cast<int *>(values), n);
228
+ reinterpret_cast<int *>(values), n,
229
+ reinterpret_cast<int *>(segment_start_indices),
230
+ reinterpret_cast<int *>(segment_end_indices),
231
+ num_segments);
232
+ }
233
+
234
+ // segment_indices is an array of length num_segments + 1, where segment_indices[i] is the index of the first element in the i-th segment
235
+ // The end of a segment is given by segment_indices[i+1]
236
+ // https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#a-simple-example
237
+ void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
238
+ {
239
+ ContextGuard guard(context);
240
+
241
+ cub::DoubleBuffer<int> d_keys(keys, keys + n);
242
+ cub::DoubleBuffer<int> d_values(values, values + n);
243
+
244
+ RadixSortTemp temp;
245
+ segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
246
+
247
+ // sort
248
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
249
+ temp.mem,
250
+ temp.size,
251
+ d_keys,
252
+ d_values,
253
+ n,
254
+ num_segments,
255
+ segment_start_indices,
256
+ segment_end_indices,
257
+ 0,
258
+ 32,
259
+ (cudaStream_t)cuda_stream_get_current()));
260
+
261
+ if (d_keys.Current() != keys)
262
+ memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
263
+
264
+ if (d_values.Current() != values)
265
+ memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
266
+ }
267
+
268
+ void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
269
+ {
270
+ segmented_sort_pairs_device(
271
+ WP_CURRENT_CONTEXT,
272
+ reinterpret_cast<int *>(keys),
273
+ reinterpret_cast<int *>(values), n,
274
+ reinterpret_cast<int *>(segment_start_indices),
275
+ reinterpret_cast<int *>(segment_end_indices),
276
+ num_segments);
131
277
  }
warp/native/sort.h CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #pragma once
@@ -13,5 +22,12 @@
13
22
  void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
14
23
  void radix_sort_pairs_host(int* keys, int* values, int n);
15
24
  void radix_sort_pairs_host(float* keys, int* values, int n);
25
+ void radix_sort_pairs_host(int64_t* keys, int* values, int n);
16
26
  void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
17
- void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
27
+ void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
28
+ void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n);
29
+
30
+ void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
31
+ void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
32
+ void segmented_sort_pairs_host(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
33
+ void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);