warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +139 -0
- warp/__init__.pyi +1 -0
- warp/autograd.py +1142 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +557 -0
- warp/build_dll.py +405 -0
- warp/builtins.py +6855 -0
- warp/codegen.py +3969 -0
- warp/config.py +158 -0
- warp/constants.py +57 -0
- warp/context.py +6812 -0
- warp/dlpack.py +462 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +194 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +193 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +423 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +178 -0
- warp/examples/fem/example_convection_diffusion_dg.py +204 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +220 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_magnetostatics.py +240 -0
- warp/examples/fem/example_mixed_elasticity.py +291 -0
- warp/examples/fem/example_navier_stokes.py +261 -0
- warp/examples/fem/example_nonconforming_contact.py +298 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +352 -0
- warp/examples/fem/utils.py +1000 -0
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +266 -0
- warp/examples/optim/example_cloth_throw.py +228 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_drone.py +870 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +182 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
- warp/examples/optim/example_softbody_properties.py +400 -0
- warp/examples/optim/example_spring_cage.py +245 -0
- warp/examples/optim/example_trajectory.py +227 -0
- warp/examples/sim/example_cartpole.py +143 -0
- warp/examples/sim/example_cloth.py +225 -0
- warp/examples/sim/example_cloth_self_contact.py +322 -0
- warp/examples/sim/example_granular.py +130 -0
- warp/examples/sim/example_granular_collision_sdf.py +202 -0
- warp/examples/sim/example_jacobian_ik.py +244 -0
- warp/examples/sim/example_particle_chain.py +124 -0
- warp/examples/sim/example_quadruped.py +203 -0
- warp/examples/sim/example_rigid_chain.py +203 -0
- warp/examples/sim/example_rigid_contact.py +195 -0
- warp/examples/sim/example_rigid_force.py +133 -0
- warp/examples/sim/example_rigid_gyroscopic.py +115 -0
- warp/examples/sim/example_rigid_soft_contact.py +140 -0
- warp/examples/sim/example_soft_body.py +196 -0
- warp/examples/tile/example_tile_cholesky.py +87 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mlp.py +383 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/examples/tile/example_tile_walker.py +327 -0
- warp/fabric.py +355 -0
- warp/fem/__init__.py +106 -0
- warp/fem/adaptivity.py +508 -0
- warp/fem/cache.py +572 -0
- warp/fem/dirichlet.py +202 -0
- warp/fem/domain.py +411 -0
- warp/fem/field/__init__.py +125 -0
- warp/fem/field/field.py +619 -0
- warp/fem/field/nodal_field.py +326 -0
- warp/fem/field/restriction.py +37 -0
- warp/fem/field/virtual.py +848 -0
- warp/fem/geometry/__init__.py +32 -0
- warp/fem/geometry/adaptive_nanogrid.py +857 -0
- warp/fem/geometry/closest_point.py +84 -0
- warp/fem/geometry/deformed_geometry.py +221 -0
- warp/fem/geometry/element.py +776 -0
- warp/fem/geometry/geometry.py +362 -0
- warp/fem/geometry/grid_2d.py +392 -0
- warp/fem/geometry/grid_3d.py +452 -0
- warp/fem/geometry/hexmesh.py +911 -0
- warp/fem/geometry/nanogrid.py +571 -0
- warp/fem/geometry/partition.py +389 -0
- warp/fem/geometry/quadmesh.py +663 -0
- warp/fem/geometry/tetmesh.py +855 -0
- warp/fem/geometry/trimesh.py +806 -0
- warp/fem/integrate.py +2335 -0
- warp/fem/linalg.py +419 -0
- warp/fem/operator.py +293 -0
- warp/fem/polynomial.py +229 -0
- warp/fem/quadrature/__init__.py +17 -0
- warp/fem/quadrature/pic_quadrature.py +299 -0
- warp/fem/quadrature/quadrature.py +591 -0
- warp/fem/space/__init__.py +228 -0
- warp/fem/space/basis_function_space.py +468 -0
- warp/fem/space/basis_space.py +667 -0
- warp/fem/space/dof_mapper.py +251 -0
- warp/fem/space/function_space.py +309 -0
- warp/fem/space/grid_2d_function_space.py +177 -0
- warp/fem/space/grid_3d_function_space.py +227 -0
- warp/fem/space/hexmesh_function_space.py +257 -0
- warp/fem/space/nanogrid_function_space.py +201 -0
- warp/fem/space/partition.py +367 -0
- warp/fem/space/quadmesh_function_space.py +223 -0
- warp/fem/space/restriction.py +179 -0
- warp/fem/space/shape/__init__.py +143 -0
- warp/fem/space/shape/cube_shape_function.py +1105 -0
- warp/fem/space/shape/shape_function.py +133 -0
- warp/fem/space/shape/square_shape_function.py +926 -0
- warp/fem/space/shape/tet_shape_function.py +834 -0
- warp/fem/space/shape/triangle_shape_function.py +672 -0
- warp/fem/space/tetmesh_function_space.py +271 -0
- warp/fem/space/topology.py +424 -0
- warp/fem/space/trimesh_function_space.py +194 -0
- warp/fem/types.py +99 -0
- warp/fem/utils.py +420 -0
- warp/jax.py +187 -0
- warp/jax_experimental/__init__.py +16 -0
- warp/jax_experimental/custom_call.py +351 -0
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +244 -0
- warp/native/array.h +1145 -0
- warp/native/builtin.h +1800 -0
- warp/native/bvh.cpp +492 -0
- warp/native/bvh.cu +791 -0
- warp/native/bvh.h +554 -0
- warp/native/clang/clang.cpp +536 -0
- warp/native/coloring.cpp +613 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +362 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +646 -0
- warp/native/cuda_util.h +307 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +1878 -0
- warp/native/fabric.h +245 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +87 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1230 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +339 -0
- warp/native/marching.cpp +19 -0
- warp/native/marching.cu +514 -0
- warp/native/marching.h +19 -0
- warp/native/mat.h +2220 -0
- warp/native/mathdx.cpp +87 -0
- warp/native/matnn.h +343 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +404 -0
- warp/native/mesh.h +1980 -0
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1371 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +139 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +364 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +53 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +277 -0
- warp/native/sort.h +33 -0
- warp/native/sparse.cpp +378 -0
- warp/native/sparse.cu +524 -0
- warp/native/spatial.h +657 -0
- warp/native/svd.h +702 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +2584 -0
- warp/native/tile_reduce.h +264 -0
- warp/native/vec.h +1426 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +67 -0
- warp/native/volume.h +969 -0
- warp/native/volume_builder.cu +477 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1082 -0
- warp/native/warp.cu +3636 -0
- warp/native/warp.h +381 -0
- warp/optim/__init__.py +17 -0
- warp/optim/adam.py +163 -0
- warp/optim/linear.py +1137 -0
- warp/optim/sgd.py +112 -0
- warp/paddle.py +407 -0
- warp/render/__init__.py +18 -0
- warp/render/render_opengl.py +3518 -0
- warp/render/render_usd.py +784 -0
- warp/render/utils.py +160 -0
- warp/sim/__init__.py +65 -0
- warp/sim/articulation.py +793 -0
- warp/sim/collide.py +2395 -0
- warp/sim/graph_coloring.py +300 -0
- warp/sim/import_mjcf.py +790 -0
- warp/sim/import_snu.py +227 -0
- warp/sim/import_urdf.py +579 -0
- warp/sim/import_usd.py +894 -0
- warp/sim/inertia.py +324 -0
- warp/sim/integrator.py +242 -0
- warp/sim/integrator_euler.py +1997 -0
- warp/sim/integrator_featherstone.py +2101 -0
- warp/sim/integrator_vbd.py +2048 -0
- warp/sim/integrator_xpbd.py +3292 -0
- warp/sim/model.py +4791 -0
- warp/sim/particles.py +121 -0
- warp/sim/render.py +427 -0
- warp/sim/utils.py +428 -0
- warp/sparse.py +2057 -0
- warp/stubs.py +3333 -0
- warp/tape.py +1203 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +634 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +200 -0
- warp/tests/geometry/test_hash_grid.py +221 -0
- warp/tests/geometry/test_marching_cubes.py +74 -0
- warp/tests/geometry/test_mesh.py +316 -0
- warp/tests/geometry/test_mesh_query_aabb.py +399 -0
- warp/tests/geometry/test_mesh_query_point.py +932 -0
- warp/tests/geometry/test_mesh_query_ray.py +311 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +729 -0
- warp/tests/interop/test_jax.py +371 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/sim/__init__.py +0 -0
- warp/tests/sim/disabled_kinematics.py +244 -0
- warp/tests/sim/flaky_test_sim_grad.py +290 -0
- warp/tests/sim/test_collision.py +604 -0
- warp/tests/sim/test_coloring.py +258 -0
- warp/tests/sim/test_model.py +224 -0
- warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
- warp/tests/sim/test_sim_kinematics.py +98 -0
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +163 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +2972 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +250 -0
- warp/tests/test_atomic.py +153 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +1298 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +810 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +252 -0
- warp/tests/test_context.py +42 -0
- warp/tests/test_copy.py +238 -0
- warp/tests/test_ctypes.py +638 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +97 -0
- warp/tests/test_examples.py +482 -0
- warp/tests/test_fabricarray.py +996 -0
- warp/tests/test_fast_math.py +74 -0
- warp/tests/test_fem.py +2003 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +454 -0
- warp/tests/test_future_annotations.py +98 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +73 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +193 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_mat.py +2089 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +178 -0
- warp/tests/test_mlp.py +282 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +44 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +339 -0
- warp/tests/test_quat.py +2315 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +302 -0
- warp/tests/test_rounding.py +185 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +105 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +582 -0
- warp/tests/test_spatial.py +2229 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +592 -0
- warp/tests/test_struct.py +734 -0
- warp/tests/test_tape.py +204 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +145 -0
- warp/tests/test_types.py +562 -0
- warp/tests/test_utils.py +588 -0
- warp/tests/test_vec.py +1487 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +780 -0
- warp/tests/tile/test_tile_load.py +407 -0
- warp/tests/tile/test_tile_mathdx.py +208 -0
- warp/tests/tile/test_tile_mlp.py +402 -0
- warp/tests/tile/test_tile_reduce.py +447 -0
- warp/tests/tile/test_tile_shared_memory.py +247 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +427 -0
- warp/tests/unittest_utils.py +468 -0
- warp/tests/walkthrough_debug.py +93 -0
- warp/thirdparty/__init__.py +0 -0
- warp/thirdparty/appdirs.py +598 -0
- warp/thirdparty/dlpack.py +145 -0
- warp/thirdparty/unittest_parallel.py +570 -0
- warp/torch.py +391 -0
- warp/types.py +5230 -0
- warp/utils.py +1137 -0
- warp_lang-1.7.0.dist-info/METADATA +516 -0
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- warp_lang-1.7.0.dist-info/WHEEL +5 -0
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/sparse.cu
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#include "cuda_util.h"
|
|
19
|
+
#include "warp.h"
|
|
20
|
+
|
|
21
|
+
#define THRUST_IGNORE_CUB_VERSION_CHECK
|
|
22
|
+
|
|
23
|
+
#include <cub/device/device_radix_sort.cuh>
|
|
24
|
+
#include <cub/device/device_run_length_encode.cuh>
|
|
25
|
+
#include <cub/device/device_scan.cuh>
|
|
26
|
+
|
|
27
|
+
namespace
|
|
28
|
+
{
|
|
29
|
+
|
|
30
|
+
// Combined row+column value that can be radix-sorted with CUB
|
|
31
|
+
using BsrRowCol = uint64_t;
|
|
32
|
+
|
|
33
|
+
static constexpr BsrRowCol PRUNED_ROWCOL = ~BsrRowCol(0);
|
|
34
|
+
|
|
35
|
+
CUDA_CALLABLE BsrRowCol bsr_combine_row_col(uint32_t row, uint32_t col)
|
|
36
|
+
{
|
|
37
|
+
return (static_cast<uint64_t>(row) << 32) | col;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol& row_col) { return row_col >> 32; }
|
|
41
|
+
|
|
42
|
+
CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol& row_col) { return row_col & INT_MAX; }
|
|
43
|
+
|
|
44
|
+
template <typename T> struct BsrBlockIsNotZero
|
|
45
|
+
{
|
|
46
|
+
int block_size;
|
|
47
|
+
const T* values;
|
|
48
|
+
|
|
49
|
+
CUDA_CALLABLE_DEVICE bool operator()(int i) const
|
|
50
|
+
{
|
|
51
|
+
if (!values)
|
|
52
|
+
return true;
|
|
53
|
+
|
|
54
|
+
const T* val = values + i * block_size;
|
|
55
|
+
for (int i = 0; i < block_size; ++i, ++val)
|
|
56
|
+
{
|
|
57
|
+
if (*val != T(0))
|
|
58
|
+
return true;
|
|
59
|
+
}
|
|
60
|
+
return false;
|
|
61
|
+
}
|
|
62
|
+
};
|
|
63
|
+
|
|
64
|
+
struct BsrBlockInMask
|
|
65
|
+
{
|
|
66
|
+
const int* bsr_offsets;
|
|
67
|
+
const int* bsr_columns;
|
|
68
|
+
|
|
69
|
+
CUDA_CALLABLE_DEVICE bool operator()(int row, int col) const
|
|
70
|
+
{
|
|
71
|
+
if (bsr_offsets == nullptr)
|
|
72
|
+
return true;
|
|
73
|
+
|
|
74
|
+
int lower = bsr_offsets[row];
|
|
75
|
+
int upper = bsr_offsets[row + 1] - 1;
|
|
76
|
+
|
|
77
|
+
while (lower < upper)
|
|
78
|
+
{
|
|
79
|
+
const int mid = lower + (upper - lower) / 2;
|
|
80
|
+
|
|
81
|
+
if (bsr_columns[mid] < col)
|
|
82
|
+
{
|
|
83
|
+
lower = mid + 1;
|
|
84
|
+
}
|
|
85
|
+
else
|
|
86
|
+
{
|
|
87
|
+
upper = mid;
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return lower == upper && (bsr_columns[lower] == col);
|
|
92
|
+
}
|
|
93
|
+
};
|
|
94
|
+
|
|
95
|
+
template <typename T>
|
|
96
|
+
__global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const int* tpl_rows, const int* tpl_columns,
|
|
97
|
+
const BsrBlockIsNotZero<T> nonZero, const BsrBlockInMask mask,
|
|
98
|
+
uint32_t* block_indices, BsrRowCol* tpl_row_col)
|
|
99
|
+
{
|
|
100
|
+
int block = blockIdx.x * blockDim.x + threadIdx.x;
|
|
101
|
+
if (block >= nnz)
|
|
102
|
+
return;
|
|
103
|
+
|
|
104
|
+
const int row = tpl_rows[block];
|
|
105
|
+
const int col = tpl_columns[block];
|
|
106
|
+
const bool is_valid = row >= 0 && row < nrow;
|
|
107
|
+
|
|
108
|
+
const BsrRowCol row_col =
|
|
109
|
+
is_valid && nonZero(block) && mask(row, col) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
|
|
110
|
+
tpl_row_col[block] = row_col;
|
|
111
|
+
block_indices[block] = block;
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
template <typename T>
|
|
115
|
+
__global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const BsrRowCol* unique_row_col,
|
|
116
|
+
int* row_offsets)
|
|
117
|
+
{
|
|
118
|
+
const uint32_t row = blockIdx.x * blockDim.x + threadIdx.x;
|
|
119
|
+
|
|
120
|
+
if (row > row_count)
|
|
121
|
+
return;
|
|
122
|
+
|
|
123
|
+
const uint32_t nnz = *d_nnz;
|
|
124
|
+
if (row == 0 || nnz == 0)
|
|
125
|
+
{
|
|
126
|
+
row_offsets[row] = 0;
|
|
127
|
+
return;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
if (bsr_get_row(unique_row_col[nnz - 1]) < row)
|
|
131
|
+
{
|
|
132
|
+
row_offsets[row] = nnz;
|
|
133
|
+
return;
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
// binary search for row start
|
|
137
|
+
uint32_t lower = 0;
|
|
138
|
+
uint32_t upper = nnz - 1;
|
|
139
|
+
while (lower < upper)
|
|
140
|
+
{
|
|
141
|
+
uint32_t mid = lower + (upper - lower) / 2;
|
|
142
|
+
|
|
143
|
+
if (bsr_get_row(unique_row_col[mid]) < row)
|
|
144
|
+
{
|
|
145
|
+
lower = mid + 1;
|
|
146
|
+
}
|
|
147
|
+
else
|
|
148
|
+
{
|
|
149
|
+
upper = mid;
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
row_offsets[row] = lower;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
template <typename T>
|
|
157
|
+
__global__ void bsr_merge_blocks(const int* d_nnz, int block_size, const uint32_t* block_offsets,
|
|
158
|
+
const uint32_t* sorted_block_indices, const BsrRowCol* unique_row_cols,
|
|
159
|
+
const T* tpl_values, int* bsr_cols, T* bsr_values)
|
|
160
|
+
|
|
161
|
+
{
|
|
162
|
+
const uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
163
|
+
|
|
164
|
+
if (i >= *d_nnz)
|
|
165
|
+
return;
|
|
166
|
+
|
|
167
|
+
const BsrRowCol row_col = unique_row_cols[i];
|
|
168
|
+
bsr_cols[i] = bsr_get_col(row_col);
|
|
169
|
+
|
|
170
|
+
// Accumulate merged block values
|
|
171
|
+
if (row_col == PRUNED_ROWCOL || bsr_values == nullptr)
|
|
172
|
+
return;
|
|
173
|
+
|
|
174
|
+
const uint32_t beg = i ? block_offsets[i - 1] : 0;
|
|
175
|
+
const uint32_t end = block_offsets[i];
|
|
176
|
+
|
|
177
|
+
T* bsr_val = bsr_values + i * block_size;
|
|
178
|
+
const T* tpl_val = tpl_values + sorted_block_indices[beg] * block_size;
|
|
179
|
+
|
|
180
|
+
for (int k = 0; k < block_size; ++k)
|
|
181
|
+
{
|
|
182
|
+
bsr_val[k] = tpl_val[k];
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
for (uint32_t cur = beg + 1; cur != end; ++cur)
|
|
186
|
+
{
|
|
187
|
+
const T* tpl_val = tpl_values + sorted_block_indices[cur] * block_size;
|
|
188
|
+
for (int k = 0; k < block_size; ++k)
|
|
189
|
+
{
|
|
190
|
+
bsr_val[k] += tpl_val[k];
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
template <typename T>
|
|
196
|
+
void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_per_block, const int row_count,
|
|
197
|
+
const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
|
|
198
|
+
const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
|
|
199
|
+
int* bsr_columns, T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
200
|
+
{
|
|
201
|
+
const int block_size = rows_per_block * cols_per_block;
|
|
202
|
+
|
|
203
|
+
void* context = cuda_context_get_current();
|
|
204
|
+
ContextGuard guard(context);
|
|
205
|
+
|
|
206
|
+
// Per-context cached temporary buffers
|
|
207
|
+
// BsrFromTripletsTemp& bsr_temp = g_bsr_from_triplets_temp_map[context];
|
|
208
|
+
|
|
209
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
210
|
+
|
|
211
|
+
ScopedTemporary<uint32_t> block_indices(context, 2 * nnz + 1);
|
|
212
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
|
|
213
|
+
|
|
214
|
+
cub::DoubleBuffer<uint32_t> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
|
|
215
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
|
|
216
|
+
|
|
217
|
+
uint32_t* unique_triplet_count = block_indices.buffer() + 2 * nnz;
|
|
218
|
+
|
|
219
|
+
// Combine rows and columns so we can sort on them both
|
|
220
|
+
BsrBlockIsNotZero<T> isNotZero{block_size, prune_numerical_zeros ? tpl_values : nullptr};
|
|
221
|
+
BsrBlockInMask mask{masked ? bsr_offsets : nullptr, bsr_columns};
|
|
222
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
|
|
223
|
+
(nnz, row_count, tpl_rows, tpl_columns, isNotZero, mask, d_keys.Current(), d_values.Current()));
|
|
224
|
+
|
|
225
|
+
// Sort
|
|
226
|
+
{
|
|
227
|
+
size_t buff_size = 0;
|
|
228
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
229
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
230
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
// Runlength encode row-col sequences
|
|
234
|
+
{
|
|
235
|
+
size_t buff_size = 0;
|
|
236
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(nullptr, buff_size, d_values.Current(), d_values.Alternate(),
|
|
237
|
+
d_keys.Alternate(), unique_triplet_count, nnz, stream));
|
|
238
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
239
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(temp.buffer(), buff_size, d_values.Current(),
|
|
240
|
+
d_values.Alternate(), d_keys.Alternate(), unique_triplet_count,
|
|
241
|
+
nnz, stream));
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
// Compute row offsets from sorted unique blocks
|
|
245
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, row_count + 1,
|
|
246
|
+
(row_count, unique_triplet_count, d_values.Alternate(), bsr_offsets));
|
|
247
|
+
|
|
248
|
+
if (bsr_nnz)
|
|
249
|
+
{
|
|
250
|
+
// Copy nnz to host, and record an event for the completed transfer if desired
|
|
251
|
+
|
|
252
|
+
memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
|
|
253
|
+
|
|
254
|
+
if (bsr_nnz_event)
|
|
255
|
+
{
|
|
256
|
+
cuda_event_record(bsr_nnz_event, stream);
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
// Scan repeated block counts
|
|
261
|
+
{
|
|
262
|
+
size_t buff_size = 0;
|
|
263
|
+
check_cuda(
|
|
264
|
+
cub::DeviceScan::InclusiveSum(nullptr, buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz, stream));
|
|
265
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
266
|
+
check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size, d_keys.Alternate(), d_keys.Alternate(), nnz,
|
|
267
|
+
stream));
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Accumulate repeated blocks and set column indices
|
|
271
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, nnz,
|
|
272
|
+
(bsr_offsets + row_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
|
|
273
|
+
tpl_values, bsr_columns, bsr_values));
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
__global__ void bsr_transpose_fill_row_col(const int nnz_upper_bound, const int row_count, const int* bsr_offsets,
|
|
277
|
+
const int* bsr_columns, int* block_indices, BsrRowCol* transposed_row_col)
|
|
278
|
+
{
|
|
279
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
280
|
+
|
|
281
|
+
if (i >= nnz_upper_bound)
|
|
282
|
+
{
|
|
283
|
+
// Outside of allocated bounds, do nothing
|
|
284
|
+
return;
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
if (i >= bsr_offsets[row_count])
|
|
288
|
+
{
|
|
289
|
+
// Below upper bound but above actual nnz count, mark as invalid
|
|
290
|
+
transposed_row_col[i] = PRUNED_ROWCOL;
|
|
291
|
+
return;
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
block_indices[i] = i;
|
|
295
|
+
|
|
296
|
+
// Binary search for row
|
|
297
|
+
int lower = 0;
|
|
298
|
+
int upper = row_count - 1;
|
|
299
|
+
|
|
300
|
+
while (lower < upper)
|
|
301
|
+
{
|
|
302
|
+
int mid = lower + (upper - lower) / 2;
|
|
303
|
+
|
|
304
|
+
if (bsr_offsets[mid + 1] <= i)
|
|
305
|
+
{
|
|
306
|
+
lower = mid + 1;
|
|
307
|
+
}
|
|
308
|
+
else
|
|
309
|
+
{
|
|
310
|
+
upper = mid;
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
const int row = lower;
|
|
315
|
+
const int col = bsr_columns[i];
|
|
316
|
+
BsrRowCol row_col = bsr_combine_row_col(col, row);
|
|
317
|
+
transposed_row_col[i] = row_col;
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
template <int Rows, int Cols, typename T> struct BsrBlockTransposer
|
|
321
|
+
{
|
|
322
|
+
void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
|
|
323
|
+
{
|
|
324
|
+
for (int r = 0; r < Rows; ++r)
|
|
325
|
+
{
|
|
326
|
+
for (int c = 0; c < Cols; ++c)
|
|
327
|
+
{
|
|
328
|
+
dest[c * Rows + r] = src[r * Cols + c];
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
};
|
|
333
|
+
|
|
334
|
+
template <typename T> struct BsrBlockTransposer<-1, -1, T>
|
|
335
|
+
{
|
|
336
|
+
|
|
337
|
+
int row_count;
|
|
338
|
+
int col_count;
|
|
339
|
+
|
|
340
|
+
void CUDA_CALLABLE_DEVICE operator()(const T* src, T* dest) const
|
|
341
|
+
{
|
|
342
|
+
for (int r = 0; r < row_count; ++r)
|
|
343
|
+
{
|
|
344
|
+
for (int c = 0; c < col_count; ++c)
|
|
345
|
+
{
|
|
346
|
+
dest[c * row_count + r] = src[r * col_count + c];
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
};
|
|
351
|
+
|
|
352
|
+
template <int Rows, int Cols, typename T>
|
|
353
|
+
__global__ void bsr_transpose_blocks(const int* nnz, const int block_size, BsrBlockTransposer<Rows, Cols, T> transposer,
|
|
354
|
+
const int* block_indices, const BsrRowCol* transposed_indices, const T* bsr_values,
|
|
355
|
+
int* transposed_bsr_columns, T* transposed_bsr_values)
|
|
356
|
+
{
|
|
357
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
358
|
+
if (i >= *nnz)
|
|
359
|
+
return;
|
|
360
|
+
|
|
361
|
+
const int src_idx = block_indices[i];
|
|
362
|
+
|
|
363
|
+
transposer(bsr_values + src_idx * block_size, transposed_bsr_values + i * block_size);
|
|
364
|
+
|
|
365
|
+
transposed_bsr_columns[i] = bsr_get_col(transposed_indices[i]);
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
template <typename T>
|
|
369
|
+
void launch_bsr_transpose_blocks(int nnz, const int* d_nnz, const int block_size, const int rows_per_block,
|
|
370
|
+
const int cols_per_block, const int* block_indices,
|
|
371
|
+
const BsrRowCol* transposed_indices, const T* bsr_values, int* transposed_bsr_columns,
|
|
372
|
+
T* transposed_bsr_values)
|
|
373
|
+
{
|
|
374
|
+
|
|
375
|
+
switch (rows_per_block)
|
|
376
|
+
{
|
|
377
|
+
case 1:
|
|
378
|
+
switch (cols_per_block)
|
|
379
|
+
{
|
|
380
|
+
case 1:
|
|
381
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
382
|
+
(d_nnz, block_size, BsrBlockTransposer<1, 1, T>{}, block_indices, transposed_indices,
|
|
383
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
384
|
+
return;
|
|
385
|
+
case 2:
|
|
386
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
387
|
+
(d_nnz, block_size, BsrBlockTransposer<1, 2, T>{}, block_indices, transposed_indices,
|
|
388
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
389
|
+
return;
|
|
390
|
+
case 3:
|
|
391
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
392
|
+
(d_nnz, block_size, BsrBlockTransposer<1, 3, T>{}, block_indices, transposed_indices,
|
|
393
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
394
|
+
return;
|
|
395
|
+
}
|
|
396
|
+
case 2:
|
|
397
|
+
switch (cols_per_block)
|
|
398
|
+
{
|
|
399
|
+
case 1:
|
|
400
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
401
|
+
(d_nnz, block_size, BsrBlockTransposer<2, 1, T>{}, block_indices, transposed_indices,
|
|
402
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
403
|
+
return;
|
|
404
|
+
case 2:
|
|
405
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
406
|
+
(d_nnz, block_size, BsrBlockTransposer<2, 2, T>{}, block_indices, transposed_indices,
|
|
407
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
408
|
+
return;
|
|
409
|
+
case 3:
|
|
410
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
411
|
+
(d_nnz, block_size, BsrBlockTransposer<2, 3, T>{}, block_indices, transposed_indices,
|
|
412
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
413
|
+
return;
|
|
414
|
+
}
|
|
415
|
+
case 3:
|
|
416
|
+
switch (cols_per_block)
|
|
417
|
+
{
|
|
418
|
+
case 1:
|
|
419
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
420
|
+
(d_nnz, block_size, BsrBlockTransposer<3, 1, T>{}, block_indices, transposed_indices,
|
|
421
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
422
|
+
return;
|
|
423
|
+
case 2:
|
|
424
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
425
|
+
(d_nnz, block_size, BsrBlockTransposer<3, 2, T>{}, block_indices, transposed_indices,
|
|
426
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
427
|
+
return;
|
|
428
|
+
case 3:
|
|
429
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
430
|
+
(d_nnz, block_size, BsrBlockTransposer<3, 3, T>{}, block_indices, transposed_indices,
|
|
431
|
+
bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
432
|
+
return;
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_blocks, nnz,
|
|
437
|
+
(d_nnz, block_size, BsrBlockTransposer<-1, -1, T>{rows_per_block, cols_per_block}, block_indices,
|
|
438
|
+
transposed_indices, bsr_values, transposed_bsr_columns, transposed_bsr_values));
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
template <typename T>
|
|
442
|
+
void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
443
|
+
const int* bsr_offsets, const int* bsr_columns, const T* bsr_values,
|
|
444
|
+
int* transposed_bsr_offsets, int* transposed_bsr_columns, T* transposed_bsr_values)
|
|
445
|
+
{
|
|
446
|
+
|
|
447
|
+
const int block_size = rows_per_block * cols_per_block;
|
|
448
|
+
|
|
449
|
+
void* context = cuda_context_get_current();
|
|
450
|
+
ContextGuard guard(context);
|
|
451
|
+
|
|
452
|
+
cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
|
|
453
|
+
|
|
454
|
+
ScopedTemporary<int> block_indices(context, 2 * nnz);
|
|
455
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
|
|
456
|
+
|
|
457
|
+
cub::DoubleBuffer<int> d_keys(block_indices.buffer(), block_indices.buffer() + nnz);
|
|
458
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
|
|
459
|
+
|
|
460
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
|
|
461
|
+
(nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(), d_values.Current()));
|
|
462
|
+
|
|
463
|
+
// Sort blocks
|
|
464
|
+
{
|
|
465
|
+
size_t buff_size = 0;
|
|
466
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
467
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
468
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
// Compute row offsets from sorted unique blocks
|
|
472
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, col_count + 1,
|
|
473
|
+
(col_count, bsr_offsets + row_count, d_values.Current(), transposed_bsr_offsets));
|
|
474
|
+
|
|
475
|
+
// Move and transpose individual blocks
|
|
476
|
+
if (transposed_bsr_values != nullptr)
|
|
477
|
+
{
|
|
478
|
+
launch_bsr_transpose_blocks(nnz, bsr_offsets + row_count, block_size, rows_per_block, cols_per_block,
|
|
479
|
+
d_keys.Current(), d_values.Current(), bsr_values, transposed_bsr_columns,
|
|
480
|
+
transposed_bsr_values);
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
} // namespace
|
|
485
|
+
|
|
486
|
+
void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
487
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
488
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
|
|
489
|
+
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
490
|
+
{
|
|
491
|
+
return bsr_matrix_from_triplets_device<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
|
|
492
|
+
static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
|
|
493
|
+
bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz,
|
|
494
|
+
bsr_nnz_event);
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
|
|
498
|
+
int* tpl_rows, int* tpl_columns, void* tpl_values,
|
|
499
|
+
bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
|
|
500
|
+
void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
|
|
501
|
+
{
|
|
502
|
+
return bsr_matrix_from_triplets_device<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows,
|
|
503
|
+
tpl_columns, static_cast<const double*>(tpl_values),
|
|
504
|
+
prune_numerical_zeros, masked, bsr_offsets, bsr_columns,
|
|
505
|
+
static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
509
|
+
int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
|
|
510
|
+
int* transposed_bsr_columns, void* transposed_bsr_values)
|
|
511
|
+
{
|
|
512
|
+
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
|
|
513
|
+
static_cast<const float*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
|
|
514
|
+
static_cast<float*>(transposed_bsr_values));
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
void bsr_transpose_double_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
|
|
518
|
+
int* bsr_offsets, int* bsr_columns, void* bsr_values, int* transposed_bsr_offsets,
|
|
519
|
+
int* transposed_bsr_columns, void* transposed_bsr_values)
|
|
520
|
+
{
|
|
521
|
+
bsr_transpose_device(rows_per_block, cols_per_block, row_count, col_count, nnz, bsr_offsets, bsr_columns,
|
|
522
|
+
static_cast<const double*>(bsr_values), transposed_bsr_offsets, transposed_bsr_columns,
|
|
523
|
+
static_cast<double*>(transposed_bsr_values));
|
|
524
|
+
}
|