warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/sort.cu
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#include "warp.h"
|
|
19
|
+
#include "cuda_util.h"
|
|
20
|
+
#include "sort.h"
|
|
21
|
+
|
|
22
|
+
#define THRUST_IGNORE_CUB_VERSION_CHECK
|
|
23
|
+
|
|
24
|
+
#include <cub/cub.cuh>
|
|
25
|
+
|
|
26
|
+
#include <map>
|
|
27
|
+
|
|
28
|
+
// temporary buffer for radix sort
|
|
29
|
+
struct RadixSortTemp
|
|
30
|
+
{
|
|
31
|
+
void* mem = NULL;
|
|
32
|
+
size_t size = 0;
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
// map temp buffers to CUDA contexts
|
|
36
|
+
static std::map<void*, RadixSortTemp> g_radix_sort_temp_map;
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
template <typename KeyType>
|
|
40
|
+
void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* size_out)
|
|
41
|
+
{
|
|
42
|
+
ContextGuard guard(context);
|
|
43
|
+
|
|
44
|
+
cub::DoubleBuffer<KeyType> d_keys;
|
|
45
|
+
cub::DoubleBuffer<int> d_values;
|
|
46
|
+
|
|
47
|
+
// compute temporary memory required
|
|
48
|
+
size_t sort_temp_size;
|
|
49
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
50
|
+
NULL,
|
|
51
|
+
sort_temp_size,
|
|
52
|
+
d_keys,
|
|
53
|
+
d_values,
|
|
54
|
+
n, 0, sizeof(KeyType)*8,
|
|
55
|
+
(cudaStream_t)cuda_stream_get_current()));
|
|
56
|
+
|
|
57
|
+
if (!context)
|
|
58
|
+
context = cuda_context_get_current();
|
|
59
|
+
|
|
60
|
+
RadixSortTemp& temp = g_radix_sort_temp_map[context];
|
|
61
|
+
|
|
62
|
+
if (sort_temp_size > temp.size)
|
|
63
|
+
{
|
|
64
|
+
free_device(WP_CURRENT_CONTEXT, temp.mem);
|
|
65
|
+
temp.mem = alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
|
|
66
|
+
temp.size = sort_temp_size;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
if (mem_out)
|
|
70
|
+
*mem_out = temp.mem;
|
|
71
|
+
if (size_out)
|
|
72
|
+
*size_out = temp.size;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
|
|
76
|
+
{
|
|
77
|
+
radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
template <typename KeyType>
|
|
81
|
+
void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
|
|
82
|
+
{
|
|
83
|
+
ContextGuard guard(context);
|
|
84
|
+
|
|
85
|
+
cub::DoubleBuffer<KeyType> d_keys(keys, keys + n);
|
|
86
|
+
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
87
|
+
|
|
88
|
+
RadixSortTemp temp;
|
|
89
|
+
radix_sort_reserve_internal<KeyType>(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
|
|
90
|
+
|
|
91
|
+
// sort
|
|
92
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
93
|
+
temp.mem,
|
|
94
|
+
temp.size,
|
|
95
|
+
d_keys,
|
|
96
|
+
d_values,
|
|
97
|
+
n, 0, sizeof(KeyType)*8,
|
|
98
|
+
(cudaStream_t)cuda_stream_get_current()));
|
|
99
|
+
|
|
100
|
+
if (d_keys.Current() != keys)
|
|
101
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(KeyType)*n);
|
|
102
|
+
|
|
103
|
+
if (d_values.Current() != values)
|
|
104
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
|
|
108
|
+
{
|
|
109
|
+
radix_sort_pairs_device<int>(context, keys, values, n);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
|
|
113
|
+
{
|
|
114
|
+
radix_sort_pairs_device<float>(context, keys, values, n);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n)
|
|
118
|
+
{
|
|
119
|
+
radix_sort_pairs_device<int64_t>(context, keys, values, n);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
|
|
123
|
+
{
|
|
124
|
+
radix_sort_pairs_device(
|
|
125
|
+
WP_CURRENT_CONTEXT,
|
|
126
|
+
reinterpret_cast<int *>(keys),
|
|
127
|
+
reinterpret_cast<int *>(values), n);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
|
|
131
|
+
{
|
|
132
|
+
radix_sort_pairs_device(
|
|
133
|
+
WP_CURRENT_CONTEXT,
|
|
134
|
+
reinterpret_cast<float *>(keys),
|
|
135
|
+
reinterpret_cast<int *>(values), n);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n)
|
|
139
|
+
{
|
|
140
|
+
radix_sort_pairs_device(
|
|
141
|
+
WP_CURRENT_CONTEXT,
|
|
142
|
+
reinterpret_cast<int64_t *>(keys),
|
|
143
|
+
reinterpret_cast<int *>(values), n);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_out, size_t* size_out)
|
|
147
|
+
{
|
|
148
|
+
ContextGuard guard(context);
|
|
149
|
+
|
|
150
|
+
cub::DoubleBuffer<int> d_keys;
|
|
151
|
+
cub::DoubleBuffer<int> d_values;
|
|
152
|
+
|
|
153
|
+
int* start_indices = NULL;
|
|
154
|
+
int* end_indices = NULL;
|
|
155
|
+
|
|
156
|
+
// compute temporary memory required
|
|
157
|
+
size_t sort_temp_size;
|
|
158
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
159
|
+
NULL,
|
|
160
|
+
sort_temp_size,
|
|
161
|
+
d_keys,
|
|
162
|
+
d_values,
|
|
163
|
+
n,
|
|
164
|
+
num_segments,
|
|
165
|
+
start_indices,
|
|
166
|
+
end_indices,
|
|
167
|
+
0,
|
|
168
|
+
32,
|
|
169
|
+
(cudaStream_t)cuda_stream_get_current()));
|
|
170
|
+
|
|
171
|
+
if (!context)
|
|
172
|
+
context = cuda_context_get_current();
|
|
173
|
+
|
|
174
|
+
RadixSortTemp& temp = g_radix_sort_temp_map[context];
|
|
175
|
+
|
|
176
|
+
if (sort_temp_size > temp.size)
|
|
177
|
+
{
|
|
178
|
+
free_device(WP_CURRENT_CONTEXT, temp.mem);
|
|
179
|
+
temp.mem = alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
|
|
180
|
+
temp.size = sort_temp_size;
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
if (mem_out)
|
|
184
|
+
*mem_out = temp.mem;
|
|
185
|
+
if (size_out)
|
|
186
|
+
*size_out = temp.size;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// segment_start_indices and segment_end_indices are arrays of length num_segments, where segment_start_indices[i] is the index of the first element
|
|
190
|
+
// in the i-th segment and segment_end_indices[i] is the index after the last element in the i-th segment
|
|
191
|
+
// https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html
|
|
192
|
+
void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
193
|
+
{
|
|
194
|
+
ContextGuard guard(context);
|
|
195
|
+
|
|
196
|
+
cub::DoubleBuffer<float> d_keys(keys, keys + n);
|
|
197
|
+
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
198
|
+
|
|
199
|
+
RadixSortTemp temp;
|
|
200
|
+
segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
|
|
201
|
+
|
|
202
|
+
// sort
|
|
203
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
204
|
+
temp.mem,
|
|
205
|
+
temp.size,
|
|
206
|
+
d_keys,
|
|
207
|
+
d_values,
|
|
208
|
+
n,
|
|
209
|
+
num_segments,
|
|
210
|
+
segment_start_indices,
|
|
211
|
+
segment_end_indices,
|
|
212
|
+
0,
|
|
213
|
+
32,
|
|
214
|
+
(cudaStream_t)cuda_stream_get_current()));
|
|
215
|
+
|
|
216
|
+
if (d_keys.Current() != keys)
|
|
217
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
|
|
218
|
+
|
|
219
|
+
if (d_values.Current() != values)
|
|
220
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
224
|
+
{
|
|
225
|
+
segmented_sort_pairs_device(
|
|
226
|
+
WP_CURRENT_CONTEXT,
|
|
227
|
+
reinterpret_cast<float *>(keys),
|
|
228
|
+
reinterpret_cast<int *>(values), n,
|
|
229
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
230
|
+
reinterpret_cast<int *>(segment_end_indices),
|
|
231
|
+
num_segments);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
// segment_indices is an array of length num_segments + 1, where segment_indices[i] is the index of the first element in the i-th segment
|
|
235
|
+
// The end of a segment is given by segment_indices[i+1]
|
|
236
|
+
// https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#a-simple-example
|
|
237
|
+
void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
238
|
+
{
|
|
239
|
+
ContextGuard guard(context);
|
|
240
|
+
|
|
241
|
+
cub::DoubleBuffer<int> d_keys(keys, keys + n);
|
|
242
|
+
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
243
|
+
|
|
244
|
+
RadixSortTemp temp;
|
|
245
|
+
segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
|
|
246
|
+
|
|
247
|
+
// sort
|
|
248
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
249
|
+
temp.mem,
|
|
250
|
+
temp.size,
|
|
251
|
+
d_keys,
|
|
252
|
+
d_values,
|
|
253
|
+
n,
|
|
254
|
+
num_segments,
|
|
255
|
+
segment_start_indices,
|
|
256
|
+
segment_end_indices,
|
|
257
|
+
0,
|
|
258
|
+
32,
|
|
259
|
+
(cudaStream_t)cuda_stream_get_current()));
|
|
260
|
+
|
|
261
|
+
if (d_keys.Current() != keys)
|
|
262
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
|
|
263
|
+
|
|
264
|
+
if (d_values.Current() != values)
|
|
265
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
269
|
+
{
|
|
270
|
+
segmented_sort_pairs_device(
|
|
271
|
+
WP_CURRENT_CONTEXT,
|
|
272
|
+
reinterpret_cast<int *>(keys),
|
|
273
|
+
reinterpret_cast<int *>(values), n,
|
|
274
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
275
|
+
reinterpret_cast<int *>(segment_end_indices),
|
|
276
|
+
num_segments);
|
|
277
|
+
}
|
warp/native/sort.h
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include <stddef.h>
|
|
21
|
+
|
|
22
|
+
void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
|
|
23
|
+
void radix_sort_pairs_host(int* keys, int* values, int n);
|
|
24
|
+
void radix_sort_pairs_host(float* keys, int* values, int n);
|
|
25
|
+
void radix_sort_pairs_host(int64_t* keys, int* values, int n);
|
|
26
|
+
void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
|
|
27
|
+
void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
|
|
28
|
+
void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n);
|
|
29
|
+
|
|
30
|
+
void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
31
|
+
void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
32
|
+
void segmented_sort_pairs_host(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
33
|
+
void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
warp/native/sparse.cpp
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#include "warp.h"
|
|
19
|
+
|
|
20
|
+
#include <algorithm>
|
|
21
|
+
#include <numeric>
|
|
22
|
+
#include <vector>
|
|
23
|
+
|
|
24
|
+
namespace
|
|
25
|
+
{
|
|
26
|
+
|
|
27
|
+
// Specialized is_zero and accumulation function for common block sizes
|
|
28
|
+
// Rely on compiler to unroll loops when block size is known
|
|
29
|
+
|
|
30
|
+
template <int N, typename T> bool bsr_fixed_block_is_zero(const T* val, int value_size)
|
|
31
|
+
{
|
|
32
|
+
return std::all_of(val, val + N, [](float v) { return v == T(0); });
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
template <typename T> bool bsr_dyn_block_is_zero(const T* val, int value_size)
|
|
36
|
+
{
|
|
37
|
+
return std::all_of(val, val + value_size, [](float v) { return v == T(0); });
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
template <int N, typename T> void bsr_fixed_block_accumulate(const T* val, T* sum, int value_size)
|
|
41
|
+
{
|
|
42
|
+
for (int i = 0; i < N; ++i, ++val, ++sum)
|
|
43
|
+
{
|
|
44
|
+
*sum += *val;
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
template <typename T> void bsr_dyn_block_accumulate(const T* val, T* sum, int value_size)
|
|
49
|
+
{
|
|
50
|
+
for (int i = 0; i < value_size; ++i, ++val, ++sum)
|
|
51
|
+
{
|
|
52
|
+
*sum += *val;
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
template <int Rows, int Cols, typename T>
|
|
57
|
+
void bsr_fixed_block_transpose(const T* src, T* dest, int row_count, int col_count)
|
|
58
|
+
{
|
|
59
|
+
for (int r = 0; r < Rows; ++r)
|
|
60
|
+
{
|
|
61
|
+
for (int c = 0; c < Cols; ++c)
|
|
62
|
+
{
|
|
63
|
+
dest[c * Rows + r] = src[r * Cols + c];
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
template <typename T> void bsr_dyn_block_transpose(const T* src, T* dest, int row_count, int col_count)
|
|
69
|
+
{
|
|
70
|
+
for (int r = 0; r < row_count; ++r)
|
|
71
|
+
{
|
|
72
|
+
for (int c = 0; c < col_count; ++c)
|
|
73
|
+
{
|
|
74
|
+
dest[c * row_count + r] = src[r * col_count + c];
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
} // namespace
|
|
80
|
+
|
|
81
|
+
template <typename T>
|
|
82
|
+
int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_block, const int row_count,
|
|
83
|
+
const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
|
|
84
|
+
const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
|
|
85
|
+
int* bsr_columns, T* bsr_values)
|
|
86
|
+
{
|
|
87
|
+
|
|
88
|
+
// get specialized accumulator for common block sizes (1,1), (1,2), (1,3),
|
|
89
|
+
// (2,2), (2,3), (3,3)
|
|
90
|
+
const int block_size = rows_per_block * cols_per_block;
|
|
91
|
+
void (*block_accumulate_func)(const T*, T*, int);
|
|
92
|
+
bool (*block_is_zero_func)(const T*, int);
|
|
93
|
+
switch (block_size)
|
|
94
|
+
{
|
|
95
|
+
case 1:
|
|
96
|
+
block_accumulate_func = bsr_fixed_block_accumulate<1, T>;
|
|
97
|
+
block_is_zero_func = bsr_fixed_block_is_zero<1, T>;
|
|
98
|
+
break;
|
|
99
|
+
case 2:
|
|
100
|
+
block_accumulate_func = bsr_fixed_block_accumulate<2, T>;
|
|
101
|
+
block_is_zero_func = bsr_fixed_block_is_zero<2, T>;
|
|
102
|
+
break;
|
|
103
|
+
case 3:
|
|
104
|
+
block_accumulate_func = bsr_fixed_block_accumulate<3, T>;
|
|
105
|
+
block_is_zero_func = bsr_fixed_block_is_zero<3, T>;
|
|
106
|
+
break;
|
|
107
|
+
case 4:
|
|
108
|
+
block_accumulate_func = bsr_fixed_block_accumulate<4, T>;
|
|
109
|
+
block_is_zero_func = bsr_fixed_block_is_zero<4, T>;
|
|
110
|
+
break;
|
|
111
|
+
case 6:
|
|
112
|
+
block_accumulate_func = bsr_fixed_block_accumulate<6, T>;
|
|
113
|
+
block_is_zero_func = bsr_fixed_block_is_zero<6, T>;
|
|
114
|
+
break;
|
|
115
|
+
case 9:
|
|
116
|
+
block_accumulate_func = bsr_fixed_block_accumulate<9, T>;
|
|
117
|
+
block_is_zero_func = bsr_fixed_block_is_zero<9, T>;
|
|
118
|
+
break;
|
|
119
|
+
default:
|
|
120
|
+
block_accumulate_func = bsr_dyn_block_accumulate<T>;
|
|
121
|
+
block_is_zero_func = bsr_dyn_block_is_zero<T>;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
std::vector<int> block_indices(nnz);
|
|
125
|
+
std::iota(block_indices.begin(), block_indices.end(), 0);
|
|
126
|
+
|
|
127
|
+
// remove zero blocks and invalid row indices
|
|
128
|
+
|
|
129
|
+
auto discard_block = [&](int i)
|
|
130
|
+
{
|
|
131
|
+
const int row = tpl_rows[i];
|
|
132
|
+
if (row < 0 || row >= row_count)
|
|
133
|
+
{
|
|
134
|
+
return true;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
if (prune_numerical_zeros && tpl_values && block_is_zero_func(tpl_values + i * block_size, block_size))
|
|
138
|
+
{
|
|
139
|
+
return true;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
if (!masked)
|
|
143
|
+
{
|
|
144
|
+
return false;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
const int* beg = bsr_columns + bsr_offsets[row];
|
|
148
|
+
const int* end = bsr_columns + bsr_offsets[row + 1];
|
|
149
|
+
const int col = tpl_columns[i];
|
|
150
|
+
const int* block = std::lower_bound(beg, end, col);
|
|
151
|
+
return block == end || *block != col;
|
|
152
|
+
};
|
|
153
|
+
|
|
154
|
+
block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(), discard_block), block_indices.end());
|
|
155
|
+
|
|
156
|
+
// sort block indices according to lexico order
|
|
157
|
+
std::sort(block_indices.begin(), block_indices.end(), [tpl_rows, tpl_columns](int i, int j) -> bool
|
|
158
|
+
{ return tpl_rows[i] < tpl_rows[j] || (tpl_rows[i] == tpl_rows[j] && tpl_columns[i] < tpl_columns[j]); });
|
|
159
|
+
|
|
160
|
+
// accumulate blocks at same locations, count blocks per row
|
|
161
|
+
std::fill_n(bsr_offsets, row_count + 1, 0);
|
|
162
|
+
|
|
163
|
+
int current_row = -1;
|
|
164
|
+
int current_col = -1;
|
|
165
|
+
|
|
166
|
+
// so that we get back to the start for the first block
|
|
167
|
+
if (bsr_values)
|
|
168
|
+
{
|
|
169
|
+
bsr_values -= block_size;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
for (int i = 0; i < block_indices.size(); ++i)
|
|
173
|
+
{
|
|
174
|
+
int idx = block_indices[i];
|
|
175
|
+
int row = tpl_rows[idx];
|
|
176
|
+
int col = tpl_columns[idx];
|
|
177
|
+
const T* val = tpl_values + idx * block_size;
|
|
178
|
+
|
|
179
|
+
if (row == current_row && col == current_col)
|
|
180
|
+
{
|
|
181
|
+
if (bsr_values)
|
|
182
|
+
{
|
|
183
|
+
block_accumulate_func(val, bsr_values, block_size);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
else
|
|
187
|
+
{
|
|
188
|
+
*(bsr_columns++) = col;
|
|
189
|
+
|
|
190
|
+
if (bsr_values)
|
|
191
|
+
{
|
|
192
|
+
bsr_values += block_size;
|
|
193
|
+
std::copy_n(val, block_size, bsr_values);
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
bsr_offsets[row + 1]++;
|
|
197
|
+
|
|
198
|
+
current_row = row;
|
|
199
|
+
current_col = col;
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
// build postfix sum of row counts
|
|
204
|
+
std::partial_sum(bsr_offsets, bsr_offsets + row_count + 1, bsr_offsets);
|
|
205
|
+
|
|
206
|
+
return bsr_offsets[row_count];
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
template <typename T>
|
|
210
|
+
void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz_up,
|
|
211
|
+
const int* bsr_offsets, const int* bsr_columns, const T* bsr_values,
|
|
212
|
+
int* transposed_bsr_offsets, int* transposed_bsr_columns, T* transposed_bsr_values)
|
|
213
|
+
{
|
|
214
|
+
const int nnz = bsr_offsets[row_count];
|
|
215
|
+
const int block_size = rows_per_block * cols_per_block;
|
|
216
|
+
|
|
217
|
+
void (*block_transpose_func)(const T*, T*, int, int) = bsr_dyn_block_transpose<T>;
|
|
218
|
+
switch (rows_per_block)
|
|
219
|
+
{
|
|
220
|
+
case 1:
|
|
221
|
+
switch (cols_per_block)
|
|
222
|
+
{
|
|
223
|
+
case 1:
|
|
224
|
+
block_transpose_func = bsr_fixed_block_transpose<1, 1, T>;
|
|
225
|
+
break;
|
|
226
|
+
case 2:
|
|
227
|
+
block_transpose_func = bsr_fixed_block_transpose<1, 2, T>;
|
|
228
|
+
break;
|
|
229
|
+
case 3:
|
|
230
|
+
block_transpose_func = bsr_fixed_block_transpose<1, 3, T>;
|
|
231
|
+
break;
|
|
232
|
+
}
|
|
233
|
+
break;
|
|
234
|
+
case 2:
|
|
235
|
+
switch (cols_per_block)
|
|
236
|
+
{
|
|
237
|
+
case 1:
|
|
238
|
+
block_transpose_func = bsr_fixed_block_transpose<2, 1, T>;
|
|
239
|
+
break;
|
|
240
|
+
case 2:
|
|
241
|
+
block_transpose_func = bsr_fixed_block_transpose<2, 2, T>;
|
|
242
|
+
break;
|
|
243
|
+
case 3:
|
|
244
|
+
block_transpose_func = bsr_fixed_block_transpose<2, 3, T>;
|
|
245
|
+
break;
|
|
246
|
+
}
|
|
247
|
+
break;
|
|
248
|
+
case 3:
|
|
249
|
+
switch (cols_per_block)
|
|
250
|
+
{
|
|
251
|
+
case 1:
|
|
252
|
+
block_transpose_func = bsr_fixed_block_transpose<3, 1, T>;
|
|
253
|
+
break;
|
|
254
|
+
case 2:
|
|
255
|
+
block_transpose_func = bsr_fixed_block_transpose<3, 2, T>;
|
|
256
|
+
break;
|
|
257
|
+
case 3:
|
|
258
|
+
block_transpose_func = bsr_fixed_block_transpose<3, 3, T>;
|
|
259
|
+
break;
|
|
260
|
+
}
|
|
261
|
+
break;
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
std::vector<int> block_indices(nnz), bsr_rows(nnz);
|
|
265
|
+
std::iota(block_indices.begin(), block_indices.end(), 0);
|
|
266
|
+
|
|
267
|
+
// Fill row indices from offsets
|
|
268
|
+
for (int row = 0; row < row_count; ++row)
|
|
269
|
+
{
|
|
270
|
+
std::fill(bsr_rows.begin() + bsr_offsets[row], bsr_rows.begin() + bsr_offsets[row + 1], row);
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
// sort block indices according to (transposed) lexico order
|
|
274
|
+
std::sort(
|
|
275
|
+
block_indices.begin(), block_indices.end(), [&bsr_rows, bsr_columns](int i, int j) -> bool
|
|
276
|
+
{ return bsr_columns[i] < bsr_columns[j] || (bsr_columns[i] == bsr_columns[j] && bsr_rows[i] < bsr_rows[j]); });
|
|
277
|
+
|
|
278
|
+
// Count blocks per column and transpose blocks
|
|
279
|
+
std::fill_n(transposed_bsr_offsets, col_count + 1, 0);
|
|
280
|
+
|
|
281
|
+
for (int i = 0; i < nnz; ++i)
|
|
282
|
+
{
|
|
283
|
+
int idx = block_indices[i];
|
|
284
|
+
int row = bsr_rows[idx];
|
|
285
|
+
int col = bsr_columns[idx];
|
|
286
|
+
|
|
287
|
+
++transposed_bsr_offsets[col + 1];
|
|
288
|
+
transposed_bsr_columns[i] = row;
|
|
289
|
+
|
|
290
|
+
if (transposed_bsr_values != nullptr)
|
|
291
|
+
{
|
|
292
|
+
const T* src_block = bsr_values + idx * block_size;
|
|
293
|
+
T* dst_block = transposed_bsr_values + i * block_size;
|
|
294
|
+
block_transpose_func(src_block, dst_block, rows_per_block, cols_per_block);
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
// build postfix sum of column counts
|
|
299
|
+
std::partial_sum(transposed_bsr_offsets, transposed_bsr_offsets + col_count + 1, transposed_bsr_offsets);
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
303
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
304
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
305
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
306
|
+
{
|
|
307
|
+
bsr_matrix_from_triplets_host<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
|
|
308
|
+
static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
|
|
309
|
+
bsr_offsets, bsr_columns, static_cast<float*>(bsr_values));
|
|
310
|
+
if (bsr_nnz)
|
|
311
|
+
{
|
|
312
|
+
*bsr_nnz = bsr_offsets[row_count];
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
WP_API void bsr_matrix_from_triplets_double_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
317
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
318
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
319
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
320
|
+
{
|
|
321
|
+
bsr_matrix_from_triplets_host<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
|
|
322
|
+
static_cast<const double*>(tpl_values), prune_numerical_zeros, masked,
|
|
323
|
+
bsr_offsets, bsr_columns, static_cast<double*>(bsr_values));
|
|
324
|
+
if (bsr_nnz)
|
|
325
|
+
{
|
|
326
|
+
*bsr_nnz = bsr_offsets[row_count];
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
WP_API void bsr_transpose_float_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
331
|
+
int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
|
|
332
|
+
int* transposed_bsr_columns, void* transposed_bsr_values)
|
|
333
|
+
{
|
|
334
|
+
bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
|
|
335
|
+
static_cast<const float*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
|
|
336
|
+
static_cast<float*>(transposed_bsr_values));
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
WP_API void bsr_transpose_double_host(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
340
|
+
int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
|
|
341
|
+
int* transposed_bsr_columns, void* transposed_bsr_values)
|
|
342
|
+
{
|
|
343
|
+
bsr_transpose_host(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
|
|
344
|
+
static_cast<const double*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
|
|
345
|
+
static_cast<double*>(transposed_bsr_values));
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
#if !WP_ENABLE_CUDA
|
|
349
|
+
WP_API void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
350
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
351
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
352
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
353
|
+
{
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
WP_API void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
357
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
358
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets,
|
|
359
|
+
int* bsr_columns, void* bsr_values, int* bsr_nnz,
|
|
360
|
+
void* bsr_nnz_event)
|
|
361
|
+
{
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
WP_API void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
365
|
+
int* bsr_offsets, int* bsr_columns, void* bsr_values,
|
|
366
|
+
int* transposed_bsr_offsets, int* transposed_bsr_columns,
|
|
367
|
+
void* transposed_bsr_values)
|
|
368
|
+
{
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
WP_API void bsr_transpose_double_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
372
|
+
int* bsr_offsets, int* bsr_columns, void* bsr_values,
|
|
373
|
+
int* transposed_bsr_offsets, int* transposed_bsr_columns,
|
|
374
|
+
void* transposed_bsr_values)
|
|
375
|
+
{
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
#endif
|