warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +334 -0
- warp/__init__.pyi +5856 -0
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/_src/builtins.py +10555 -0
- warp/_src/codegen.py +4361 -0
- warp/_src/config.py +178 -0
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/_src/fem/domain.py +553 -0
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/_src/fem/field/nodal_field.py +403 -0
- warp/_src/fem/field/restriction.py +39 -0
- warp/_src/fem/field/virtual.py +1021 -0
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/_src/fem/geometry/deformed_geometry.py +277 -0
- warp/_src/fem/geometry/element.py +854 -0
- warp/_src/fem/geometry/geometry.py +693 -0
- warp/_src/fem/geometry/grid_2d.py +478 -0
- warp/_src/fem/geometry/grid_3d.py +539 -0
- warp/_src/fem/geometry/hexmesh.py +956 -0
- warp/_src/fem/geometry/nanogrid.py +660 -0
- warp/_src/fem/geometry/partition.py +483 -0
- warp/_src/fem/geometry/quadmesh.py +597 -0
- warp/_src/fem/geometry/tetmesh.py +762 -0
- warp/_src/fem/geometry/trimesh.py +588 -0
- warp/_src/fem/integrate.py +2507 -0
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/_src/fem/quadrature/__init__.py +17 -0
- warp/_src/fem/quadrature/pic_quadrature.py +318 -0
- warp/_src/fem/quadrature/quadrature.py +665 -0
- warp/_src/fem/space/__init__.py +248 -0
- warp/_src/fem/space/basis_function_space.py +499 -0
- warp/_src/fem/space/basis_space.py +681 -0
- warp/_src/fem/space/dof_mapper.py +253 -0
- warp/_src/fem/space/function_space.py +312 -0
- warp/_src/fem/space/grid_2d_function_space.py +179 -0
- warp/_src/fem/space/grid_3d_function_space.py +229 -0
- warp/_src/fem/space/hexmesh_function_space.py +255 -0
- warp/_src/fem/space/nanogrid_function_space.py +199 -0
- warp/_src/fem/space/partition.py +435 -0
- warp/_src/fem/space/quadmesh_function_space.py +222 -0
- warp/_src/fem/space/restriction.py +221 -0
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
- warp/_src/fem/space/shape/shape_function.py +134 -0
- warp/_src/fem/space/shape/square_shape_function.py +928 -0
- warp/_src/fem/space/shape/tet_shape_function.py +829 -0
- warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
- warp/_src/fem/space/tetmesh_function_space.py +270 -0
- warp/_src/fem/space/topology.py +461 -0
- warp/_src/fem/space/trimesh_function_space.py +193 -0
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/_src/thirdparty/__init__.py +0 -0
- warp/_src/thirdparty/appdirs.py +598 -0
- warp/_src/thirdparty/dlpack.py +145 -0
- warp/_src/thirdparty/unittest_parallel.py +676 -0
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +33 -0
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +29 -0
- warp/build_dll.py +24 -0
- warp/codegen.py +24 -0
- warp/constants.py +24 -0
- warp/context.py +33 -0
- warp/dlpack.py +24 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +195 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +290 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/distributed/example_jacobi_mpi.py +506 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +469 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +181 -0
- warp/examples/fem/example_convection_diffusion_dg.py +225 -0
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +225 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +242 -0
- warp/examples/fem/example_mixed_elasticity.py +293 -0
- warp/examples/fem/example_navier_stokes.py +263 -0
- warp/examples/fem/example_nonconforming_contact.py +300 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +357 -0
- warp/examples/fem/utils.py +1047 -0
- warp/examples/interop/example_jax_callable.py +146 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +232 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +88 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/examples/tile/example_tile_mlp.py +385 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/fabric.py +24 -0
- warp/fem/__init__.py +173 -0
- warp/fem/adaptivity.py +26 -0
- warp/fem/cache.py +30 -0
- warp/fem/dirichlet.py +24 -0
- warp/fem/field/__init__.py +24 -0
- warp/fem/field/field.py +26 -0
- warp/fem/geometry/__init__.py +21 -0
- warp/fem/geometry/closest_point.py +31 -0
- warp/fem/linalg.py +38 -0
- warp/fem/operator.py +32 -0
- warp/fem/polynomial.py +29 -0
- warp/fem/space/__init__.py +22 -0
- warp/fem/space/basis_space.py +24 -0
- warp/fem/space/shape/__init__.py +68 -0
- warp/fem/space/topology.py +24 -0
- warp/fem/types.py +24 -0
- warp/fem/utils.py +32 -0
- warp/jax.py +29 -0
- warp/jax_experimental/__init__.py +29 -0
- warp/jax_experimental/custom_call.py +29 -0
- warp/jax_experimental/ffi.py +39 -0
- warp/jax_experimental/xla_ffi.py +24 -0
- warp/marching_cubes.py +24 -0
- warp/math.py +37 -0
- warp/native/array.h +1687 -0
- warp/native/builtin.h +2327 -0
- warp/native/bvh.cpp +562 -0
- warp/native/bvh.cu +826 -0
- warp/native/bvh.h +555 -0
- warp/native/clang/clang.cpp +541 -0
- warp/native/coloring.cpp +622 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +568 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +677 -0
- warp/native/cuda_util.h +313 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +2023 -0
- warp/native/fabric.h +246 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +89 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1253 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +348 -0
- warp/native/mat.h +5189 -0
- warp/native/mathdx.cpp +93 -0
- warp/native/matnn.h +221 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +406 -0
- warp/native/mesh.h +2097 -0
- warp/native/nanovdb/GridHandle.h +533 -0
- warp/native/nanovdb/HostBuffer.h +591 -0
- warp/native/nanovdb/NanoVDB.h +6246 -0
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1664 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +145 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +363 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +55 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +286 -0
- warp/native/sort.h +35 -0
- warp/native/sparse.cpp +241 -0
- warp/native/sparse.cu +435 -0
- warp/native/spatial.h +1306 -0
- warp/native/svd.h +727 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +4124 -0
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +838 -0
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +2199 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +68 -0
- warp/native/volume.h +970 -0
- warp/native/volume_builder.cu +483 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1143 -0
- warp/native/warp.cu +4604 -0
- warp/native/warp.h +358 -0
- warp/optim/__init__.py +20 -0
- warp/optim/adam.py +24 -0
- warp/optim/linear.py +35 -0
- warp/optim/sgd.py +24 -0
- warp/paddle.py +24 -0
- warp/py.typed +0 -0
- warp/render/__init__.py +22 -0
- warp/render/imgui_manager.py +29 -0
- warp/render/render_opengl.py +24 -0
- warp/render/render_usd.py +24 -0
- warp/render/utils.py +24 -0
- warp/sparse.py +51 -0
- warp/tape.py +24 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_conditional_captures.py +1147 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +691 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +335 -0
- warp/tests/geometry/test_hash_grid.py +259 -0
- warp/tests/geometry/test_marching_cubes.py +294 -0
- warp/tests/geometry/test_mesh.py +318 -0
- warp/tests/geometry/test_mesh_query_aabb.py +392 -0
- warp/tests/geometry/test_mesh_query_point.py +935 -0
- warp/tests/geometry/test_mesh_query_ray.py +323 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +730 -0
- warp/tests/interop/test_jax.py +1673 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/test_adam.py +162 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +3756 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +303 -0
- warp/tests/test_atomic.py +336 -0
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +732 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +974 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +298 -0
- warp/tests/test_context.py +35 -0
- warp/tests/test_copy.py +319 -0
- warp/tests/test_ctypes.py +618 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +127 -0
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +424 -0
- warp/tests/test_fabricarray.py +998 -0
- warp/tests/test_fast_math.py +72 -0
- warp/tests/test_fem.py +2204 -0
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +501 -0
- warp/tests/test_future_annotations.py +100 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +103 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +223 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_map.py +526 -0
- warp/tests/test_mat.py +3515 -0
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +573 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +212 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +70 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +408 -0
- warp/tests/test_quat.py +2653 -0
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +303 -0
- warp/tests/test_rounding.py +157 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +133 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +845 -0
- warp/tests/test_spatial.py +2859 -0
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +640 -0
- warp/tests/test_struct.py +901 -0
- warp/tests/test_tape.py +242 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +192 -0
- warp/tests/test_tuple.py +361 -0
- warp/tests/test_types.py +615 -0
- warp/tests/test_utils.py +594 -0
- warp/tests/test_vec.py +1408 -0
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/test_version.py +75 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +1519 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +608 -0
- warp/tests/tile/test_tile_load.py +724 -0
- warp/tests/tile/test_tile_mathdx.py +156 -0
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +400 -0
- warp/tests/tile/test_tile_reduce.py +950 -0
- warp/tests/tile/test_tile_shared_memory.py +376 -0
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +430 -0
- warp/tests/unittest_utils.py +469 -0
- warp/tests/walkthrough_debug.py +95 -0
- warp/torch.py +24 -0
- warp/types.py +51 -0
- warp/utils.py +31 -0
- warp_lang-1.10.0.dist-info/METADATA +459 -0
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/WHEEL +5 -0
- warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1112 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include "tile.h"
|
|
21
|
+
|
|
22
|
+
#if defined(__clang__)
|
|
23
|
+
// disable warnings related to C++17 extensions on CPU JIT builds
|
|
24
|
+
#pragma clang diagnostic push
|
|
25
|
+
#pragma clang diagnostic ignored "-Wc++17-extensions"
|
|
26
|
+
#endif
|
|
27
|
+
|
|
28
|
+
namespace wp
|
|
29
|
+
{
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
// After this threshold, using segmented_sort from cub is faster
|
|
33
|
+
// The threshold must be a power of 2
|
|
34
|
+
// The radix sort in this file is consistently slower than the bitonic sort
|
|
35
|
+
#define BITONIC_SORT_THRESHOLD 2048
|
|
36
|
+
|
|
37
|
+
struct UintKeyToUint
|
|
38
|
+
{
|
|
39
|
+
inline CUDA_CALLABLE uint32_t convert(uint32 value)
|
|
40
|
+
{
|
|
41
|
+
return value;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
inline CUDA_CALLABLE uint32_t max_possible_key_value()
|
|
45
|
+
{
|
|
46
|
+
return 0xFFFFFFFF;
|
|
47
|
+
}
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
struct IntKeyToUint
|
|
51
|
+
{
|
|
52
|
+
inline CUDA_CALLABLE uint32_t convert(int value)
|
|
53
|
+
{
|
|
54
|
+
// Flip the sign bit: ensures negative numbers come before positive numbers
|
|
55
|
+
return static_cast<uint32_t>(value) ^ 0x80000000;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
inline CUDA_CALLABLE int max_possible_key_value()
|
|
59
|
+
{
|
|
60
|
+
return 2147483647;
|
|
61
|
+
}
|
|
62
|
+
};
|
|
63
|
+
|
|
64
|
+
struct FloatKeyToUint
|
|
65
|
+
{
|
|
66
|
+
//http://stereopsis.com/radix.html
|
|
67
|
+
inline CUDA_CALLABLE uint32_t convert(float value)
|
|
68
|
+
{
|
|
69
|
+
unsigned int i = reinterpret_cast<unsigned int&>(value);
|
|
70
|
+
unsigned int mask = (unsigned int)(-(int)(i >> 31)) | 0x80000000;
|
|
71
|
+
return i ^ mask;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
inline CUDA_CALLABLE float max_possible_key_value()
|
|
75
|
+
{
|
|
76
|
+
return FLT_MAX;
|
|
77
|
+
}
|
|
78
|
+
};
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
constexpr inline CUDA_CALLABLE bool is_power_of_two(int x)
|
|
82
|
+
{
|
|
83
|
+
return (x & (x - 1)) == 0;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
constexpr inline CUDA_CALLABLE int next_higher_pow2(int input)
|
|
87
|
+
{
|
|
88
|
+
if (input <= 0) return 1; // Smallest power of 2 is 1
|
|
89
|
+
|
|
90
|
+
input--; // Decrement to handle already a power of 2 cases
|
|
91
|
+
input |= input >> 1;
|
|
92
|
+
input |= input >> 2;
|
|
93
|
+
input |= input >> 4;
|
|
94
|
+
input |= input >> 8;
|
|
95
|
+
input |= input >> 16;
|
|
96
|
+
input++; // Next power of 2
|
|
97
|
+
|
|
98
|
+
return input;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
#if defined(__CUDA_ARCH__)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
// Bitonic sort fast pass for small arrays
|
|
106
|
+
|
|
107
|
+
template<typename T>
|
|
108
|
+
inline CUDA_CALLABLE T shfl_xor(unsigned int thread_id, T* sh_mem, unsigned int lane_mask)
|
|
109
|
+
{
|
|
110
|
+
unsigned int source_lane = thread_id ^ lane_mask;
|
|
111
|
+
return sh_mem[source_lane];
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
template<typename K, typename V, int num_loops>
|
|
115
|
+
inline CUDA_CALLABLE void bitonic_sort_single_stage_full_thread_block(int k, unsigned int thread_id, unsigned int stride, K* key_sh_mem, V* val_sh_mem, int length, K max_key_value,
|
|
116
|
+
K* key_register, V* val_register)
|
|
117
|
+
{
|
|
118
|
+
__syncthreads();
|
|
119
|
+
#pragma unroll
|
|
120
|
+
for (int loop_id = 0; loop_id < num_loops; ++loop_id)
|
|
121
|
+
{
|
|
122
|
+
int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
|
|
123
|
+
|
|
124
|
+
key_register[loop_id] = thread_id2 < length ? key_sh_mem[thread_id2] : max_key_value;
|
|
125
|
+
val_register[loop_id] = thread_id2 < length ? val_sh_mem[thread_id2] : static_cast<V>(0);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
__syncthreads();
|
|
129
|
+
|
|
130
|
+
K s_key[num_loops];
|
|
131
|
+
V s_val[num_loops];
|
|
132
|
+
bool swap[num_loops];
|
|
133
|
+
#pragma unroll
|
|
134
|
+
for (int loop_id = 0; loop_id < num_loops; ++loop_id)
|
|
135
|
+
{
|
|
136
|
+
int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
|
|
137
|
+
|
|
138
|
+
if(thread_id2 < length)
|
|
139
|
+
{
|
|
140
|
+
s_key[loop_id] = shfl_xor(thread_id2, key_sh_mem, stride);
|
|
141
|
+
s_val[loop_id] = shfl_xor(thread_id2, val_sh_mem, stride);
|
|
142
|
+
swap[loop_id] = (((thread_id2 & stride) != 0 ? key_register[loop_id] > s_key[loop_id] : key_register[loop_id] < s_key[loop_id])) ^ ((thread_id2 & k) == 0);
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
__syncthreads();
|
|
147
|
+
|
|
148
|
+
#pragma unroll
|
|
149
|
+
for (int loop_id = 0; loop_id < num_loops; ++loop_id)
|
|
150
|
+
{
|
|
151
|
+
int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
|
|
152
|
+
if (thread_id2 < length)
|
|
153
|
+
{
|
|
154
|
+
key_sh_mem[thread_id2] = swap[loop_id] ? s_key[loop_id] : key_register[loop_id];
|
|
155
|
+
val_sh_mem[thread_id2] = swap[loop_id] ? s_val[loop_id] : val_register[loop_id];
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
__syncthreads();
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
//stride can be 1, 2, 4, 8, 16
|
|
162
|
+
template<typename K, typename V>
|
|
163
|
+
inline CUDA_CALLABLE void bitonic_sort_single_stage_full_warp(int k, unsigned int thread_id, int stride, K& key, V& val)
|
|
164
|
+
{
|
|
165
|
+
auto s_key = __shfl_xor_sync(0xFFFFFFFFu, key, stride);
|
|
166
|
+
auto s_val = __shfl_xor_sync(0xFFFFFFFFu, val, stride);
|
|
167
|
+
auto swap = (((thread_id & stride) != 0 ? key > s_key : key < s_key)) ^ ((thread_id & k) == 0);
|
|
168
|
+
key = swap ? s_key : key;
|
|
169
|
+
val = swap ? s_val : val;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
//Sorts 32 elements according to keys
|
|
174
|
+
template<typename K, typename V>
|
|
175
|
+
inline CUDA_CALLABLE void bitonic_sort_single_warp(unsigned int thread_id, K& key, V& val)
|
|
176
|
+
{
|
|
177
|
+
#pragma unroll
|
|
178
|
+
for (int k = 2; k <= 32; k <<= 1)
|
|
179
|
+
{
|
|
180
|
+
#pragma unroll
|
|
181
|
+
for (int stride = k / 2; stride > 0; stride >>= 1)
|
|
182
|
+
{
|
|
183
|
+
bitonic_sort_single_stage_full_warp(k, thread_id, stride, key, val);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
template<typename K, typename V, typename KeyToUint>
|
|
189
|
+
inline CUDA_CALLABLE void bitonic_sort_single_warp(int thread_id,
|
|
190
|
+
K* keys_input,
|
|
191
|
+
V* values_input,
|
|
192
|
+
int num_elements_to_sort)
|
|
193
|
+
{
|
|
194
|
+
KeyToUint key_converter;
|
|
195
|
+
|
|
196
|
+
__syncwarp();
|
|
197
|
+
|
|
198
|
+
K key = thread_id < num_elements_to_sort ? keys_input[thread_id] : key_converter.max_possible_key_value();
|
|
199
|
+
V value;
|
|
200
|
+
if(thread_id < num_elements_to_sort)
|
|
201
|
+
value = values_input[thread_id];
|
|
202
|
+
|
|
203
|
+
__syncwarp();
|
|
204
|
+
bitonic_sort_single_warp(thread_id, key, value);
|
|
205
|
+
__syncwarp();
|
|
206
|
+
|
|
207
|
+
if(thread_id < num_elements_to_sort)
|
|
208
|
+
{
|
|
209
|
+
keys_input[thread_id] = key;
|
|
210
|
+
values_input[thread_id] = value;
|
|
211
|
+
}
|
|
212
|
+
__syncwarp();
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
//Sorts according to keys
|
|
217
|
+
template<int max_num_elements, typename K, typename V>
|
|
218
|
+
inline CUDA_CALLABLE void bitonic_sort_pow2_length(unsigned int thread_id, K* key_sh_mem, V* val_sh_mem, int length, K key_max_possible_value)
|
|
219
|
+
{
|
|
220
|
+
constexpr int num_loops = (max_num_elements + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
|
|
221
|
+
K key[num_loops];
|
|
222
|
+
V val[num_loops];
|
|
223
|
+
|
|
224
|
+
#pragma unroll
|
|
225
|
+
for (int loop_id = 0; loop_id < num_loops; ++loop_id)
|
|
226
|
+
{
|
|
227
|
+
int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
|
|
228
|
+
key[loop_id] = thread_id2 < length ? key_sh_mem[thread_id2] : key_max_possible_value;
|
|
229
|
+
if (thread_id2 < length)
|
|
230
|
+
val[loop_id] = val_sh_mem[thread_id2];
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
__syncthreads();
|
|
234
|
+
bool full_block_sort_active = false;
|
|
235
|
+
|
|
236
|
+
for (int k = 2; k <= length; k <<= 1)
|
|
237
|
+
{
|
|
238
|
+
for (int stride = k / 2; stride > 0; stride >>= 1)
|
|
239
|
+
{
|
|
240
|
+
if (stride <= 16) //no inter-warp communication needed up to stride 16
|
|
241
|
+
{
|
|
242
|
+
if(full_block_sort_active)
|
|
243
|
+
{
|
|
244
|
+
__syncthreads();
|
|
245
|
+
#pragma unroll
|
|
246
|
+
for (int loop_id = 0; loop_id < num_loops; ++loop_id)
|
|
247
|
+
{
|
|
248
|
+
int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
|
|
249
|
+
|
|
250
|
+
//Switch from shared mem to registers
|
|
251
|
+
if (thread_id2 < length)
|
|
252
|
+
{
|
|
253
|
+
key[loop_id] = key_sh_mem[thread_id2];
|
|
254
|
+
val[loop_id] = val_sh_mem[thread_id2];
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
full_block_sort_active = false;
|
|
258
|
+
__syncthreads();
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
#pragma unroll
|
|
262
|
+
for (int loop_id = 0; loop_id < num_loops; ++loop_id)
|
|
263
|
+
{
|
|
264
|
+
int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
|
|
265
|
+
bitonic_sort_single_stage_full_warp(k, thread_id2, stride, key[loop_id], val[loop_id]);
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
else
|
|
269
|
+
{
|
|
270
|
+
if (!full_block_sort_active)
|
|
271
|
+
{
|
|
272
|
+
__syncthreads();
|
|
273
|
+
#pragma unroll
|
|
274
|
+
for (int loop_id = 0; loop_id < num_loops; ++loop_id)
|
|
275
|
+
{
|
|
276
|
+
int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
|
|
277
|
+
|
|
278
|
+
//Switch from registers t0 shared mem
|
|
279
|
+
if (thread_id2 < length)
|
|
280
|
+
{
|
|
281
|
+
key_sh_mem[thread_id2] = key[loop_id];
|
|
282
|
+
val_sh_mem[thread_id2] = val[loop_id];
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
full_block_sort_active = true;
|
|
286
|
+
__syncthreads();
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
bitonic_sort_single_stage_full_thread_block<K, V, num_loops>(k, thread_id, (unsigned int)stride, key_sh_mem, val_sh_mem, length, key_max_possible_value, key, val);
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
if (!full_block_sort_active)
|
|
295
|
+
{
|
|
296
|
+
#pragma unroll
|
|
297
|
+
for (int loop_id = 0; loop_id < num_loops; ++loop_id)
|
|
298
|
+
{
|
|
299
|
+
int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
|
|
300
|
+
//Switch from registers t0 shared mem
|
|
301
|
+
if (thread_id2 < length)
|
|
302
|
+
{
|
|
303
|
+
key_sh_mem[thread_id2] = key[loop_id];
|
|
304
|
+
val_sh_mem[thread_id2] = val[loop_id];
|
|
305
|
+
}
|
|
306
|
+
}
|
|
307
|
+
full_block_sort_active = true;
|
|
308
|
+
__syncthreads();
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
//Allocates shared memory to buffer the arrays that need to be sorted
|
|
313
|
+
template <int max_num_elements, typename K, typename V, typename KeyToUint>
|
|
314
|
+
inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
|
|
315
|
+
int thread_id,
|
|
316
|
+
K* keys_input,
|
|
317
|
+
V* values_input,
|
|
318
|
+
int num_elements_to_sort)
|
|
319
|
+
{
|
|
320
|
+
if constexpr(max_num_elements < 32)
|
|
321
|
+
{
|
|
322
|
+
//Fast track - single warp sort
|
|
323
|
+
if (thread_id < 32)
|
|
324
|
+
bitonic_sort_single_warp<K, V, KeyToUint>(thread_id, keys_input, values_input, num_elements_to_sort);
|
|
325
|
+
__syncthreads();
|
|
326
|
+
}
|
|
327
|
+
else
|
|
328
|
+
{
|
|
329
|
+
KeyToUint key_converter;
|
|
330
|
+
const K key_max_possible_value = key_converter.max_possible_key_value();
|
|
331
|
+
|
|
332
|
+
constexpr int shared_mem_count = next_higher_pow2(max_num_elements);
|
|
333
|
+
|
|
334
|
+
__shared__ K keys_shared_mem[shared_mem_count]; //TODO: This shared memory can be avoided if keys_input is already shared memory
|
|
335
|
+
__shared__ V values_shared_mem[shared_mem_count]; //TODO: This shared memory can be avoided if values_input is already shared memory
|
|
336
|
+
|
|
337
|
+
for(int i = thread_id; i < shared_mem_count; i += WP_TILE_BLOCK_DIM)
|
|
338
|
+
{
|
|
339
|
+
if (i < num_elements_to_sort)
|
|
340
|
+
{
|
|
341
|
+
keys_shared_mem[i] = keys_input[i];
|
|
342
|
+
values_shared_mem[i] = values_input[i];
|
|
343
|
+
}
|
|
344
|
+
else
|
|
345
|
+
{
|
|
346
|
+
// Note that these values may end up in the output If enough NaN or Inf values are present in keys_input
|
|
347
|
+
keys_shared_mem[i] = key_max_possible_value;
|
|
348
|
+
values_shared_mem[i] = static_cast<V>(0);
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
__syncthreads();
|
|
352
|
+
|
|
353
|
+
bitonic_sort_pow2_length<shared_mem_count, K, V>((unsigned int)thread_id, keys_shared_mem, values_shared_mem, shared_mem_count, key_max_possible_value);
|
|
354
|
+
|
|
355
|
+
__syncthreads();
|
|
356
|
+
|
|
357
|
+
for (int i = thread_id; i < num_elements_to_sort; i += WP_TILE_BLOCK_DIM)
|
|
358
|
+
{
|
|
359
|
+
keys_input[i] = keys_shared_mem[i];
|
|
360
|
+
values_input[i] = values_shared_mem[i];
|
|
361
|
+
}
|
|
362
|
+
__syncthreads();
|
|
363
|
+
}
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
// Specialization for int keys
|
|
368
|
+
template <int max_num_elements, typename V>
|
|
369
|
+
inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
|
|
370
|
+
int thread_id,
|
|
371
|
+
int* keys_input,
|
|
372
|
+
V* values_input,
|
|
373
|
+
int num_elements_to_sort)
|
|
374
|
+
{
|
|
375
|
+
bitonic_sort_thread_block_shared_mem<max_num_elements, int, V, IntKeyToUint>(
|
|
376
|
+
thread_id, keys_input, values_input, num_elements_to_sort);
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
// Specialization for unsigned int keys
|
|
380
|
+
template <int max_num_elements, typename V>
|
|
381
|
+
inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
|
|
382
|
+
int thread_id,
|
|
383
|
+
unsigned int* keys_input,
|
|
384
|
+
V* values_input,
|
|
385
|
+
int num_elements_to_sort)
|
|
386
|
+
{
|
|
387
|
+
bitonic_sort_thread_block_shared_mem<max_num_elements, unsigned int, V, UintKeyToUint>(
|
|
388
|
+
thread_id, keys_input, values_input, num_elements_to_sort);
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
// Specialization for float keys
|
|
392
|
+
template <int max_num_elements, typename V>
|
|
393
|
+
inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
|
|
394
|
+
int thread_id,
|
|
395
|
+
float* keys_input,
|
|
396
|
+
V* values_input,
|
|
397
|
+
int num_elements_to_sort)
|
|
398
|
+
{
|
|
399
|
+
bitonic_sort_thread_block_shared_mem<max_num_elements, float, V, FloatKeyToUint>(
|
|
400
|
+
thread_id, keys_input, values_input, num_elements_to_sort);
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
// Ideally keys_input and values_input point into fast memory (shared memory)
|
|
406
|
+
template <int max_num_elements, typename K, typename V, typename KeyToUint>
|
|
407
|
+
inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
|
|
408
|
+
int thread_id,
|
|
409
|
+
K* keys_input,
|
|
410
|
+
V* values_input,
|
|
411
|
+
int num_elements_to_sort)
|
|
412
|
+
{
|
|
413
|
+
if constexpr(max_num_elements < 32)
|
|
414
|
+
{
|
|
415
|
+
//Fast track - single warp sort
|
|
416
|
+
if (thread_id < 32)
|
|
417
|
+
bitonic_sort_single_warp<K, V, KeyToUint>(thread_id, keys_input, values_input, num_elements_to_sort);
|
|
418
|
+
__syncthreads();
|
|
419
|
+
}
|
|
420
|
+
else
|
|
421
|
+
{
|
|
422
|
+
assert(num_elements_to_sort <= max_num_elements);
|
|
423
|
+
|
|
424
|
+
KeyToUint key_converter;
|
|
425
|
+
const K key_max_possible_value = key_converter.max_possible_key_value();
|
|
426
|
+
|
|
427
|
+
bitonic_sort_pow2_length<max_num_elements, K, V>((unsigned int)thread_id, keys_input, values_input, num_elements_to_sort, key_max_possible_value);
|
|
428
|
+
}
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
// Specialization for int keys
|
|
432
|
+
template <int max_num_elements, typename V>
|
|
433
|
+
inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
|
|
434
|
+
int thread_id,
|
|
435
|
+
int* keys_input,
|
|
436
|
+
V* values_input,
|
|
437
|
+
int num_elements_to_sort)
|
|
438
|
+
{
|
|
439
|
+
bitonic_sort_thread_block_direct<max_num_elements, int, V, IntKeyToUint>(
|
|
440
|
+
thread_id, keys_input, values_input, num_elements_to_sort);
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
// Specialization for unsigned int keys
|
|
444
|
+
template <int max_num_elements, typename V>
|
|
445
|
+
inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
|
|
446
|
+
int thread_id,
|
|
447
|
+
unsigned int* keys_input,
|
|
448
|
+
V* values_input,
|
|
449
|
+
int num_elements_to_sort)
|
|
450
|
+
{
|
|
451
|
+
bitonic_sort_thread_block_direct<max_num_elements, unsigned int, V, UintKeyToUint>(
|
|
452
|
+
thread_id, keys_input, values_input, num_elements_to_sort);
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
// Specialization for float keys
|
|
456
|
+
template <int max_num_elements, typename V>
|
|
457
|
+
inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
|
|
458
|
+
int thread_id,
|
|
459
|
+
float* keys_input,
|
|
460
|
+
V* values_input,
|
|
461
|
+
int num_elements_to_sort)
|
|
462
|
+
{
|
|
463
|
+
bitonic_sort_thread_block_direct<max_num_elements, float, V, FloatKeyToUint>(
|
|
464
|
+
thread_id, keys_input, values_input, num_elements_to_sort);
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
// End bitonic sort
|
|
468
|
+
|
|
469
|
+
inline CUDA_CALLABLE int warp_scan_inclusive(int lane, unsigned int ballot_mask)
|
|
470
|
+
{
|
|
471
|
+
uint32_t mask = ((1u << (lane + 1)) - 1);
|
|
472
|
+
return __popc(ballot_mask & mask);
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
inline CUDA_CALLABLE int warp_scan_inclusive(int lane, unsigned int mask, bool thread_contributes_element)
|
|
476
|
+
{
|
|
477
|
+
return warp_scan_inclusive(lane, __ballot_sync(mask, thread_contributes_element));
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
template<typename T>
|
|
481
|
+
inline CUDA_CALLABLE T warp_scan_inclusive(int lane, T value)
|
|
482
|
+
{
|
|
483
|
+
//Computes an inclusive cumulative sum
|
|
484
|
+
#pragma unroll
|
|
485
|
+
for (int i = 1; i <= 32; i *= 2)
|
|
486
|
+
{
|
|
487
|
+
auto n = __shfl_up_sync(0xffffffffu, value, i, 32);
|
|
488
|
+
|
|
489
|
+
if (lane >= i)
|
|
490
|
+
value = value + n;
|
|
491
|
+
}
|
|
492
|
+
return value;
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
template<typename T>
|
|
496
|
+
inline CUDA_CALLABLE T warp_scan_exclusive(int lane, T value)
|
|
497
|
+
{
|
|
498
|
+
T scan = warp_scan_inclusive(lane, value);
|
|
499
|
+
return scan - value;
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
template <int num_warps, int num_threads, typename K, typename V, typename KeyToUint>
|
|
503
|
+
inline CUDA_CALLABLE void radix_sort_thread_block_core(
|
|
504
|
+
int thread_id,
|
|
505
|
+
K* keys_input, K* keys_tmp,
|
|
506
|
+
V* values_input, V* values_tmp,
|
|
507
|
+
int num_elements_to_sort)
|
|
508
|
+
{
|
|
509
|
+
KeyToUint key_converter;
|
|
510
|
+
|
|
511
|
+
int num_bits_to_sort = 32; //Sort all bits because that's what the bitonic fast pass does as well
|
|
512
|
+
|
|
513
|
+
const int warp_id = thread_id / 32;
|
|
514
|
+
const int lane_id = thread_id & 31;
|
|
515
|
+
|
|
516
|
+
const int bits_per_pass = 4; //Higher than 5 is currently not supported - 2^5=32 is the warp size and is still just fine
|
|
517
|
+
const int lowest_bits_mask = (1 << bits_per_pass) - 1;
|
|
518
|
+
const int num_scan_buckets = (1 << bits_per_pass);
|
|
519
|
+
|
|
520
|
+
const int num_warp_passes = (num_scan_buckets + num_warps - 1) / num_warps;
|
|
521
|
+
|
|
522
|
+
__shared__ int buckets[num_scan_buckets];
|
|
523
|
+
__shared__ int buckets2[num_scan_buckets];
|
|
524
|
+
__shared__ int buckets_cumulative_sum[num_scan_buckets];
|
|
525
|
+
__shared__ int shared_mem[num_warps][num_scan_buckets];
|
|
526
|
+
|
|
527
|
+
const int num_passes = (num_bits_to_sort + bits_per_pass - 1) / bits_per_pass;
|
|
528
|
+
const int num_inner_loops = (num_elements_to_sort + num_threads - 1) / num_threads;
|
|
529
|
+
|
|
530
|
+
for (int pass_id = 0; pass_id < num_passes; ++pass_id)
|
|
531
|
+
{
|
|
532
|
+
__syncthreads();
|
|
533
|
+
if (thread_id < num_scan_buckets)
|
|
534
|
+
{
|
|
535
|
+
buckets[lane_id] = 0;
|
|
536
|
+
buckets2[lane_id] = 0;
|
|
537
|
+
}
|
|
538
|
+
__syncthreads();
|
|
539
|
+
|
|
540
|
+
int shift = pass_id * bits_per_pass;
|
|
541
|
+
|
|
542
|
+
for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
|
|
543
|
+
{
|
|
544
|
+
int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
|
|
545
|
+
|
|
546
|
+
for (int b = 0; b < num_scan_buckets; b++)
|
|
547
|
+
{
|
|
548
|
+
bool contributes = digit == b;
|
|
549
|
+
int sum_per_warp = warp_scan_inclusive(lane_id, 0xFFFFFFFF, contributes);
|
|
550
|
+
|
|
551
|
+
if (lane_id == 31)
|
|
552
|
+
shared_mem[warp_id][b] = sum_per_warp;
|
|
553
|
+
}
|
|
554
|
+
__syncthreads();
|
|
555
|
+
|
|
556
|
+
for(int b=warp_id;b< num_warp_passes * num_warps;b += num_warps)
|
|
557
|
+
{
|
|
558
|
+
int f = lane_id < num_warps ? shared_mem[lane_id][b] : 0;
|
|
559
|
+
f = warp_scan_inclusive(lane_id, f);
|
|
560
|
+
if (lane_id == 31)
|
|
561
|
+
buckets[b] += f;
|
|
562
|
+
}
|
|
563
|
+
__syncthreads();
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
#if VALIDATE_SORT
|
|
567
|
+
if (thread_id == 0)
|
|
568
|
+
{
|
|
569
|
+
for (int b = 0; b < num_scan_buckets; b++)
|
|
570
|
+
{
|
|
571
|
+
int bucket_sum = 0;
|
|
572
|
+
for (int j = 0; j < num_elements_to_sort; j++)
|
|
573
|
+
{
|
|
574
|
+
int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
|
|
575
|
+
if (digit == b)
|
|
576
|
+
++bucket_sum;
|
|
577
|
+
}
|
|
578
|
+
assert(buckets[b] == bucket_sum);
|
|
579
|
+
}
|
|
580
|
+
}
|
|
581
|
+
__syncthreads();
|
|
582
|
+
#endif
|
|
583
|
+
|
|
584
|
+
if (warp_id == 0)
|
|
585
|
+
{
|
|
586
|
+
int value = lane_id < num_scan_buckets ? buckets[lane_id] : 0;
|
|
587
|
+
value = warp_scan_exclusive(lane_id, value);
|
|
588
|
+
if (lane_id < num_scan_buckets)
|
|
589
|
+
buckets_cumulative_sum[lane_id] = value;
|
|
590
|
+
|
|
591
|
+
if (lane_id == num_scan_buckets - 1)
|
|
592
|
+
assert(value == num_elements_to_sort);
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
__syncthreads();
|
|
596
|
+
|
|
597
|
+
#if VALIDATE_SORT
|
|
598
|
+
if(thread_id == 0)
|
|
599
|
+
{
|
|
600
|
+
for (int b = 0; b < num_scan_buckets; b++)
|
|
601
|
+
{
|
|
602
|
+
int bucket_sum = 0;
|
|
603
|
+
for(int j=0; j<num_elements_to_sort; j++)
|
|
604
|
+
{
|
|
605
|
+
int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
|
|
606
|
+
if (digit == b)
|
|
607
|
+
++bucket_sum;
|
|
608
|
+
}
|
|
609
|
+
assert(buckets[b] == bucket_sum);
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
int exclusive_bucket_sum = 0;
|
|
613
|
+
for (int b = 0; b < num_scan_buckets; b++)
|
|
614
|
+
{
|
|
615
|
+
assert(exclusive_bucket_sum == buckets_cumulative_sum[b]);
|
|
616
|
+
exclusive_bucket_sum += buckets[b];
|
|
617
|
+
}
|
|
618
|
+
assert(exclusive_bucket_sum == num_elements_to_sort);
|
|
619
|
+
}
|
|
620
|
+
__syncthreads();
|
|
621
|
+
#endif
|
|
622
|
+
|
|
623
|
+
//Now buckets holds numBuckets inclusive cumulative sums (e. g. 16 sums for 4 bit radix sort - 2^4=16)
|
|
624
|
+
//The problem is that we either store local_offset_per_thread for every element array (potentially many) or we recompute it again
|
|
625
|
+
for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
|
|
626
|
+
{
|
|
627
|
+
int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
|
|
628
|
+
|
|
629
|
+
int local_offset_per_thread = 0;
|
|
630
|
+
|
|
631
|
+
for (int b = 0; b < num_scan_buckets; b++)
|
|
632
|
+
{
|
|
633
|
+
bool contributes = digit == b;
|
|
634
|
+
int sum_per_warp = warp_scan_inclusive(lane_id, 0xFFFFFFFF, contributes);
|
|
635
|
+
if (lane_id == 31)
|
|
636
|
+
shared_mem[warp_id][b] = sum_per_warp;
|
|
637
|
+
|
|
638
|
+
if (contributes)
|
|
639
|
+
local_offset_per_thread = sum_per_warp - 1; //-1 because of inclusive scan and local_offset_per_thread needs exclusive scan
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
for (int b = 0; b < num_scan_buckets; b++)
|
|
643
|
+
{
|
|
644
|
+
__syncthreads();
|
|
645
|
+
int global_offset = buckets2[b];
|
|
646
|
+
__syncthreads();
|
|
647
|
+
|
|
648
|
+
int f = lane_id < num_warps ? shared_mem[lane_id][b] : 0;
|
|
649
|
+
int inclusive_scan = warp_scan_inclusive(lane_id, f);
|
|
650
|
+
if (lane_id == 31 && warp_id == 0)
|
|
651
|
+
{
|
|
652
|
+
buckets2[b] += inclusive_scan;
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
int warp_offset = __shfl_sync(0xFFFFFFFF, inclusive_scan - f, warp_id); //-f because warp_offset needs to be an exclusive scan
|
|
656
|
+
|
|
657
|
+
bool contributes = digit == b;
|
|
658
|
+
if (contributes)
|
|
659
|
+
{
|
|
660
|
+
local_offset_per_thread += global_offset + warp_offset;
|
|
661
|
+
|
|
662
|
+
#if VALIDATE_SORT
|
|
663
|
+
int curr = buckets_cumulative_sum[b];
|
|
664
|
+
int next = b + 1 < num_scan_buckets ? buckets_cumulative_sum[b + 1] : num_elements_to_sort;
|
|
665
|
+
assert(local_offset_per_thread < next - curr && local_offset_per_thread >= 0);
|
|
666
|
+
#endif
|
|
667
|
+
}
|
|
668
|
+
}
|
|
669
|
+
__syncthreads();
|
|
670
|
+
|
|
671
|
+
if (j < num_elements_to_sort)
|
|
672
|
+
{
|
|
673
|
+
int final_offset = buckets_cumulative_sum[digit] + local_offset_per_thread;
|
|
674
|
+
|
|
675
|
+
keys_tmp[final_offset] = keys_input[j];
|
|
676
|
+
values_tmp[final_offset] = values_input[j];
|
|
677
|
+
}
|
|
678
|
+
}
|
|
679
|
+
|
|
680
|
+
__syncthreads();
|
|
681
|
+
|
|
682
|
+
#if VALIDATE_SORT
|
|
683
|
+
for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
|
|
684
|
+
{
|
|
685
|
+
if(j>0 && j < num_elements_to_sort)
|
|
686
|
+
{
|
|
687
|
+
int digit1 = (int)((keys_tmp[j-1] >> shift) & lowest_bits_mask);
|
|
688
|
+
int digit2 = (int)((keys_tmp[j] >> shift) & lowest_bits_mask);
|
|
689
|
+
|
|
690
|
+
assert(digit1<=digit2);
|
|
691
|
+
}
|
|
692
|
+
}
|
|
693
|
+
__syncthreads();
|
|
694
|
+
#endif
|
|
695
|
+
|
|
696
|
+
auto tmp = keys_tmp;
|
|
697
|
+
keys_tmp = keys_input;
|
|
698
|
+
keys_input = tmp;
|
|
699
|
+
|
|
700
|
+
auto tmp2 = values_tmp;
|
|
701
|
+
values_tmp = values_input;
|
|
702
|
+
values_input = tmp2;
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
//For odd number of passes, the result is the const& wrong array - copy it over
|
|
706
|
+
if (num_passes % 2 != 0)
|
|
707
|
+
{
|
|
708
|
+
for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
|
|
709
|
+
{
|
|
710
|
+
if (j < num_elements_to_sort)
|
|
711
|
+
{
|
|
712
|
+
keys_tmp[j] = keys_input[j];
|
|
713
|
+
values_tmp[j] = values_input[j];
|
|
714
|
+
}
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
auto tmp = keys_tmp;
|
|
718
|
+
keys_tmp = keys_input;
|
|
719
|
+
keys_input = tmp;
|
|
720
|
+
|
|
721
|
+
auto tmp2 = values_tmp;
|
|
722
|
+
values_tmp = values_input;
|
|
723
|
+
values_input = tmp2;
|
|
724
|
+
}
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
template <int num_warps, int num_threads, typename V>
|
|
731
|
+
inline CUDA_CALLABLE void radix_sort_thread_block(
|
|
732
|
+
int thread_id,
|
|
733
|
+
int* keys_input, int* keys_tmp,
|
|
734
|
+
V* values_input, V* values_tmp,
|
|
735
|
+
int num_elements_to_sort)
|
|
736
|
+
{
|
|
737
|
+
radix_sort_thread_block_core<num_warps, num_threads, int, V, IntKeyToUint>(
|
|
738
|
+
thread_id, keys_input, keys_tmp,
|
|
739
|
+
values_input, values_tmp, num_elements_to_sort);
|
|
740
|
+
}
|
|
741
|
+
|
|
742
|
+
template <int num_warps, int num_threads, typename V>
|
|
743
|
+
inline CUDA_CALLABLE void radix_sort_thread_block(
|
|
744
|
+
int thread_id,
|
|
745
|
+
unsigned int* keys_input, unsigned int* keys_tmp,
|
|
746
|
+
V* values_input, V* values_tmp,
|
|
747
|
+
int num_elements_to_sort)
|
|
748
|
+
{
|
|
749
|
+
radix_sort_thread_block_core<num_warps, num_threads, unsigned int, V, UintKeyToUint>(
|
|
750
|
+
thread_id, keys_input, keys_tmp,
|
|
751
|
+
values_input, values_tmp,
|
|
752
|
+
num_elements_to_sort);
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
template <int num_warps, int num_threads, typename V>
|
|
756
|
+
inline CUDA_CALLABLE void radix_sort_thread_block(
|
|
757
|
+
int thread_id,
|
|
758
|
+
float* keys_input, float* keys_tmp,
|
|
759
|
+
V* values_input, V* values_tmp,
|
|
760
|
+
int num_elements_to_sort)
|
|
761
|
+
{
|
|
762
|
+
radix_sort_thread_block_core<num_warps, num_threads, float, V, FloatKeyToUint>(
|
|
763
|
+
thread_id, keys_input, keys_tmp,
|
|
764
|
+
values_input, values_tmp,
|
|
765
|
+
num_elements_to_sort);
|
|
766
|
+
}
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
template <typename TileK, typename TileV>
|
|
770
|
+
void tile_sort(TileK& t, TileV& t2)
|
|
771
|
+
{
|
|
772
|
+
using T = typename TileK::Type;
|
|
773
|
+
using V = typename TileV::Type;
|
|
774
|
+
|
|
775
|
+
constexpr int num_elements_to_sort = TileK::Layout::Shape::size();
|
|
776
|
+
T* keys = &t.data(0);
|
|
777
|
+
V* values = &t2.data(0);
|
|
778
|
+
|
|
779
|
+
//Trim away the code that won't be used - possible because the number of elements to sort is known at compile time
|
|
780
|
+
if constexpr (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
|
|
781
|
+
{
|
|
782
|
+
if constexpr(is_power_of_two(num_elements_to_sort))
|
|
783
|
+
bitonic_sort_thread_block_direct<num_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
|
|
784
|
+
else
|
|
785
|
+
bitonic_sort_thread_block_shared_mem<num_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
|
|
786
|
+
}
|
|
787
|
+
else
|
|
788
|
+
{
|
|
789
|
+
__shared__ T keys_tmp[num_elements_to_sort];
|
|
790
|
+
__shared__ V values_tmp[num_elements_to_sort];
|
|
791
|
+
|
|
792
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
|
|
793
|
+
|
|
794
|
+
radix_sort_thread_block<warp_count, WP_TILE_BLOCK_DIM, V>(WP_TILE_THREAD_IDX, keys, keys_tmp,
|
|
795
|
+
values, values_tmp, num_elements_to_sort);
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
WP_TILE_SYNC();
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
template <typename TileK, typename TileV>
|
|
802
|
+
void tile_sort(TileK& t, TileV& t2, int start, int length)
|
|
803
|
+
{
|
|
804
|
+
using T = typename TileK::Type;
|
|
805
|
+
using V = typename TileV::Type;
|
|
806
|
+
|
|
807
|
+
constexpr int max_elements_to_sort = TileK::Layout::Shape::size();
|
|
808
|
+
const int num_elements_to_sort = length;
|
|
809
|
+
T* keys = &t.data(start);
|
|
810
|
+
V* values = &t2.data(start);
|
|
811
|
+
|
|
812
|
+
if (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
|
|
813
|
+
{
|
|
814
|
+
if (is_power_of_two(num_elements_to_sort))
|
|
815
|
+
bitonic_sort_thread_block_direct<max_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
|
|
816
|
+
else
|
|
817
|
+
bitonic_sort_thread_block_shared_mem<max_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
|
|
818
|
+
}
|
|
819
|
+
else
|
|
820
|
+
{
|
|
821
|
+
if constexpr (max_elements_to_sort > BITONIC_SORT_THRESHOLD)
|
|
822
|
+
{
|
|
823
|
+
__shared__ T keys_tmp[max_elements_to_sort];
|
|
824
|
+
__shared__ V values_tmp[max_elements_to_sort];
|
|
825
|
+
|
|
826
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
|
|
827
|
+
|
|
828
|
+
radix_sort_thread_block<warp_count, WP_TILE_BLOCK_DIM, V>(WP_TILE_THREAD_IDX, keys, keys_tmp,
|
|
829
|
+
values, values_tmp, num_elements_to_sort);
|
|
830
|
+
}
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
WP_TILE_SYNC();
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
#else
|
|
837
|
+
|
|
838
|
+
// CPU implementation
|
|
839
|
+
|
|
840
|
+
template <typename K>
|
|
841
|
+
void swap_elements(K& a, K& b)
|
|
842
|
+
{
|
|
843
|
+
K tmp = a;
|
|
844
|
+
a = b;
|
|
845
|
+
b = tmp;
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
// length must be a power of two
|
|
849
|
+
template <typename K, typename V>
|
|
850
|
+
void bitonic_sort_pairs_pow2_length_cpu(K* keys, V* values, int length)
|
|
851
|
+
{
|
|
852
|
+
for (int k = 2; k <= length; k *= 2)
|
|
853
|
+
{
|
|
854
|
+
for (int stride = k / 2; stride > 0; stride /= 2)
|
|
855
|
+
{
|
|
856
|
+
for (int i = 0; i < length; i++)
|
|
857
|
+
{
|
|
858
|
+
int swap_idx = i ^ stride;
|
|
859
|
+
if (swap_idx > i)
|
|
860
|
+
{
|
|
861
|
+
bool ascending = ((i & k) == 0);
|
|
862
|
+
if ((ascending && keys[i] > keys[swap_idx]) || (!ascending && keys[i] < keys[swap_idx]))
|
|
863
|
+
{
|
|
864
|
+
swap_elements(keys[i], keys[swap_idx]);
|
|
865
|
+
swap_elements(values[i], values[swap_idx]);
|
|
866
|
+
}
|
|
867
|
+
}
|
|
868
|
+
}
|
|
869
|
+
}
|
|
870
|
+
}
|
|
871
|
+
}
|
|
872
|
+
|
|
873
|
+
template <typename K, typename V, int max_size, typename KeyToUint>
|
|
874
|
+
void bitonic_sort_pairs_general_size_cpu(K* keys, V* values, int length)
|
|
875
|
+
{
|
|
876
|
+
constexpr int pow2_size = next_higher_pow2(max_size);
|
|
877
|
+
|
|
878
|
+
K keys_tmp[pow2_size];
|
|
879
|
+
V values_tmp[pow2_size];
|
|
880
|
+
|
|
881
|
+
KeyToUint converter;
|
|
882
|
+
K max_key = converter.max_possible_key_value();
|
|
883
|
+
|
|
884
|
+
for(int i=0; i<pow2_size; ++i)
|
|
885
|
+
{
|
|
886
|
+
keys_tmp[i] = i < length ? keys[i] : max_key;
|
|
887
|
+
if(i < length)
|
|
888
|
+
values_tmp[i] = values[i];
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
bitonic_sort_pairs_pow2_length_cpu(keys_tmp, values_tmp, pow2_size);
|
|
892
|
+
|
|
893
|
+
for(int i=0; i<length; ++i)
|
|
894
|
+
{
|
|
895
|
+
keys[i] = keys_tmp[i];
|
|
896
|
+
values[i] = values_tmp[i];
|
|
897
|
+
}
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
template <typename V, int max_size>
|
|
901
|
+
void bitonic_sort_pairs_general_size_cpu(unsigned int* keys, V* values, int length)
|
|
902
|
+
{
|
|
903
|
+
bitonic_sort_pairs_general_size_cpu<unsigned int, V, max_size, UintKeyToUint>(keys, values, length);
|
|
904
|
+
}
|
|
905
|
+
|
|
906
|
+
template <typename V, int max_size>
|
|
907
|
+
void bitonic_sort_pairs_general_size_cpu(int* keys, V* values, int length)
|
|
908
|
+
{
|
|
909
|
+
bitonic_sort_pairs_general_size_cpu<int, V, max_size, IntKeyToUint>(keys, values, length);
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
template <typename V, int max_size>
|
|
913
|
+
void bitonic_sort_pairs_general_size_cpu(float* keys, V* values, int length)
|
|
914
|
+
{
|
|
915
|
+
bitonic_sort_pairs_general_size_cpu<float, V, max_size, FloatKeyToUint>(keys, values, length);
|
|
916
|
+
}
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
template <typename K, typename V, typename KeyToUint>
|
|
921
|
+
void radix_sort_pairs_cpu_core(K* keys, K* aux_keys, V* values, V* aux_values, int n)
|
|
922
|
+
{
|
|
923
|
+
KeyToUint converter;
|
|
924
|
+
unsigned int tables[2][1 << 16];
|
|
925
|
+
memset(tables, 0, sizeof(tables));
|
|
926
|
+
|
|
927
|
+
// build histograms
|
|
928
|
+
for (int i=0; i < n; ++i)
|
|
929
|
+
{
|
|
930
|
+
const unsigned int k = converter.convert(keys[i]);
|
|
931
|
+
const unsigned short low = k & 0xffff;
|
|
932
|
+
const unsigned short high = k >> 16;
|
|
933
|
+
|
|
934
|
+
++tables[0][low];
|
|
935
|
+
++tables[1][high];
|
|
936
|
+
}
|
|
937
|
+
|
|
938
|
+
// convert histograms to offset tables in-place
|
|
939
|
+
unsigned int offlow = 0;
|
|
940
|
+
unsigned int offhigh = 0;
|
|
941
|
+
|
|
942
|
+
for (int i=0; i < 65536; ++i)
|
|
943
|
+
{
|
|
944
|
+
const unsigned int newofflow = offlow + tables[0][i];
|
|
945
|
+
const unsigned int newoffhigh = offhigh + tables[1][i];
|
|
946
|
+
|
|
947
|
+
tables[0][i] = offlow;
|
|
948
|
+
tables[1][i] = offhigh;
|
|
949
|
+
|
|
950
|
+
offlow = newofflow;
|
|
951
|
+
offhigh = newoffhigh;
|
|
952
|
+
}
|
|
953
|
+
|
|
954
|
+
// pass 1 - sort by low 16 bits
|
|
955
|
+
for (int i=0; i < n; ++i)
|
|
956
|
+
{
|
|
957
|
+
// lookup offset of input
|
|
958
|
+
const K f = keys[i];
|
|
959
|
+
const unsigned int k = converter.convert(f);
|
|
960
|
+
const V v = values[i];
|
|
961
|
+
const unsigned int b = k & 0xffff;
|
|
962
|
+
|
|
963
|
+
// find offset and increment
|
|
964
|
+
const unsigned int offset = tables[0][b]++;
|
|
965
|
+
|
|
966
|
+
aux_keys[offset] = f;
|
|
967
|
+
aux_values[offset] = v;
|
|
968
|
+
}
|
|
969
|
+
|
|
970
|
+
// pass 2 - sort by high 16 bits
|
|
971
|
+
for (int i=0; i < n; ++i)
|
|
972
|
+
{
|
|
973
|
+
// lookup offset of input
|
|
974
|
+
const K f = aux_keys[i];
|
|
975
|
+
const unsigned int k = converter.convert(f);
|
|
976
|
+
const V v = aux_values[i];
|
|
977
|
+
|
|
978
|
+
const unsigned int b = k >> 16;
|
|
979
|
+
|
|
980
|
+
const unsigned int offset = tables[1][b]++;
|
|
981
|
+
|
|
982
|
+
keys[offset] = f;
|
|
983
|
+
values[offset] = v;
|
|
984
|
+
}
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
template <typename V>
|
|
988
|
+
inline void radix_sort_pairs_cpu(
|
|
989
|
+
int* keys_input,
|
|
990
|
+
int* keys_aux,
|
|
991
|
+
V* values_input,
|
|
992
|
+
V* values_aux,
|
|
993
|
+
int num_elements_to_sort)
|
|
994
|
+
{
|
|
995
|
+
radix_sort_pairs_cpu_core<int, V, IntKeyToUint>(
|
|
996
|
+
keys_input, keys_aux,
|
|
997
|
+
values_input, values_aux,
|
|
998
|
+
num_elements_to_sort);
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
template <typename V>
|
|
1002
|
+
inline void radix_sort_pairs_cpu(
|
|
1003
|
+
unsigned int* keys_input,
|
|
1004
|
+
unsigned int* keys_aux,
|
|
1005
|
+
V* values_input,
|
|
1006
|
+
V* values_aux,
|
|
1007
|
+
int num_elements_to_sort)
|
|
1008
|
+
{
|
|
1009
|
+
radix_sort_pairs_cpu_core<unsigned int, V, UintKeyToUint>(
|
|
1010
|
+
keys_input, keys_aux,
|
|
1011
|
+
values_input, values_aux,
|
|
1012
|
+
num_elements_to_sort);
|
|
1013
|
+
}
|
|
1014
|
+
|
|
1015
|
+
template <typename V>
|
|
1016
|
+
inline void radix_sort_pairs_cpu(
|
|
1017
|
+
float* keys_input,
|
|
1018
|
+
float* keys_aux,
|
|
1019
|
+
V* values_input,
|
|
1020
|
+
V* values_aux,
|
|
1021
|
+
int num_elements_to_sort)
|
|
1022
|
+
{
|
|
1023
|
+
radix_sort_pairs_cpu_core<float, V, FloatKeyToUint>(
|
|
1024
|
+
keys_input, keys_aux,
|
|
1025
|
+
values_input, values_aux,
|
|
1026
|
+
num_elements_to_sort);
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
template <typename TileK, typename TileV>
|
|
1032
|
+
void tile_sort(TileK& t, TileV& t2)
|
|
1033
|
+
{
|
|
1034
|
+
using T = typename TileK::Type;
|
|
1035
|
+
using V = typename TileV::Type;
|
|
1036
|
+
|
|
1037
|
+
constexpr int num_elements_to_sort = TileK::Layout::Shape::size();
|
|
1038
|
+
T* keys = &t.data(0);
|
|
1039
|
+
V* values = &t2.data(0);
|
|
1040
|
+
|
|
1041
|
+
//Trim away the code that won't be used - possible because the number of elements to sort is known at compile time
|
|
1042
|
+
if constexpr (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
|
|
1043
|
+
{
|
|
1044
|
+
if constexpr(is_power_of_two(num_elements_to_sort))
|
|
1045
|
+
bitonic_sort_pairs_pow2_length_cpu<T, V>(keys, values, num_elements_to_sort);
|
|
1046
|
+
else
|
|
1047
|
+
bitonic_sort_pairs_general_size_cpu<V, num_elements_to_sort>(keys, values, num_elements_to_sort);
|
|
1048
|
+
}
|
|
1049
|
+
else
|
|
1050
|
+
{
|
|
1051
|
+
T keys_tmp[num_elements_to_sort];
|
|
1052
|
+
V values_tmp[num_elements_to_sort];
|
|
1053
|
+
|
|
1054
|
+
radix_sort_pairs_cpu<V>(keys, keys_tmp, values, values_tmp, num_elements_to_sort);
|
|
1055
|
+
}
|
|
1056
|
+
|
|
1057
|
+
WP_TILE_SYNC();
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
template <typename TileK, typename TileV>
|
|
1061
|
+
void tile_sort(TileK& t, TileV& t2, int start, int length)
|
|
1062
|
+
{
|
|
1063
|
+
using T = typename TileK::Type;
|
|
1064
|
+
using V = typename TileV::Type;
|
|
1065
|
+
|
|
1066
|
+
constexpr int max_elements_to_sort = TileK::Layout::Shape::size();
|
|
1067
|
+
const int num_elements_to_sort = length;
|
|
1068
|
+
T* keys = &t.data(start);
|
|
1069
|
+
V* values = &t2.data(start);
|
|
1070
|
+
|
|
1071
|
+
if (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
|
|
1072
|
+
{
|
|
1073
|
+
if (is_power_of_two(num_elements_to_sort))
|
|
1074
|
+
bitonic_sort_pairs_pow2_length_cpu<T, V>(keys, values, num_elements_to_sort);
|
|
1075
|
+
else
|
|
1076
|
+
bitonic_sort_pairs_general_size_cpu<V, max_elements_to_sort>(keys, values, num_elements_to_sort);
|
|
1077
|
+
}
|
|
1078
|
+
else
|
|
1079
|
+
{
|
|
1080
|
+
if constexpr (max_elements_to_sort > BITONIC_SORT_THRESHOLD)
|
|
1081
|
+
{
|
|
1082
|
+
T keys_tmp[max_elements_to_sort];
|
|
1083
|
+
V values_tmp[max_elements_to_sort];
|
|
1084
|
+
|
|
1085
|
+
radix_sort_pairs_cpu<V>(keys, keys_tmp, values, values_tmp, num_elements_to_sort);
|
|
1086
|
+
}
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
WP_TILE_SYNC();
|
|
1090
|
+
}
|
|
1091
|
+
|
|
1092
|
+
|
|
1093
|
+
#endif // !defined(__CUDA_ARCH__)
|
|
1094
|
+
|
|
1095
|
+
|
|
1096
|
+
template <typename TileK, typename TileV>
|
|
1097
|
+
inline void adj_tile_sort(TileK& t, TileV& t2, TileK& adj_t1, TileV& adj_t2)
|
|
1098
|
+
{
|
|
1099
|
+
// todo: general purpose sort gradients not implemented
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
template <typename TileK, typename TileV>
|
|
1103
|
+
inline void adj_tile_sort(TileK& t, TileV& t2, int start, int length, TileK& adj_t1, TileV& adj_t2, int adj_start, int adj_length)
|
|
1104
|
+
{
|
|
1105
|
+
// todo: general purpose sort gradients not implemented
|
|
1106
|
+
}
|
|
1107
|
+
|
|
1108
|
+
} // namespace wp
|
|
1109
|
+
|
|
1110
|
+
#if defined(__clang__)
|
|
1111
|
+
#pragma clang diagnostic pop
|
|
1112
|
+
#endif
|