warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +334 -0
- warp/__init__.pyi +5856 -0
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/_src/builtins.py +10555 -0
- warp/_src/codegen.py +4361 -0
- warp/_src/config.py +178 -0
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/_src/fem/domain.py +553 -0
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/_src/fem/field/nodal_field.py +403 -0
- warp/_src/fem/field/restriction.py +39 -0
- warp/_src/fem/field/virtual.py +1021 -0
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/_src/fem/geometry/deformed_geometry.py +277 -0
- warp/_src/fem/geometry/element.py +854 -0
- warp/_src/fem/geometry/geometry.py +693 -0
- warp/_src/fem/geometry/grid_2d.py +478 -0
- warp/_src/fem/geometry/grid_3d.py +539 -0
- warp/_src/fem/geometry/hexmesh.py +956 -0
- warp/_src/fem/geometry/nanogrid.py +660 -0
- warp/_src/fem/geometry/partition.py +483 -0
- warp/_src/fem/geometry/quadmesh.py +597 -0
- warp/_src/fem/geometry/tetmesh.py +762 -0
- warp/_src/fem/geometry/trimesh.py +588 -0
- warp/_src/fem/integrate.py +2507 -0
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/_src/fem/quadrature/__init__.py +17 -0
- warp/_src/fem/quadrature/pic_quadrature.py +318 -0
- warp/_src/fem/quadrature/quadrature.py +665 -0
- warp/_src/fem/space/__init__.py +248 -0
- warp/_src/fem/space/basis_function_space.py +499 -0
- warp/_src/fem/space/basis_space.py +681 -0
- warp/_src/fem/space/dof_mapper.py +253 -0
- warp/_src/fem/space/function_space.py +312 -0
- warp/_src/fem/space/grid_2d_function_space.py +179 -0
- warp/_src/fem/space/grid_3d_function_space.py +229 -0
- warp/_src/fem/space/hexmesh_function_space.py +255 -0
- warp/_src/fem/space/nanogrid_function_space.py +199 -0
- warp/_src/fem/space/partition.py +435 -0
- warp/_src/fem/space/quadmesh_function_space.py +222 -0
- warp/_src/fem/space/restriction.py +221 -0
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
- warp/_src/fem/space/shape/shape_function.py +134 -0
- warp/_src/fem/space/shape/square_shape_function.py +928 -0
- warp/_src/fem/space/shape/tet_shape_function.py +829 -0
- warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
- warp/_src/fem/space/tetmesh_function_space.py +270 -0
- warp/_src/fem/space/topology.py +461 -0
- warp/_src/fem/space/trimesh_function_space.py +193 -0
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/_src/thirdparty/__init__.py +0 -0
- warp/_src/thirdparty/appdirs.py +598 -0
- warp/_src/thirdparty/dlpack.py +145 -0
- warp/_src/thirdparty/unittest_parallel.py +676 -0
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +33 -0
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +29 -0
- warp/build_dll.py +24 -0
- warp/codegen.py +24 -0
- warp/constants.py +24 -0
- warp/context.py +33 -0
- warp/dlpack.py +24 -0
- warp/examples/__init__.py +24 -0
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +389 -0
- warp/examples/benchmarks/benchmark_cloth.py +296 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
- warp/examples/benchmarks/benchmark_gemm.py +164 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
- warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
- warp/examples/benchmarks/benchmark_launches.py +301 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/browse.py +37 -0
- warp/examples/core/example_cupy.py +86 -0
- warp/examples/core/example_dem.py +241 -0
- warp/examples/core/example_fluid.py +299 -0
- warp/examples/core/example_graph_capture.py +150 -0
- warp/examples/core/example_marching_cubes.py +195 -0
- warp/examples/core/example_mesh.py +180 -0
- warp/examples/core/example_mesh_intersect.py +211 -0
- warp/examples/core/example_nvdb.py +182 -0
- warp/examples/core/example_raycast.py +111 -0
- warp/examples/core/example_raymarch.py +205 -0
- warp/examples/core/example_render_opengl.py +290 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +411 -0
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_torch.py +211 -0
- warp/examples/core/example_wave.py +269 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/distributed/example_jacobi_mpi.py +506 -0
- warp/examples/fem/example_adaptive_grid.py +286 -0
- warp/examples/fem/example_apic_fluid.py +469 -0
- warp/examples/fem/example_burgers.py +261 -0
- warp/examples/fem/example_convection_diffusion.py +181 -0
- warp/examples/fem/example_convection_diffusion_dg.py +225 -0
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +172 -0
- warp/examples/fem/example_diffusion.py +196 -0
- warp/examples/fem/example_diffusion_3d.py +225 -0
- warp/examples/fem/example_diffusion_mgpu.py +225 -0
- warp/examples/fem/example_distortion_energy.py +228 -0
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +242 -0
- warp/examples/fem/example_mixed_elasticity.py +293 -0
- warp/examples/fem/example_navier_stokes.py +263 -0
- warp/examples/fem/example_nonconforming_contact.py +300 -0
- warp/examples/fem/example_stokes.py +213 -0
- warp/examples/fem/example_stokes_transfer.py +262 -0
- warp/examples/fem/example_streamlines.py +357 -0
- warp/examples/fem/utils.py +1047 -0
- warp/examples/interop/example_jax_callable.py +146 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +232 -0
- warp/examples/optim/example_diffray.py +561 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +88 -0
- warp/examples/tile/example_tile_convolution.py +66 -0
- warp/examples/tile/example_tile_fft.py +55 -0
- warp/examples/tile/example_tile_filtering.py +113 -0
- warp/examples/tile/example_tile_matmul.py +85 -0
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/examples/tile/example_tile_mlp.py +385 -0
- warp/examples/tile/example_tile_nbody.py +199 -0
- warp/fabric.py +24 -0
- warp/fem/__init__.py +173 -0
- warp/fem/adaptivity.py +26 -0
- warp/fem/cache.py +30 -0
- warp/fem/dirichlet.py +24 -0
- warp/fem/field/__init__.py +24 -0
- warp/fem/field/field.py +26 -0
- warp/fem/geometry/__init__.py +21 -0
- warp/fem/geometry/closest_point.py +31 -0
- warp/fem/linalg.py +38 -0
- warp/fem/operator.py +32 -0
- warp/fem/polynomial.py +29 -0
- warp/fem/space/__init__.py +22 -0
- warp/fem/space/basis_space.py +24 -0
- warp/fem/space/shape/__init__.py +68 -0
- warp/fem/space/topology.py +24 -0
- warp/fem/types.py +24 -0
- warp/fem/utils.py +32 -0
- warp/jax.py +29 -0
- warp/jax_experimental/__init__.py +29 -0
- warp/jax_experimental/custom_call.py +29 -0
- warp/jax_experimental/ffi.py +39 -0
- warp/jax_experimental/xla_ffi.py +24 -0
- warp/marching_cubes.py +24 -0
- warp/math.py +37 -0
- warp/native/array.h +1687 -0
- warp/native/builtin.h +2327 -0
- warp/native/bvh.cpp +562 -0
- warp/native/bvh.cu +826 -0
- warp/native/bvh.h +555 -0
- warp/native/clang/clang.cpp +541 -0
- warp/native/coloring.cpp +622 -0
- warp/native/crt.cpp +51 -0
- warp/native/crt.h +568 -0
- warp/native/cuda_crt.h +1058 -0
- warp/native/cuda_util.cpp +677 -0
- warp/native/cuda_util.h +313 -0
- warp/native/error.cpp +77 -0
- warp/native/error.h +36 -0
- warp/native/exports.h +2023 -0
- warp/native/fabric.h +246 -0
- warp/native/hashgrid.cpp +311 -0
- warp/native/hashgrid.cu +89 -0
- warp/native/hashgrid.h +240 -0
- warp/native/initializer_array.h +41 -0
- warp/native/intersect.h +1253 -0
- warp/native/intersect_adj.h +375 -0
- warp/native/intersect_tri.h +348 -0
- warp/native/mat.h +5189 -0
- warp/native/mathdx.cpp +93 -0
- warp/native/matnn.h +221 -0
- warp/native/mesh.cpp +266 -0
- warp/native/mesh.cu +406 -0
- warp/native/mesh.h +2097 -0
- warp/native/nanovdb/GridHandle.h +533 -0
- warp/native/nanovdb/HostBuffer.h +591 -0
- warp/native/nanovdb/NanoVDB.h +6246 -0
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +3390 -0
- warp/native/noise.h +859 -0
- warp/native/quat.h +1664 -0
- warp/native/rand.h +342 -0
- warp/native/range.h +145 -0
- warp/native/reduce.cpp +174 -0
- warp/native/reduce.cu +363 -0
- warp/native/runlength_encode.cpp +79 -0
- warp/native/runlength_encode.cu +61 -0
- warp/native/scan.cpp +47 -0
- warp/native/scan.cu +55 -0
- warp/native/scan.h +23 -0
- warp/native/solid_angle.h +466 -0
- warp/native/sort.cpp +251 -0
- warp/native/sort.cu +286 -0
- warp/native/sort.h +35 -0
- warp/native/sparse.cpp +241 -0
- warp/native/sparse.cu +435 -0
- warp/native/spatial.h +1306 -0
- warp/native/svd.h +727 -0
- warp/native/temp_buffer.h +46 -0
- warp/native/tile.h +4124 -0
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +838 -0
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +2199 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +501 -0
- warp/native/volume.cu +68 -0
- warp/native/volume.h +970 -0
- warp/native/volume_builder.cu +483 -0
- warp/native/volume_builder.h +52 -0
- warp/native/volume_impl.h +70 -0
- warp/native/warp.cpp +1143 -0
- warp/native/warp.cu +4604 -0
- warp/native/warp.h +358 -0
- warp/optim/__init__.py +20 -0
- warp/optim/adam.py +24 -0
- warp/optim/linear.py +35 -0
- warp/optim/sgd.py +24 -0
- warp/paddle.py +24 -0
- warp/py.typed +0 -0
- warp/render/__init__.py +22 -0
- warp/render/imgui_manager.py +29 -0
- warp/render/render_opengl.py +24 -0
- warp/render/render_usd.py +24 -0
- warp/render/utils.py +24 -0
- warp/sparse.py +51 -0
- warp/tape.py +24 -0
- warp/tests/__init__.py +1 -0
- warp/tests/__main__.py +4 -0
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/mlp_golden.npy +0 -0
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/spiky.usd +0 -0
- warp/tests/assets/test_grid.nvdb +0 -0
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/test_int32_grid.nvdb +0 -0
- warp/tests/assets/test_vec_grid.nvdb +0 -0
- warp/tests/assets/torus.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -0
- warp/tests/aux_test_class_kernel.py +34 -0
- warp/tests/aux_test_compile_consts_dummy.py +18 -0
- warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
- warp/tests/aux_test_dependent.py +29 -0
- warp/tests/aux_test_grad_customs.py +29 -0
- warp/tests/aux_test_instancing_gc.py +26 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/aux_test_module_unload.py +23 -0
- warp/tests/aux_test_name_clash1.py +40 -0
- warp/tests/aux_test_name_clash2.py +40 -0
- warp/tests/aux_test_reference.py +9 -0
- warp/tests/aux_test_reference_reference.py +8 -0
- warp/tests/aux_test_square.py +16 -0
- warp/tests/aux_test_unresolved_func.py +22 -0
- warp/tests/aux_test_unresolved_symbol.py +22 -0
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/cuda/test_async.py +676 -0
- warp/tests/cuda/test_conditional_captures.py +1147 -0
- warp/tests/cuda/test_ipc.py +124 -0
- warp/tests/cuda/test_mempool.py +233 -0
- warp/tests/cuda/test_multigpu.py +169 -0
- warp/tests/cuda/test_peer.py +139 -0
- warp/tests/cuda/test_pinned.py +84 -0
- warp/tests/cuda/test_streams.py +691 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/geometry/test_bvh.py +335 -0
- warp/tests/geometry/test_hash_grid.py +259 -0
- warp/tests/geometry/test_marching_cubes.py +294 -0
- warp/tests/geometry/test_mesh.py +318 -0
- warp/tests/geometry/test_mesh_query_aabb.py +392 -0
- warp/tests/geometry/test_mesh_query_point.py +935 -0
- warp/tests/geometry/test_mesh_query_ray.py +323 -0
- warp/tests/geometry/test_volume.py +1103 -0
- warp/tests/geometry/test_volume_write.py +346 -0
- warp/tests/interop/__init__.py +0 -0
- warp/tests/interop/test_dlpack.py +730 -0
- warp/tests/interop/test_jax.py +1673 -0
- warp/tests/interop/test_paddle.py +800 -0
- warp/tests/interop/test_torch.py +1001 -0
- warp/tests/run_coverage_serial.py +39 -0
- warp/tests/test_adam.py +162 -0
- warp/tests/test_arithmetic.py +1096 -0
- warp/tests/test_array.py +3756 -0
- warp/tests/test_array_reduce.py +156 -0
- warp/tests/test_assert.py +303 -0
- warp/tests/test_atomic.py +336 -0
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_bool.py +220 -0
- warp/tests/test_builtins_resolution.py +732 -0
- warp/tests/test_closest_point_edge_edge.py +327 -0
- warp/tests/test_codegen.py +974 -0
- warp/tests/test_codegen_instancing.py +1495 -0
- warp/tests/test_compile_consts.py +215 -0
- warp/tests/test_conditional.py +298 -0
- warp/tests/test_context.py +35 -0
- warp/tests/test_copy.py +319 -0
- warp/tests/test_ctypes.py +618 -0
- warp/tests/test_dense.py +73 -0
- warp/tests/test_devices.py +127 -0
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +424 -0
- warp/tests/test_fabricarray.py +998 -0
- warp/tests/test_fast_math.py +72 -0
- warp/tests/test_fem.py +2204 -0
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_fp16.py +136 -0
- warp/tests/test_func.py +501 -0
- warp/tests/test_future_annotations.py +100 -0
- warp/tests/test_generics.py +656 -0
- warp/tests/test_grad.py +893 -0
- warp/tests/test_grad_customs.py +339 -0
- warp/tests/test_grad_debug.py +341 -0
- warp/tests/test_implicit_init.py +411 -0
- warp/tests/test_import.py +45 -0
- warp/tests/test_indexedarray.py +1140 -0
- warp/tests/test_intersect.py +103 -0
- warp/tests/test_iter.py +76 -0
- warp/tests/test_large.py +177 -0
- warp/tests/test_launch.py +411 -0
- warp/tests/test_lerp.py +151 -0
- warp/tests/test_linear_solvers.py +223 -0
- warp/tests/test_lvalue.py +427 -0
- warp/tests/test_map.py +526 -0
- warp/tests/test_mat.py +3515 -0
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +573 -0
- warp/tests/test_mat_lite.py +122 -0
- warp/tests/test_mat_scalar_ops.py +2913 -0
- warp/tests/test_math.py +212 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_module_hashing.py +258 -0
- warp/tests/test_modules_lite.py +70 -0
- warp/tests/test_noise.py +252 -0
- warp/tests/test_operators.py +299 -0
- warp/tests/test_options.py +129 -0
- warp/tests/test_overwrite.py +551 -0
- warp/tests/test_print.py +408 -0
- warp/tests/test_quat.py +2653 -0
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_rand.py +339 -0
- warp/tests/test_reload.py +303 -0
- warp/tests/test_rounding.py +157 -0
- warp/tests/test_runlength_encode.py +196 -0
- warp/tests/test_scalar_ops.py +133 -0
- warp/tests/test_smoothstep.py +108 -0
- warp/tests/test_snippet.py +318 -0
- warp/tests/test_sparse.py +845 -0
- warp/tests/test_spatial.py +2859 -0
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_special_values.py +361 -0
- warp/tests/test_static.py +640 -0
- warp/tests/test_struct.py +901 -0
- warp/tests/test_tape.py +242 -0
- warp/tests/test_transient_module.py +93 -0
- warp/tests/test_triangle_closest_point.py +192 -0
- warp/tests/test_tuple.py +361 -0
- warp/tests/test_types.py +615 -0
- warp/tests/test_utils.py +594 -0
- warp/tests/test_vec.py +1408 -0
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/test_vec_lite.py +80 -0
- warp/tests/test_vec_scalar_ops.py +2327 -0
- warp/tests/test_verify_fp.py +100 -0
- warp/tests/test_version.py +75 -0
- warp/tests/tile/__init__.py +0 -0
- warp/tests/tile/test_tile.py +1519 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +608 -0
- warp/tests/tile/test_tile_load.py +724 -0
- warp/tests/tile/test_tile_mathdx.py +156 -0
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +400 -0
- warp/tests/tile/test_tile_reduce.py +950 -0
- warp/tests/tile/test_tile_shared_memory.py +376 -0
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/tile/test_tile_view.py +173 -0
- warp/tests/unittest_serial.py +47 -0
- warp/tests/unittest_suites.py +430 -0
- warp/tests/unittest_utils.py +469 -0
- warp/tests/walkthrough_debug.py +95 -0
- warp/torch.py +24 -0
- warp/types.py +51 -0
- warp/utils.py +31 -0
- warp_lang-1.10.0.dist-info/METADATA +459 -0
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/WHEEL +5 -0
- warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/native/sparse.cu
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#include "cuda_util.h"
|
|
19
|
+
#include "temp_buffer.h"
|
|
20
|
+
#include "warp.h"
|
|
21
|
+
|
|
22
|
+
#include <cstdint>
|
|
23
|
+
|
|
24
|
+
#define THRUST_IGNORE_CUB_VERSION_CHECK
|
|
25
|
+
|
|
26
|
+
#include <cub/device/device_radix_sort.cuh>
|
|
27
|
+
#include <cub/device/device_run_length_encode.cuh>
|
|
28
|
+
#include <cub/device/device_scan.cuh>
|
|
29
|
+
|
|
30
|
+
extern CUcontext get_current_context();
|
|
31
|
+
|
|
32
|
+
namespace
|
|
33
|
+
{
|
|
34
|
+
|
|
35
|
+
// Combined row+column value that can be radix-sorted with CUB
|
|
36
|
+
using BsrRowCol = uint64_t;
|
|
37
|
+
|
|
38
|
+
static constexpr BsrRowCol PRUNED_ROWCOL = ~BsrRowCol(0);
|
|
39
|
+
|
|
40
|
+
CUDA_CALLABLE BsrRowCol bsr_combine_row_col(uint32_t row, uint32_t col)
|
|
41
|
+
{
|
|
42
|
+
return (static_cast<uint64_t>(row) << 32) | col;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
CUDA_CALLABLE uint32_t bsr_get_row(const BsrRowCol& row_col) { return row_col >> 32; }
|
|
46
|
+
|
|
47
|
+
CUDA_CALLABLE uint32_t bsr_get_col(const BsrRowCol& row_col) { return row_col & INT_MAX; }
|
|
48
|
+
|
|
49
|
+
template <typename T> struct BsrBlockIsNotZero
|
|
50
|
+
{
|
|
51
|
+
int block_size;
|
|
52
|
+
const T* values;
|
|
53
|
+
T zero_mask;
|
|
54
|
+
|
|
55
|
+
BsrBlockIsNotZero(int block_size, const void* values, const uint64_t zero_mask)
|
|
56
|
+
: block_size(block_size), values(static_cast<const T*>(values)), zero_mask(static_cast<T>(zero_mask))
|
|
57
|
+
{}
|
|
58
|
+
|
|
59
|
+
CUDA_CALLABLE_DEVICE bool operator()(int block) const
|
|
60
|
+
{
|
|
61
|
+
if (!values)
|
|
62
|
+
return true;
|
|
63
|
+
|
|
64
|
+
const T* val = values + block * block_size;
|
|
65
|
+
for (int i = 0; i < block_size; ++i, ++val)
|
|
66
|
+
{
|
|
67
|
+
if ((*val & zero_mask) != 0)
|
|
68
|
+
return true;
|
|
69
|
+
}
|
|
70
|
+
return false;
|
|
71
|
+
}
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
template <> struct BsrBlockIsNotZero<void>
|
|
75
|
+
{
|
|
76
|
+
BsrBlockIsNotZero(int block_size, const void* values, const uint64_t zero_mask)
|
|
77
|
+
{}
|
|
78
|
+
|
|
79
|
+
CUDA_CALLABLE_DEVICE bool operator()(int block) const
|
|
80
|
+
{
|
|
81
|
+
return true;
|
|
82
|
+
}
|
|
83
|
+
};
|
|
84
|
+
|
|
85
|
+
struct BsrBlockInMask
|
|
86
|
+
{
|
|
87
|
+
const int nrow;
|
|
88
|
+
const int ncol;
|
|
89
|
+
const int* bsr_offsets;
|
|
90
|
+
const int* bsr_columns;
|
|
91
|
+
const int* device_nnz;
|
|
92
|
+
|
|
93
|
+
CUDA_CALLABLE_DEVICE bool operator()(int index, int row, int col) const
|
|
94
|
+
{
|
|
95
|
+
if (device_nnz != nullptr && index >= *device_nnz)
|
|
96
|
+
return false;
|
|
97
|
+
|
|
98
|
+
if (row < 0 || row >= nrow || col < 0 || col >= ncol){
|
|
99
|
+
return false;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
if (bsr_offsets == nullptr)
|
|
103
|
+
return true;
|
|
104
|
+
|
|
105
|
+
int lower = bsr_offsets[row];
|
|
106
|
+
int upper = bsr_offsets[row + 1] - 1;
|
|
107
|
+
|
|
108
|
+
while (lower < upper)
|
|
109
|
+
{
|
|
110
|
+
const int mid = lower + (upper - lower) / 2;
|
|
111
|
+
|
|
112
|
+
if (bsr_columns[mid] < col)
|
|
113
|
+
{
|
|
114
|
+
lower = mid + 1;
|
|
115
|
+
}
|
|
116
|
+
else
|
|
117
|
+
{
|
|
118
|
+
upper = mid;
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
return lower == upper && (bsr_columns[lower] == col);
|
|
123
|
+
}
|
|
124
|
+
};
|
|
125
|
+
|
|
126
|
+
template <typename T>
|
|
127
|
+
__global__ void bsr_fill_triplet_key_values(const int nnz, const int* tpl_rows, const int* tpl_columns,
|
|
128
|
+
const BsrBlockIsNotZero<T> nonZero, const BsrBlockInMask mask,
|
|
129
|
+
int* block_indices, BsrRowCol* tpl_row_col)
|
|
130
|
+
{
|
|
131
|
+
int block = blockIdx.x * blockDim.x + threadIdx.x;
|
|
132
|
+
if (block >= nnz)
|
|
133
|
+
return;
|
|
134
|
+
|
|
135
|
+
const int row = tpl_rows[block];
|
|
136
|
+
const int col = tpl_columns[block];
|
|
137
|
+
|
|
138
|
+
const BsrRowCol row_col =
|
|
139
|
+
mask(block, row, col) && nonZero(block) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
|
|
140
|
+
|
|
141
|
+
tpl_row_col[block] = row_col;
|
|
142
|
+
block_indices[block] = block;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
template <typename T>
|
|
146
|
+
__global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const BsrRowCol* unique_row_col,
|
|
147
|
+
int* row_offsets)
|
|
148
|
+
{
|
|
149
|
+
const uint32_t row = blockIdx.x * blockDim.x + threadIdx.x;
|
|
150
|
+
|
|
151
|
+
if (row > row_count)
|
|
152
|
+
return;
|
|
153
|
+
|
|
154
|
+
const uint32_t nnz = *d_nnz;
|
|
155
|
+
if (row == 0 || nnz == 0)
|
|
156
|
+
{
|
|
157
|
+
row_offsets[row] = 0;
|
|
158
|
+
return;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
if (bsr_get_row(unique_row_col[nnz - 1]) < row)
|
|
162
|
+
{
|
|
163
|
+
row_offsets[row] = nnz;
|
|
164
|
+
return;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
// binary search for row start
|
|
168
|
+
uint32_t lower = 0;
|
|
169
|
+
uint32_t upper = nnz - 1;
|
|
170
|
+
while (lower < upper)
|
|
171
|
+
{
|
|
172
|
+
uint32_t mid = lower + (upper - lower) / 2;
|
|
173
|
+
|
|
174
|
+
if (bsr_get_row(unique_row_col[mid]) < row)
|
|
175
|
+
{
|
|
176
|
+
lower = mid + 1;
|
|
177
|
+
}
|
|
178
|
+
else
|
|
179
|
+
{
|
|
180
|
+
upper = mid;
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
row_offsets[row] = lower;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
__global__ void bsr_set_column(const int* d_nnz, const BsrRowCol* unique_row_cols, int* bsr_cols)
|
|
188
|
+
{
|
|
189
|
+
const uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
190
|
+
if (i >= *d_nnz)
|
|
191
|
+
return;
|
|
192
|
+
const BsrRowCol row_col = unique_row_cols[i];
|
|
193
|
+
bsr_cols[i] = bsr_get_col(row_col);
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
template <typename T>
|
|
197
|
+
void launch_bsr_fill_triplet_key_values(
|
|
198
|
+
const int block_size,
|
|
199
|
+
const int nnz,
|
|
200
|
+
const BsrBlockInMask& mask,
|
|
201
|
+
const int* tpl_rows,
|
|
202
|
+
const int* tpl_columns,
|
|
203
|
+
const void* tpl_values,
|
|
204
|
+
const uint64_t scalar_zero_mask,
|
|
205
|
+
int* block_indices,
|
|
206
|
+
BsrRowCol* row_col
|
|
207
|
+
)
|
|
208
|
+
{
|
|
209
|
+
BsrBlockIsNotZero<T> isNotZero{block_size, tpl_values, scalar_zero_mask};
|
|
210
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
|
|
211
|
+
(nnz, tpl_rows, tpl_columns, isNotZero, mask, block_indices, row_col ));
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
__global__ void bsr_transpose_fill_row_col(const int nnz_upper_bound, const int row_count, const int* bsr_offsets,
|
|
216
|
+
const int* bsr_columns, int* block_indices, BsrRowCol* transposed_row_col)
|
|
217
|
+
{
|
|
218
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
219
|
+
|
|
220
|
+
if (i >= nnz_upper_bound)
|
|
221
|
+
{
|
|
222
|
+
// Outside of allocated bounds, do nothing
|
|
223
|
+
return;
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
block_indices[i] = i;
|
|
227
|
+
|
|
228
|
+
if (i >= bsr_offsets[row_count])
|
|
229
|
+
{
|
|
230
|
+
// Below upper bound but above actual nnz count, mark as invalid
|
|
231
|
+
transposed_row_col[i] = PRUNED_ROWCOL;
|
|
232
|
+
return;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
// Binary search for row
|
|
236
|
+
int lower = 0;
|
|
237
|
+
int upper = row_count - 1;
|
|
238
|
+
|
|
239
|
+
while (lower < upper)
|
|
240
|
+
{
|
|
241
|
+
int mid = lower + (upper - lower) / 2;
|
|
242
|
+
|
|
243
|
+
if (bsr_offsets[mid + 1] <= i)
|
|
244
|
+
{
|
|
245
|
+
lower = mid + 1;
|
|
246
|
+
}
|
|
247
|
+
else
|
|
248
|
+
{
|
|
249
|
+
upper = mid;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
const int row = lower;
|
|
254
|
+
const int col = bsr_columns[i];
|
|
255
|
+
BsrRowCol row_col = bsr_combine_row_col(col, row);
|
|
256
|
+
transposed_row_col[i] = row_col;
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
} // namespace
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
WP_API void wp_bsr_matrix_from_triplets_device(
|
|
263
|
+
const int block_size,
|
|
264
|
+
int scalar_size,
|
|
265
|
+
const int row_count,
|
|
266
|
+
const int col_count,
|
|
267
|
+
const int nnz,
|
|
268
|
+
const int* tpl_nnz,
|
|
269
|
+
const int* tpl_rows,
|
|
270
|
+
const int* tpl_columns,
|
|
271
|
+
const void* tpl_values,
|
|
272
|
+
const uint64_t scalar_zero_mask,
|
|
273
|
+
const bool masked_topology,
|
|
274
|
+
int* tpl_block_offsets,
|
|
275
|
+
int* tpl_block_indices,
|
|
276
|
+
int* bsr_offsets,
|
|
277
|
+
int* bsr_columns,
|
|
278
|
+
int* bsr_nnz, void* bsr_nnz_event)
|
|
279
|
+
{
|
|
280
|
+
void* context = wp_cuda_context_get_current();
|
|
281
|
+
ContextGuard guard(context);
|
|
282
|
+
|
|
283
|
+
// Per-context cached temporary buffers
|
|
284
|
+
// BsrFromTripletsTemp& bsr_temp = g_bsr_from_triplets_temp_map[context];
|
|
285
|
+
|
|
286
|
+
cudaStream_t stream = static_cast<cudaStream_t>(wp_cuda_stream_get_current());
|
|
287
|
+
|
|
288
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * size_t(nnz));
|
|
289
|
+
ScopedTemporary<int> unique_triplet_count(context, 1);
|
|
290
|
+
|
|
291
|
+
bool return_summed_blocks = tpl_block_offsets != nullptr && tpl_block_indices != nullptr;
|
|
292
|
+
if(!return_summed_blocks)
|
|
293
|
+
{
|
|
294
|
+
// if not provided, allocate temporary offset and indices buffers
|
|
295
|
+
tpl_block_offsets = static_cast<int*>(wp_alloc_device(context, size_t(nnz) * sizeof(int)));
|
|
296
|
+
tpl_block_indices = static_cast<int*>(wp_alloc_device(context, size_t(nnz) * sizeof(int)));
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
cub::DoubleBuffer<int> d_keys(tpl_block_indices, tpl_block_offsets);
|
|
301
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
|
|
302
|
+
|
|
303
|
+
// Combine rows and columns so we can sort on them both,
|
|
304
|
+
// ensuring that blocks that should be pruned are moved to the end
|
|
305
|
+
BsrBlockInMask mask{row_count, col_count, masked_topology ? bsr_offsets : nullptr, bsr_columns, tpl_nnz};
|
|
306
|
+
if (scalar_zero_mask == 0 || tpl_values == nullptr)
|
|
307
|
+
scalar_size = 0;
|
|
308
|
+
switch(scalar_size)
|
|
309
|
+
{
|
|
310
|
+
case sizeof(uint8_t):
|
|
311
|
+
launch_bsr_fill_triplet_key_values<uint8_t>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
|
|
312
|
+
break;
|
|
313
|
+
case sizeof(uint16_t):
|
|
314
|
+
launch_bsr_fill_triplet_key_values<uint16_t>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
|
|
315
|
+
break;
|
|
316
|
+
case sizeof(uint32_t):
|
|
317
|
+
launch_bsr_fill_triplet_key_values<uint32_t>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
|
|
318
|
+
break;
|
|
319
|
+
case sizeof(uint64_t):
|
|
320
|
+
launch_bsr_fill_triplet_key_values<uint64_t>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
|
|
321
|
+
break;
|
|
322
|
+
default:
|
|
323
|
+
// no scalar-level pruning
|
|
324
|
+
launch_bsr_fill_triplet_key_values<void>(block_size, nnz, mask, tpl_rows, tpl_columns, tpl_values, scalar_zero_mask, d_keys.Current(), d_values.Current());
|
|
325
|
+
break;
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
// Sort
|
|
330
|
+
{
|
|
331
|
+
size_t buff_size = 0;
|
|
332
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
333
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
334
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
335
|
+
|
|
336
|
+
// Depending on data size and GPU architecture buffers may have been swapped or not
|
|
337
|
+
// Ensures the sorted keys are available in summed_block_indices if needed
|
|
338
|
+
if(return_summed_blocks && d_keys.Current() != tpl_block_indices)
|
|
339
|
+
{
|
|
340
|
+
check_cuda(cudaMemcpyAsync(tpl_block_indices, d_keys.Current(), nnz * sizeof(int), cudaMemcpyDeviceToDevice, stream));
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
// Runlength encode row-col sequences
|
|
345
|
+
{
|
|
346
|
+
size_t buff_size = 0;
|
|
347
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(nullptr, buff_size, d_values.Current(), d_values.Alternate(),
|
|
348
|
+
tpl_block_offsets, unique_triplet_count.buffer(), nnz, stream));
|
|
349
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
350
|
+
check_cuda(cub::DeviceRunLengthEncode::Encode(temp.buffer(), buff_size, d_values.Current(),
|
|
351
|
+
d_values.Alternate(), tpl_block_offsets, unique_triplet_count.buffer(),
|
|
352
|
+
nnz, stream));
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
// Compute row offsets from sorted unique blocks
|
|
356
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, row_count + 1,
|
|
357
|
+
(row_count, unique_triplet_count.buffer(), d_values.Alternate(), bsr_offsets));
|
|
358
|
+
|
|
359
|
+
if (bsr_nnz)
|
|
360
|
+
{
|
|
361
|
+
// Copy nnz to host, and record an event for the completed transfer if desired
|
|
362
|
+
|
|
363
|
+
wp_memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
|
|
364
|
+
|
|
365
|
+
if (bsr_nnz_event)
|
|
366
|
+
{
|
|
367
|
+
const bool external = true;
|
|
368
|
+
wp_cuda_event_record(bsr_nnz_event, stream, external);
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
// Set column indices
|
|
373
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_set_column, nnz,
|
|
374
|
+
(bsr_offsets + row_count, d_values.Alternate(),
|
|
375
|
+
bsr_columns));
|
|
376
|
+
|
|
377
|
+
// Scan repeated block counts
|
|
378
|
+
if(return_summed_blocks)
|
|
379
|
+
{
|
|
380
|
+
size_t buff_size = 0;
|
|
381
|
+
check_cuda(
|
|
382
|
+
cub::DeviceScan::InclusiveSum(nullptr, buff_size, tpl_block_offsets, tpl_block_offsets, nnz, stream));
|
|
383
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
384
|
+
check_cuda(cub::DeviceScan::InclusiveSum(temp.buffer(), buff_size, tpl_block_offsets, tpl_block_offsets, nnz,
|
|
385
|
+
stream));
|
|
386
|
+
} else {
|
|
387
|
+
// free our temporary buffers
|
|
388
|
+
wp_free_device(context, tpl_block_offsets);
|
|
389
|
+
wp_free_device(context, tpl_block_indices);
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
WP_API void wp_bsr_transpose_device(int row_count, int col_count, int nnz,
|
|
395
|
+
const int* bsr_offsets, const int* bsr_columns,
|
|
396
|
+
int* transposed_bsr_offsets, int* transposed_bsr_columns,
|
|
397
|
+
int* src_block_indices)
|
|
398
|
+
{
|
|
399
|
+
void* context = wp_cuda_context_get_current();
|
|
400
|
+
ContextGuard guard(context);
|
|
401
|
+
|
|
402
|
+
cudaStream_t stream = static_cast<cudaStream_t>(wp_cuda_stream_get_current());
|
|
403
|
+
|
|
404
|
+
ScopedTemporary<BsrRowCol> combined_row_col(context, 2 * nnz);
|
|
405
|
+
|
|
406
|
+
cub::DoubleBuffer<int> d_keys(src_block_indices + nnz, src_block_indices);
|
|
407
|
+
cub::DoubleBuffer<BsrRowCol> d_values(combined_row_col.buffer(), combined_row_col.buffer() + nnz);
|
|
408
|
+
|
|
409
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_transpose_fill_row_col, nnz,
|
|
410
|
+
(nnz, row_count, bsr_offsets, bsr_columns, d_keys.Current(), d_values.Current()));
|
|
411
|
+
|
|
412
|
+
// Sort blocks
|
|
413
|
+
{
|
|
414
|
+
size_t buff_size = 0;
|
|
415
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(nullptr, buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
416
|
+
ScopedTemporary<> temp(context, buff_size);
|
|
417
|
+
check_cuda(cub::DeviceRadixSort::SortPairs(temp.buffer(), buff_size, d_values, d_keys, nnz, 0, 64, stream));
|
|
418
|
+
|
|
419
|
+
// Depending on data size and GPU architecture buffers may have been swapped or not
|
|
420
|
+
// Ensures the sorted keys are available in summed_block_indices if needed
|
|
421
|
+
if(d_keys.Current() != src_block_indices)
|
|
422
|
+
{
|
|
423
|
+
check_cuda(cudaMemcpyAsync(src_block_indices, src_block_indices+nnz, size_t(nnz) * sizeof(int), cudaMemcpyDeviceToDevice, stream));
|
|
424
|
+
}
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
// Compute row offsets from sorted unique blocks
|
|
428
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_find_row_offsets, col_count + 1,
|
|
429
|
+
(col_count, bsr_offsets + row_count, d_values.Current(), transposed_bsr_offsets));
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
wp_launch_device(WP_CURRENT_CONTEXT, bsr_set_column, nnz,
|
|
433
|
+
(bsr_offsets + row_count, d_values.Current(),
|
|
434
|
+
transposed_bsr_columns));
|
|
435
|
+
}
|