warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/sort.cpp
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
|
-
|
|
2
|
-
* NVIDIA CORPORATION
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
7
16
|
*/
|
|
8
17
|
|
|
9
18
|
#include "warp.h"
|
|
@@ -12,69 +21,75 @@
|
|
|
12
21
|
|
|
13
22
|
#include <cstdint>
|
|
14
23
|
|
|
15
|
-
|
|
24
|
+
//Only integer keys (bit count 32 or 64) are supported. Floats need to get converted into int first. see radix_float_to_int.
|
|
25
|
+
template <typename KeyType>
|
|
26
|
+
void radix_sort_pairs_host(KeyType* keys, int* values, int n, int offset_to_scratch_memory)
|
|
16
27
|
{
|
|
17
|
-
|
|
28
|
+
const int numPasses = sizeof(KeyType) / 2;
|
|
29
|
+
static int tables[numPasses][1 << 16];
|
|
18
30
|
memset(tables, 0, sizeof(tables));
|
|
19
|
-
|
|
20
|
-
int* auxKeys = keys + n;
|
|
21
|
-
int* auxValues = values + n;
|
|
22
|
-
|
|
31
|
+
|
|
23
32
|
// build histograms
|
|
24
|
-
for (int
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
33
|
+
for (int p = 0; p < numPasses; ++p)
|
|
34
|
+
{
|
|
35
|
+
for (int i=0; i < n; ++i)
|
|
36
|
+
{
|
|
37
|
+
const int shift = p * 16;
|
|
38
|
+
const int b = (keys[i] >> shift) & 0xffff;
|
|
39
|
+
|
|
40
|
+
++tables[p][b];
|
|
41
|
+
}
|
|
31
42
|
}
|
|
32
43
|
|
|
33
|
-
// convert histograms to offset tables in-place
|
|
34
|
-
int
|
|
35
|
-
int offhigh = 0;
|
|
36
|
-
|
|
37
|
-
for (int i=0; i < 65536; ++i)
|
|
44
|
+
// convert histograms to offset tables in-place
|
|
45
|
+
for (int p = 0; p < numPasses; ++p)
|
|
38
46
|
{
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
+
int off = 0;
|
|
48
|
+
for (int i = 0; i < 65536; ++i)
|
|
49
|
+
{
|
|
50
|
+
const int newoff = off + tables[p][i];
|
|
51
|
+
|
|
52
|
+
tables[p][i] = off;
|
|
53
|
+
|
|
54
|
+
off = newoff;
|
|
55
|
+
}
|
|
47
56
|
}
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
//
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
57
|
+
|
|
58
|
+
for (int p = 0; p < numPasses; ++p)
|
|
59
|
+
{
|
|
60
|
+
int flipFlop = p % 2;
|
|
61
|
+
KeyType* readKeys = keys + offset_to_scratch_memory * flipFlop;
|
|
62
|
+
int* readValues = values + offset_to_scratch_memory * flipFlop;
|
|
63
|
+
KeyType* writeKeys = keys + offset_to_scratch_memory * (1 - flipFlop);
|
|
64
|
+
int* writeValues = values + offset_to_scratch_memory * (1 - flipFlop);
|
|
65
|
+
|
|
66
|
+
// pass 1 - sort by low 16 bits
|
|
67
|
+
for (int i=0; i < n; ++i)
|
|
68
|
+
{
|
|
69
|
+
// lookup offset of input
|
|
70
|
+
const KeyType k = readKeys[i];
|
|
71
|
+
const int v = readValues[i];
|
|
72
|
+
|
|
73
|
+
const int shift = p * 16;
|
|
74
|
+
const int b = (k >> shift) & 0xffff;
|
|
75
|
+
|
|
76
|
+
// find offset and increment
|
|
77
|
+
const int offset = tables[p][b]++;
|
|
78
|
+
|
|
79
|
+
writeKeys[offset] = k;
|
|
80
|
+
writeValues[offset] = v;
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
70
84
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
85
|
+
void radix_sort_pairs_host(int* keys, int* values, int n)
|
|
86
|
+
{
|
|
87
|
+
radix_sort_pairs_host<int>(keys, values, n, n);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
void radix_sort_pairs_host(int64_t* keys, int* values, int n)
|
|
91
|
+
{
|
|
92
|
+
radix_sort_pairs_host<int64_t>(keys, values, n, n);
|
|
78
93
|
}
|
|
79
94
|
|
|
80
95
|
//http://stereopsis.com/radix.html
|
|
@@ -85,13 +100,13 @@ inline unsigned int radix_float_to_int(float f)
|
|
|
85
100
|
return i ^ mask;
|
|
86
101
|
}
|
|
87
102
|
|
|
88
|
-
void radix_sort_pairs_host(float* keys, int* values, int n)
|
|
103
|
+
void radix_sort_pairs_host(float* keys, int* values, int n, int offset_to_scratch_memory)
|
|
89
104
|
{
|
|
90
105
|
static unsigned int tables[2][1 << 16];
|
|
91
106
|
memset(tables, 0, sizeof(tables));
|
|
92
107
|
|
|
93
|
-
float* auxKeys = keys +
|
|
94
|
-
int* auxValues = values +
|
|
108
|
+
float* auxKeys = keys + offset_to_scratch_memory;
|
|
109
|
+
int* auxValues = values + offset_to_scratch_memory;
|
|
95
110
|
|
|
96
111
|
// build histograms
|
|
97
112
|
for (int i=0; i < n; ++i)
|
|
@@ -153,14 +168,46 @@ void radix_sort_pairs_host(float* keys, int* values, int n)
|
|
|
153
168
|
}
|
|
154
169
|
}
|
|
155
170
|
|
|
171
|
+
void radix_sort_pairs_host(float* keys, int* values, int n)
|
|
172
|
+
{
|
|
173
|
+
radix_sort_pairs_host(keys, values, n, n);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
177
|
+
{
|
|
178
|
+
for (int i = 0; i < num_segments; ++i)
|
|
179
|
+
{
|
|
180
|
+
const int start = segment_start_indices[i];
|
|
181
|
+
const int end = segment_end_indices[i];
|
|
182
|
+
radix_sort_pairs_host(keys + start, values + start, end - start, n);
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
void segmented_sort_pairs_host(int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
187
|
+
{
|
|
188
|
+
for (int i = 0; i < num_segments; ++i)
|
|
189
|
+
{
|
|
190
|
+
const int start = segment_start_indices[i];
|
|
191
|
+
const int end = segment_end_indices[i];
|
|
192
|
+
radix_sort_pairs_host(keys + start, values + start, end - start, n);
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
|
|
156
197
|
#if !WP_ENABLE_CUDA
|
|
157
198
|
|
|
158
199
|
void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out) {}
|
|
159
200
|
|
|
160
201
|
void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n) {}
|
|
161
202
|
|
|
203
|
+
void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n) {}
|
|
204
|
+
|
|
162
205
|
void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n) {}
|
|
163
206
|
|
|
207
|
+
void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments) {}
|
|
208
|
+
|
|
209
|
+
void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments) {}
|
|
210
|
+
|
|
164
211
|
#endif // !WP_ENABLE_CUDA
|
|
165
212
|
|
|
166
213
|
|
|
@@ -171,9 +218,34 @@ void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n)
|
|
|
171
218
|
reinterpret_cast<int *>(values), n);
|
|
172
219
|
}
|
|
173
220
|
|
|
221
|
+
void radix_sort_pairs_int64_host(uint64_t keys, uint64_t values, int n)
|
|
222
|
+
{
|
|
223
|
+
radix_sort_pairs_host(
|
|
224
|
+
reinterpret_cast<int64_t *>(keys),
|
|
225
|
+
reinterpret_cast<int *>(values), n);
|
|
226
|
+
}
|
|
227
|
+
|
|
174
228
|
void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n)
|
|
175
229
|
{
|
|
176
230
|
radix_sort_pairs_host(
|
|
177
231
|
reinterpret_cast<float *>(keys),
|
|
178
232
|
reinterpret_cast<int *>(values), n);
|
|
179
|
-
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
void segmented_sort_pairs_float_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
236
|
+
{
|
|
237
|
+
segmented_sort_pairs_host(
|
|
238
|
+
reinterpret_cast<float *>(keys),
|
|
239
|
+
reinterpret_cast<int *>(values), n,
|
|
240
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
241
|
+
reinterpret_cast<int *>(segment_end_indices), num_segments);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
void segmented_sort_pairs_int_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
245
|
+
{
|
|
246
|
+
segmented_sort_pairs_host(
|
|
247
|
+
reinterpret_cast<int *>(keys),
|
|
248
|
+
reinterpret_cast<int *>(values), n,
|
|
249
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
250
|
+
reinterpret_cast<int *>(segment_end_indices), num_segments);
|
|
251
|
+
}
|
warp/native/sort.cu
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
|
-
|
|
2
|
-
* NVIDIA CORPORATION
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
7
16
|
*/
|
|
8
17
|
|
|
9
18
|
#include "warp.h"
|
|
@@ -27,11 +36,12 @@ struct RadixSortTemp
|
|
|
27
36
|
static std::map<void*, RadixSortTemp> g_radix_sort_temp_map;
|
|
28
37
|
|
|
29
38
|
|
|
30
|
-
|
|
39
|
+
template <typename KeyType>
|
|
40
|
+
void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* size_out)
|
|
31
41
|
{
|
|
32
42
|
ContextGuard guard(context);
|
|
33
43
|
|
|
34
|
-
cub::DoubleBuffer<
|
|
44
|
+
cub::DoubleBuffer<KeyType> d_keys;
|
|
35
45
|
cub::DoubleBuffer<int> d_values;
|
|
36
46
|
|
|
37
47
|
// compute temporary memory required
|
|
@@ -41,7 +51,7 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
|
|
|
41
51
|
sort_temp_size,
|
|
42
52
|
d_keys,
|
|
43
53
|
d_values,
|
|
44
|
-
n, 0,
|
|
54
|
+
n, 0, sizeof(KeyType)*8,
|
|
45
55
|
(cudaStream_t)cuda_stream_get_current()));
|
|
46
56
|
|
|
47
57
|
if (!context)
|
|
@@ -62,15 +72,21 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
|
|
|
62
72
|
*size_out = temp.size;
|
|
63
73
|
}
|
|
64
74
|
|
|
65
|
-
void
|
|
75
|
+
void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
|
|
76
|
+
{
|
|
77
|
+
radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
template <typename KeyType>
|
|
81
|
+
void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
|
|
66
82
|
{
|
|
67
83
|
ContextGuard guard(context);
|
|
68
84
|
|
|
69
|
-
cub::DoubleBuffer<
|
|
85
|
+
cub::DoubleBuffer<KeyType> d_keys(keys, keys + n);
|
|
70
86
|
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
71
87
|
|
|
72
88
|
RadixSortTemp temp;
|
|
73
|
-
|
|
89
|
+
radix_sort_reserve_internal<KeyType>(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
|
|
74
90
|
|
|
75
91
|
// sort
|
|
76
92
|
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
@@ -78,16 +94,31 @@ void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
|
|
|
78
94
|
temp.size,
|
|
79
95
|
d_keys,
|
|
80
96
|
d_values,
|
|
81
|
-
n, 0,
|
|
97
|
+
n, 0, sizeof(KeyType)*8,
|
|
82
98
|
(cudaStream_t)cuda_stream_get_current()));
|
|
83
99
|
|
|
84
100
|
if (d_keys.Current() != keys)
|
|
85
|
-
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(
|
|
101
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(KeyType)*n);
|
|
86
102
|
|
|
87
103
|
if (d_values.Current() != values)
|
|
88
104
|
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
89
105
|
}
|
|
90
106
|
|
|
107
|
+
void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
|
|
108
|
+
{
|
|
109
|
+
radix_sort_pairs_device<int>(context, keys, values, n);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
|
|
113
|
+
{
|
|
114
|
+
radix_sort_pairs_device<float>(context, keys, values, n);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n)
|
|
118
|
+
{
|
|
119
|
+
radix_sort_pairs_device<int64_t>(context, keys, values, n);
|
|
120
|
+
}
|
|
121
|
+
|
|
91
122
|
void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
|
|
92
123
|
{
|
|
93
124
|
radix_sort_pairs_device(
|
|
@@ -96,7 +127,69 @@ void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
|
|
|
96
127
|
reinterpret_cast<int *>(values), n);
|
|
97
128
|
}
|
|
98
129
|
|
|
99
|
-
void
|
|
130
|
+
void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
|
|
131
|
+
{
|
|
132
|
+
radix_sort_pairs_device(
|
|
133
|
+
WP_CURRENT_CONTEXT,
|
|
134
|
+
reinterpret_cast<float *>(keys),
|
|
135
|
+
reinterpret_cast<int *>(values), n);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n)
|
|
139
|
+
{
|
|
140
|
+
radix_sort_pairs_device(
|
|
141
|
+
WP_CURRENT_CONTEXT,
|
|
142
|
+
reinterpret_cast<int64_t *>(keys),
|
|
143
|
+
reinterpret_cast<int *>(values), n);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_out, size_t* size_out)
|
|
147
|
+
{
|
|
148
|
+
ContextGuard guard(context);
|
|
149
|
+
|
|
150
|
+
cub::DoubleBuffer<int> d_keys;
|
|
151
|
+
cub::DoubleBuffer<int> d_values;
|
|
152
|
+
|
|
153
|
+
int* start_indices = NULL;
|
|
154
|
+
int* end_indices = NULL;
|
|
155
|
+
|
|
156
|
+
// compute temporary memory required
|
|
157
|
+
size_t sort_temp_size;
|
|
158
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
159
|
+
NULL,
|
|
160
|
+
sort_temp_size,
|
|
161
|
+
d_keys,
|
|
162
|
+
d_values,
|
|
163
|
+
n,
|
|
164
|
+
num_segments,
|
|
165
|
+
start_indices,
|
|
166
|
+
end_indices,
|
|
167
|
+
0,
|
|
168
|
+
32,
|
|
169
|
+
(cudaStream_t)cuda_stream_get_current()));
|
|
170
|
+
|
|
171
|
+
if (!context)
|
|
172
|
+
context = cuda_context_get_current();
|
|
173
|
+
|
|
174
|
+
RadixSortTemp& temp = g_radix_sort_temp_map[context];
|
|
175
|
+
|
|
176
|
+
if (sort_temp_size > temp.size)
|
|
177
|
+
{
|
|
178
|
+
free_device(WP_CURRENT_CONTEXT, temp.mem);
|
|
179
|
+
temp.mem = alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
|
|
180
|
+
temp.size = sort_temp_size;
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
if (mem_out)
|
|
184
|
+
*mem_out = temp.mem;
|
|
185
|
+
if (size_out)
|
|
186
|
+
*size_out = temp.size;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// segment_start_indices and segment_end_indices are arrays of length num_segments, where segment_start_indices[i] is the index of the first element
|
|
190
|
+
// in the i-th segment and segment_end_indices[i] is the index after the last element in the i-th segment
|
|
191
|
+
// https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html
|
|
192
|
+
void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
100
193
|
{
|
|
101
194
|
ContextGuard guard(context);
|
|
102
195
|
|
|
@@ -104,15 +197,20 @@ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
|
|
|
104
197
|
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
105
198
|
|
|
106
199
|
RadixSortTemp temp;
|
|
107
|
-
|
|
200
|
+
segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
|
|
108
201
|
|
|
109
202
|
// sort
|
|
110
|
-
check_cuda(cub::
|
|
203
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
111
204
|
temp.mem,
|
|
112
205
|
temp.size,
|
|
113
206
|
d_keys,
|
|
114
207
|
d_values,
|
|
115
|
-
n,
|
|
208
|
+
n,
|
|
209
|
+
num_segments,
|
|
210
|
+
segment_start_indices,
|
|
211
|
+
segment_end_indices,
|
|
212
|
+
0,
|
|
213
|
+
32,
|
|
116
214
|
(cudaStream_t)cuda_stream_get_current()));
|
|
117
215
|
|
|
118
216
|
if (d_keys.Current() != keys)
|
|
@@ -122,10 +220,58 @@ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
|
|
|
122
220
|
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
123
221
|
}
|
|
124
222
|
|
|
125
|
-
void
|
|
223
|
+
void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
126
224
|
{
|
|
127
|
-
|
|
225
|
+
segmented_sort_pairs_device(
|
|
128
226
|
WP_CURRENT_CONTEXT,
|
|
129
227
|
reinterpret_cast<float *>(keys),
|
|
130
|
-
reinterpret_cast<int *>(values), n
|
|
228
|
+
reinterpret_cast<int *>(values), n,
|
|
229
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
230
|
+
reinterpret_cast<int *>(segment_end_indices),
|
|
231
|
+
num_segments);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
// segment_indices is an array of length num_segments + 1, where segment_indices[i] is the index of the first element in the i-th segment
|
|
235
|
+
// The end of a segment is given by segment_indices[i+1]
|
|
236
|
+
// https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#a-simple-example
|
|
237
|
+
void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
238
|
+
{
|
|
239
|
+
ContextGuard guard(context);
|
|
240
|
+
|
|
241
|
+
cub::DoubleBuffer<int> d_keys(keys, keys + n);
|
|
242
|
+
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
243
|
+
|
|
244
|
+
RadixSortTemp temp;
|
|
245
|
+
segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
|
|
246
|
+
|
|
247
|
+
// sort
|
|
248
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
249
|
+
temp.mem,
|
|
250
|
+
temp.size,
|
|
251
|
+
d_keys,
|
|
252
|
+
d_values,
|
|
253
|
+
n,
|
|
254
|
+
num_segments,
|
|
255
|
+
segment_start_indices,
|
|
256
|
+
segment_end_indices,
|
|
257
|
+
0,
|
|
258
|
+
32,
|
|
259
|
+
(cudaStream_t)cuda_stream_get_current()));
|
|
260
|
+
|
|
261
|
+
if (d_keys.Current() != keys)
|
|
262
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
|
|
263
|
+
|
|
264
|
+
if (d_values.Current() != values)
|
|
265
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
269
|
+
{
|
|
270
|
+
segmented_sort_pairs_device(
|
|
271
|
+
WP_CURRENT_CONTEXT,
|
|
272
|
+
reinterpret_cast<int *>(keys),
|
|
273
|
+
reinterpret_cast<int *>(values), n,
|
|
274
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
275
|
+
reinterpret_cast<int *>(segment_end_indices),
|
|
276
|
+
num_segments);
|
|
131
277
|
}
|
warp/native/sort.h
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
|
-
|
|
2
|
-
* NVIDIA CORPORATION
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
*
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
7
16
|
*/
|
|
8
17
|
|
|
9
18
|
#pragma once
|
|
@@ -13,5 +22,12 @@
|
|
|
13
22
|
void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
|
|
14
23
|
void radix_sort_pairs_host(int* keys, int* values, int n);
|
|
15
24
|
void radix_sort_pairs_host(float* keys, int* values, int n);
|
|
25
|
+
void radix_sort_pairs_host(int64_t* keys, int* values, int n);
|
|
16
26
|
void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
|
|
17
|
-
void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
|
|
27
|
+
void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
|
|
28
|
+
void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n);
|
|
29
|
+
|
|
30
|
+
void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
31
|
+
void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
32
|
+
void segmented_sort_pairs_host(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
33
|
+
void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|