warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +21 -7
- warp/autograd.py +14 -6
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +424 -6
- warp/build_dll.py +20 -20
- warp/builtins.py +467 -368
- warp/codegen.py +193 -125
- warp/config.py +56 -12
- warp/constants.py +14 -6
- warp/context.py +524 -277
- warp/dlpack.py +22 -12
- warp/examples/__init__.py +14 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_api.py +14 -6
- warp/examples/benchmarks/benchmark_cloth.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
- warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
- warp/examples/benchmarks/benchmark_gemm.py +82 -48
- warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
- warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
- warp/examples/benchmarks/benchmark_launches.py +14 -6
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/browse.py +14 -6
- warp/examples/core/example_cupy.py +14 -6
- warp/examples/core/example_dem.py +14 -6
- warp/examples/core/example_fluid.py +14 -6
- warp/examples/core/example_graph_capture.py +14 -6
- warp/examples/core/example_marching_cubes.py +14 -6
- warp/examples/core/example_mesh.py +14 -6
- warp/examples/core/example_mesh_intersect.py +14 -6
- warp/examples/core/example_nvdb.py +14 -6
- warp/examples/core/example_raycast.py +14 -6
- warp/examples/core/example_raymarch.py +14 -6
- warp/examples/core/example_render_opengl.py +14 -6
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/core/example_sph.py +14 -6
- warp/examples/core/example_torch.py +14 -6
- warp/examples/core/example_wave.py +14 -6
- warp/examples/fem/example_adaptive_grid.py +14 -6
- warp/examples/fem/example_apic_fluid.py +15 -7
- warp/examples/fem/example_burgers.py +16 -8
- warp/examples/fem/example_convection_diffusion.py +14 -6
- warp/examples/fem/example_convection_diffusion_dg.py +14 -6
- warp/examples/fem/example_deformed_geometry.py +15 -7
- warp/examples/fem/example_diffusion.py +14 -6
- warp/examples/fem/example_diffusion_3d.py +14 -6
- warp/examples/fem/example_diffusion_mgpu.py +14 -6
- warp/examples/fem/example_distortion_energy.py +15 -7
- warp/examples/fem/example_magnetostatics.py +20 -12
- warp/examples/fem/example_mixed_elasticity.py +14 -6
- warp/examples/fem/example_navier_stokes.py +14 -6
- warp/examples/fem/example_nonconforming_contact.py +14 -6
- warp/examples/fem/example_stokes.py +14 -6
- warp/examples/fem/example_stokes_transfer.py +14 -6
- warp/examples/fem/example_streamlines.py +14 -6
- warp/examples/fem/utils.py +24 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_bounce.py +14 -6
- warp/examples/optim/example_cloth_throw.py +14 -6
- warp/examples/optim/example_diffray.py +14 -6
- warp/examples/optim/example_drone.py +14 -6
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/optim/example_inverse_kinematics.py +14 -6
- warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
- warp/examples/optim/example_softbody_properties.py +14 -6
- warp/examples/optim/example_spring_cage.py +14 -6
- warp/examples/optim/example_trajectory.py +14 -6
- warp/examples/sim/example_cartpole.py +14 -6
- warp/examples/sim/example_cloth.py +14 -6
- warp/examples/sim/example_cloth_self_contact.py +14 -6
- warp/examples/sim/example_granular.py +14 -6
- warp/examples/sim/example_granular_collision_sdf.py +14 -6
- warp/examples/sim/example_jacobian_ik.py +14 -6
- warp/examples/sim/example_particle_chain.py +14 -6
- warp/examples/sim/example_quadruped.py +14 -6
- warp/examples/sim/example_rigid_chain.py +14 -6
- warp/examples/sim/example_rigid_contact.py +14 -6
- warp/examples/sim/example_rigid_force.py +14 -6
- warp/examples/sim/example_rigid_gyroscopic.py +14 -6
- warp/examples/sim/example_rigid_soft_contact.py +14 -6
- warp/examples/sim/example_soft_body.py +14 -6
- warp/examples/tile/example_tile_cholesky.py +14 -6
- warp/examples/tile/example_tile_convolution.py +14 -6
- warp/examples/tile/example_tile_fft.py +14 -6
- warp/examples/tile/example_tile_filtering.py +14 -6
- warp/examples/tile/example_tile_matmul.py +16 -10
- warp/examples/tile/example_tile_mlp.py +14 -6
- warp/examples/tile/example_tile_nbody.py +14 -6
- warp/examples/tile/example_tile_walker.py +14 -6
- warp/fabric.py +15 -0
- warp/fem/__init__.py +26 -1
- warp/fem/adaptivity.py +19 -4
- warp/fem/cache.py +15 -0
- warp/fem/dirichlet.py +15 -0
- warp/fem/domain.py +15 -0
- warp/fem/field/__init__.py +15 -0
- warp/fem/field/field.py +15 -0
- warp/fem/field/nodal_field.py +37 -68
- warp/fem/field/restriction.py +15 -0
- warp/fem/field/virtual.py +77 -23
- warp/fem/geometry/__init__.py +15 -0
- warp/fem/geometry/adaptive_nanogrid.py +24 -10
- warp/fem/geometry/closest_point.py +16 -1
- warp/fem/geometry/deformed_geometry.py +20 -2
- warp/fem/geometry/element.py +15 -0
- warp/fem/geometry/geometry.py +20 -0
- warp/fem/geometry/grid_2d.py +27 -12
- warp/fem/geometry/grid_3d.py +27 -15
- warp/fem/geometry/hexmesh.py +20 -7
- warp/fem/geometry/nanogrid.py +24 -11
- warp/fem/geometry/partition.py +15 -0
- warp/fem/geometry/quadmesh.py +28 -13
- warp/fem/geometry/tetmesh.py +18 -4
- warp/fem/geometry/trimesh.py +18 -8
- warp/fem/integrate.py +277 -93
- warp/fem/linalg.py +20 -5
- warp/fem/operator.py +15 -0
- warp/fem/polynomial.py +15 -0
- warp/fem/quadrature/__init__.py +15 -0
- warp/fem/quadrature/pic_quadrature.py +52 -22
- warp/fem/quadrature/quadrature.py +209 -25
- warp/fem/space/__init__.py +16 -1
- warp/fem/space/basis_function_space.py +19 -2
- warp/fem/space/basis_space.py +40 -18
- warp/fem/space/dof_mapper.py +15 -0
- warp/fem/space/function_space.py +15 -0
- warp/fem/space/grid_2d_function_space.py +15 -0
- warp/fem/space/grid_3d_function_space.py +15 -0
- warp/fem/space/hexmesh_function_space.py +17 -2
- warp/fem/space/nanogrid_function_space.py +15 -0
- warp/fem/space/partition.py +21 -2
- warp/fem/space/quadmesh_function_space.py +23 -8
- warp/fem/space/restriction.py +15 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +38 -23
- warp/fem/space/shape/shape_function.py +15 -0
- warp/fem/space/shape/square_shape_function.py +27 -12
- warp/fem/space/shape/tet_shape_function.py +15 -0
- warp/fem/space/shape/triangle_shape_function.py +16 -1
- warp/fem/space/tetmesh_function_space.py +18 -3
- warp/fem/space/topology.py +15 -0
- warp/fem/space/trimesh_function_space.py +17 -2
- warp/fem/types.py +15 -0
- warp/fem/utils.py +27 -6
- warp/jax.py +28 -7
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +103 -6
- warp/native/array.h +28 -6
- warp/native/builtin.h +44 -9
- warp/native/bvh.cpp +18 -7
- warp/native/bvh.cu +57 -20
- warp/native/bvh.h +17 -7
- warp/native/clang/clang.cpp +45 -9
- warp/native/coloring.cpp +15 -6
- warp/native/crt.cpp +15 -6
- warp/native/crt.h +15 -6
- warp/native/cuda_crt.h +15 -6
- warp/native/cuda_util.cpp +29 -6
- warp/native/cuda_util.h +17 -6
- warp/native/error.cpp +15 -6
- warp/native/error.h +15 -6
- warp/native/exports.h +85 -63
- warp/native/fabric.h +15 -6
- warp/native/hashgrid.cpp +15 -6
- warp/native/hashgrid.cu +15 -6
- warp/native/hashgrid.h +15 -6
- warp/native/initializer_array.h +15 -6
- warp/native/intersect.h +41 -32
- warp/native/intersect_adj.h +48 -39
- warp/native/intersect_tri.h +17 -0
- warp/native/marching.cpp +16 -0
- warp/native/marching.cu +16 -7
- warp/native/marching.h +17 -0
- warp/native/mat.h +528 -15
- warp/native/mathdx.cpp +15 -6
- warp/native/matnn.h +15 -6
- warp/native/mesh.cpp +15 -6
- warp/native/mesh.cu +15 -6
- warp/native/mesh.h +25 -16
- warp/native/noise.h +15 -6
- warp/native/quat.h +114 -17
- warp/native/rand.h +21 -6
- warp/native/range.h +15 -6
- warp/native/reduce.cpp +15 -6
- warp/native/reduce.cu +15 -6
- warp/native/runlength_encode.cpp +15 -6
- warp/native/runlength_encode.cu +15 -6
- warp/native/scan.cpp +15 -6
- warp/native/scan.cu +15 -6
- warp/native/scan.h +15 -6
- warp/native/solid_angle.h +17 -0
- warp/native/sort.cpp +137 -65
- warp/native/sort.cu +167 -21
- warp/native/sort.h +23 -7
- warp/native/sparse.cpp +58 -28
- warp/native/sparse.cu +67 -23
- warp/native/spatial.h +15 -6
- warp/native/svd.h +131 -6
- warp/native/temp_buffer.h +15 -6
- warp/native/tile.h +316 -111
- warp/native/tile_reduce.h +61 -9
- warp/native/vec.h +83 -13
- warp/native/volume.cpp +100 -119
- warp/native/volume.cu +15 -6
- warp/native/volume.h +15 -6
- warp/native/volume_builder.cu +40 -16
- warp/native/volume_builder.h +21 -6
- warp/native/volume_impl.h +15 -6
- warp/native/warp.cpp +20 -12
- warp/native/warp.cu +114 -16
- warp/native/warp.h +34 -16
- warp/optim/__init__.py +14 -6
- warp/optim/adam.py +14 -6
- warp/optim/linear.py +25 -10
- warp/optim/sgd.py +14 -6
- warp/paddle.py +14 -6
- warp/render/__init__.py +14 -6
- warp/render/render_opengl.py +14 -6
- warp/render/render_usd.py +14 -6
- warp/render/utils.py +14 -6
- warp/sim/__init__.py +14 -7
- warp/sim/articulation.py +18 -10
- warp/sim/collide.py +35 -16
- warp/sim/graph_coloring.py +14 -6
- warp/sim/import_mjcf.py +463 -162
- warp/sim/import_snu.py +14 -7
- warp/sim/import_urdf.py +46 -18
- warp/sim/import_usd.py +14 -7
- warp/sim/inertia.py +14 -6
- warp/sim/integrator.py +14 -6
- warp/sim/integrator_euler.py +19 -11
- warp/sim/integrator_featherstone.py +17 -16
- warp/sim/integrator_vbd.py +222 -8
- warp/sim/integrator_xpbd.py +19 -11
- warp/sim/model.py +56 -19
- warp/sim/particles.py +14 -6
- warp/sim/render.py +14 -6
- warp/sim/utils.py +17 -2
- warp/sparse.py +657 -555
- warp/stubs.py +231 -19
- warp/tape.py +14 -6
- warp/tests/aux_test_class_kernel.py +14 -6
- warp/tests/aux_test_compile_consts_dummy.py +14 -6
- warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
- warp/tests/aux_test_dependent.py +14 -6
- warp/tests/aux_test_grad_customs.py +14 -6
- warp/tests/aux_test_instancing_gc.py +14 -6
- warp/tests/aux_test_module_unload.py +14 -6
- warp/tests/aux_test_name_clash1.py +14 -6
- warp/tests/aux_test_name_clash2.py +14 -6
- warp/tests/aux_test_unresolved_func.py +14 -6
- warp/tests/aux_test_unresolved_symbol.py +14 -6
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_async.py → cuda/test_async.py} +14 -6
- warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
- warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
- warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
- warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
- warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
- warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
- warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
- warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
- warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
- warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
- warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
- warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
- warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
- warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
- warp/tests/run_coverage_serial.py +14 -6
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
- warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
- warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
- warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
- warp/tests/{test_model.py → sim/test_model.py} +55 -7
- warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_adam.py +14 -6
- warp/tests/test_arithmetic.py +14 -6
- warp/tests/test_array.py +14 -6
- warp/tests/test_array_reduce.py +14 -6
- warp/tests/test_assert.py +14 -6
- warp/tests/test_atomic.py +14 -6
- warp/tests/test_bool.py +15 -7
- warp/tests/test_builtins_resolution.py +14 -6
- warp/tests/test_closest_point_edge_edge.py +14 -6
- warp/tests/test_codegen.py +14 -6
- warp/tests/test_codegen_instancing.py +14 -6
- warp/tests/test_compile_consts.py +14 -6
- warp/tests/test_conditional.py +14 -6
- warp/tests/test_context.py +14 -6
- warp/tests/test_copy.py +14 -6
- warp/tests/test_ctypes.py +14 -6
- warp/tests/test_dense.py +14 -6
- warp/tests/test_devices.py +14 -6
- warp/tests/test_examples.py +42 -42
- warp/tests/test_fabricarray.py +14 -6
- warp/tests/test_fast_math.py +14 -6
- warp/tests/test_fem.py +37 -10
- warp/tests/test_fp16.py +14 -6
- warp/tests/test_func.py +14 -6
- warp/tests/test_future_annotations.py +14 -6
- warp/tests/test_generics.py +14 -6
- warp/tests/test_grad.py +14 -6
- warp/tests/test_grad_customs.py +14 -6
- warp/tests/test_grad_debug.py +14 -6
- warp/tests/test_implicit_init.py +14 -6
- warp/tests/test_import.py +14 -6
- warp/tests/test_indexedarray.py +14 -6
- warp/tests/test_intersect.py +14 -6
- warp/tests/test_iter.py +14 -6
- warp/tests/test_large.py +14 -6
- warp/tests/test_launch.py +14 -6
- warp/tests/test_lerp.py +14 -6
- warp/tests/test_linear_solvers.py +15 -11
- warp/tests/test_lvalue.py +14 -6
- warp/tests/test_mat.py +247 -85
- warp/tests/test_mat_lite.py +14 -6
- warp/tests/test_mat_scalar_ops.py +18 -10
- warp/tests/test_math.py +14 -6
- warp/tests/test_mlp.py +14 -6
- warp/tests/test_module_hashing.py +14 -6
- warp/tests/test_modules_lite.py +14 -6
- warp/tests/test_noise.py +14 -6
- warp/tests/test_operators.py +14 -6
- warp/tests/test_options.py +14 -6
- warp/tests/test_overwrite.py +15 -60
- warp/tests/test_print.py +14 -6
- warp/tests/test_quat.py +81 -52
- warp/tests/test_rand.py +58 -43
- warp/tests/test_reload.py +14 -6
- warp/tests/test_rounding.py +14 -6
- warp/tests/test_runlength_encode.py +14 -6
- warp/tests/test_scalar_ops.py +14 -6
- warp/tests/test_smoothstep.py +14 -6
- warp/tests/test_snippet.py +15 -0
- warp/tests/test_sparse.py +61 -12
- warp/tests/test_spatial.py +89 -6
- warp/tests/test_special_values.py +14 -6
- warp/tests/test_static.py +15 -7
- warp/tests/test_struct.py +14 -6
- warp/tests/test_tape.py +14 -6
- warp/tests/test_transient_module.py +14 -6
- warp/tests/test_triangle_closest_point.py +14 -6
- warp/tests/test_types.py +14 -6
- warp/tests/test_utils.py +98 -10
- warp/tests/test_vec.py +60 -40
- warp/tests/test_vec_lite.py +14 -6
- warp/tests/test_vec_scalar_ops.py +14 -6
- warp/tests/test_verify_fp.py +14 -6
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
- warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
- warp/tests/unittest_serial.py +15 -6
- warp/tests/unittest_suites.py +59 -65
- warp/tests/unittest_utils.py +16 -7
- warp/tests/walkthrough_debug.py +14 -6
- warp/thirdparty/unittest_parallel.py +15 -8
- warp/torch.py +14 -6
- warp/types.py +124 -664
- warp/utils.py +151 -78
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
- warp_lang-1.7.0.dist-info/RECORD +429 -0
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
- warp/examples/optim/example_walker.py +0 -309
- warp/native/cutlass_gemm.cpp +0 -34
- warp/native/cutlass_gemm.cu +0 -373
- warp/tests/test_matmul.py +0 -503
- warp/tests/test_matmul_lite.py +0 -403
- warp/tests/test_vbd.py +0 -378
- warp/tests/unused_test_misc.py +0 -69
- warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
- warp_lang-1.6.1.dist-info/RECORD +0 -419
- {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/sparse.py
CHANGED
|
@@ -1,10 +1,39 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
1
16
|
import ctypes
|
|
2
17
|
from typing import Any, Generic, Optional, Tuple, TypeVar, Union
|
|
3
18
|
|
|
4
19
|
import warp as wp
|
|
5
20
|
import warp.types
|
|
6
21
|
import warp.utils
|
|
7
|
-
from warp.types import
|
|
22
|
+
from warp.types import (
|
|
23
|
+
Array,
|
|
24
|
+
Cols,
|
|
25
|
+
Rows,
|
|
26
|
+
Scalar,
|
|
27
|
+
Vector,
|
|
28
|
+
is_array,
|
|
29
|
+
scalar_types,
|
|
30
|
+
type_is_matrix,
|
|
31
|
+
type_length,
|
|
32
|
+
type_repr,
|
|
33
|
+
type_scalar_type,
|
|
34
|
+
type_to_warp,
|
|
35
|
+
types_equal,
|
|
36
|
+
)
|
|
8
37
|
|
|
9
38
|
# typing hints
|
|
10
39
|
|
|
@@ -30,50 +59,89 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
30
59
|
Should not be constructed directly but through functions such as :func:`bsr_zeros`.
|
|
31
60
|
|
|
32
61
|
Attributes:
|
|
33
|
-
nrow (int): Number of rows of blocks
|
|
34
|
-
ncol (int): Number of columns of blocks
|
|
35
|
-
nnz (int): Upper bound for the number of non-zero blocks, used for
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
62
|
+
nrow (int): Number of rows of blocks.
|
|
63
|
+
ncol (int): Number of columns of blocks.
|
|
64
|
+
nnz (int): Upper bound for the number of non-zero blocks, used for
|
|
65
|
+
dimensioning launches. The exact number is at ``offsets[nrow-1]``.
|
|
66
|
+
See also :meth:`nnz_sync`.
|
|
67
|
+
offsets (Array[int]): Array of size at least ``1 + nrow`` such that the
|
|
68
|
+
start and end indices of the blocks of row ``r`` are ``offsets[r]``
|
|
69
|
+
and ``offsets[r+1]``, respectively.
|
|
70
|
+
columns (Array[int]): Array of size at least equal to ``nnz`` containing
|
|
71
|
+
block column indices.
|
|
72
|
+
values (Array[BlockType]): Array of size at least equal to ``nnz``
|
|
73
|
+
containing block values.
|
|
39
74
|
"""
|
|
40
75
|
|
|
41
76
|
@property
|
|
42
77
|
def scalar_type(self) -> Scalar:
|
|
43
|
-
"""Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
|
|
44
|
-
return
|
|
78
|
+
"""Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type."""
|
|
79
|
+
return type_scalar_type(self.values.dtype)
|
|
45
80
|
|
|
46
81
|
@property
|
|
47
82
|
def block_shape(self) -> Tuple[int, int]:
|
|
48
|
-
"""Shape of the individual blocks"""
|
|
83
|
+
"""Shape of the individual blocks."""
|
|
49
84
|
return getattr(self.values.dtype, "_shape_", (1, 1))
|
|
50
85
|
|
|
51
86
|
@property
|
|
52
87
|
def block_size(self) -> int:
|
|
53
|
-
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
|
|
54
|
-
return
|
|
88
|
+
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
|
|
89
|
+
return type_length(self.values.dtype)
|
|
55
90
|
|
|
56
91
|
@property
|
|
57
92
|
def shape(self) -> Tuple[int, int]:
|
|
58
|
-
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
|
|
93
|
+
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block."""
|
|
59
94
|
block_shape = self.block_shape
|
|
60
95
|
return (self.nrow * block_shape[0], self.ncol * block_shape[1])
|
|
61
96
|
|
|
62
97
|
@property
|
|
63
98
|
def dtype(self) -> type:
|
|
64
|
-
"""Data type for individual block values"""
|
|
99
|
+
"""Data type for individual block values."""
|
|
65
100
|
return self.values.dtype
|
|
66
101
|
|
|
67
102
|
@property
|
|
68
103
|
def device(self) -> wp.context.Device:
|
|
69
|
-
"""Device on which offsets
|
|
104
|
+
"""Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
|
|
70
105
|
return self.values.device
|
|
71
106
|
|
|
107
|
+
@property
|
|
108
|
+
def scalar_values(self) -> wp.array:
|
|
109
|
+
"""Accesses the ``values`` array as a 3d scalar array."""
|
|
110
|
+
if self.block_shape == (1, 1):
|
|
111
|
+
return self.values.reshape((self.nnz, 1, 1))
|
|
112
|
+
|
|
113
|
+
def _as_3d_array(arr):
|
|
114
|
+
return wp.array(
|
|
115
|
+
ptr=arr.ptr,
|
|
116
|
+
capacity=arr.capacity,
|
|
117
|
+
device=arr.device,
|
|
118
|
+
dtype=self.scalar_type,
|
|
119
|
+
shape=(self.nnz, *self.block_shape),
|
|
120
|
+
grad=None if arr.grad is None else _as_3d_array(arr.grad),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
values_view = _as_3d_array(self.values)
|
|
124
|
+
values_view._ref = self.values # keep ref in case we're garbage collected
|
|
125
|
+
return values_view
|
|
126
|
+
|
|
127
|
+
def uncompress_rows(self, out: wp.array = None) -> wp.array:
|
|
128
|
+
"""Compute the row index for each non-zero block from the compressed row offsets."""
|
|
129
|
+
if out is None:
|
|
130
|
+
out = wp.empty(self.nnz, dtype=int, device=self.device)
|
|
131
|
+
|
|
132
|
+
wp.launch(
|
|
133
|
+
kernel=_bsr_get_block_row,
|
|
134
|
+
device=self.device,
|
|
135
|
+
dim=self.nnz,
|
|
136
|
+
inputs=[self.nrow, self.offsets, out],
|
|
137
|
+
)
|
|
138
|
+
return out
|
|
139
|
+
|
|
72
140
|
def nnz_sync(self):
|
|
73
|
-
"""
|
|
74
|
-
and
|
|
141
|
+
"""Ensure that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed
|
|
142
|
+
and update the nnz upper bound.
|
|
75
143
|
|
|
76
|
-
See also :meth:`copy_nnz_async
|
|
144
|
+
See also :meth:`copy_nnz_async`.
|
|
77
145
|
"""
|
|
78
146
|
|
|
79
147
|
if self._is_nnz_transfer_setup():
|
|
@@ -84,10 +152,11 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
84
152
|
|
|
85
153
|
def copy_nnz_async(self, known_nnz: int = None):
|
|
86
154
|
"""
|
|
87
|
-
|
|
155
|
+
Start the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
|
|
156
|
+
|
|
88
157
|
Needs to be called whenever the offsets array has been modified from outside ``warp.sparse``.
|
|
89
158
|
|
|
90
|
-
See also :meth:`nnz_sync
|
|
159
|
+
See also :meth:`nnz_sync`.
|
|
91
160
|
"""
|
|
92
161
|
if known_nnz is not None:
|
|
93
162
|
self.nnz = int(known_nnz)
|
|
@@ -171,35 +240,33 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
171
240
|
return _BsrScalingExpression(self, -1.0)
|
|
172
241
|
|
|
173
242
|
def transpose(self):
|
|
174
|
-
"""
|
|
243
|
+
"""Return a transposed copy of this matrix."""
|
|
175
244
|
return bsr_transposed(self)
|
|
176
245
|
|
|
177
246
|
|
|
178
247
|
def bsr_matrix_t(dtype: BlockType):
|
|
179
|
-
dtype =
|
|
248
|
+
dtype = type_to_warp(dtype)
|
|
180
249
|
|
|
181
|
-
if not
|
|
182
|
-
raise ValueError(
|
|
183
|
-
f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
|
|
184
|
-
)
|
|
250
|
+
if not type_is_matrix(dtype) and dtype not in scalar_types:
|
|
251
|
+
raise ValueError(f"BsrMatrix block type must be either warp matrix or scalar; got {type_repr(dtype)}")
|
|
185
252
|
|
|
186
253
|
class BsrMatrixTyped(BsrMatrix):
|
|
187
254
|
nrow: int
|
|
188
|
-
"""Number of rows of blocks"""
|
|
255
|
+
"""Number of rows of blocks."""
|
|
189
256
|
ncol: int
|
|
190
|
-
"""Number of columns of blocks"""
|
|
257
|
+
"""Number of columns of blocks."""
|
|
191
258
|
nnz: int
|
|
192
|
-
"""Upper bound for the number of non-zeros"""
|
|
259
|
+
"""Upper bound for the number of non-zeros."""
|
|
193
260
|
offsets: wp.array(dtype=int)
|
|
194
|
-
"""Array of size at least 1 +
|
|
261
|
+
"""Array of size at least ``1 + nrow``."""
|
|
195
262
|
columns: wp.array(dtype=int)
|
|
196
|
-
"""Array of size at least equal to nnz"""
|
|
263
|
+
"""Array of size at least equal to ``nnz``."""
|
|
197
264
|
values: wp.array(dtype=dtype)
|
|
198
265
|
|
|
199
266
|
module = wp.get_module(BsrMatrix.__module__)
|
|
200
267
|
|
|
201
268
|
if hasattr(dtype, "_shape_"):
|
|
202
|
-
type_str = f"{
|
|
269
|
+
type_str = f"{type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
|
|
203
270
|
else:
|
|
204
271
|
type_str = dtype.__name__
|
|
205
272
|
key = f"{BsrMatrix.__qualname__}_{type_str}"
|
|
@@ -220,16 +287,16 @@ def bsr_zeros(
|
|
|
220
287
|
block_type: BlockType,
|
|
221
288
|
device: wp.context.Devicelike = None,
|
|
222
289
|
) -> BsrMatrix:
|
|
223
|
-
"""
|
|
224
|
-
Constructs and returns an empty BSR or CSR matrix with the given shape
|
|
290
|
+
"""Construct and return an empty BSR or CSR matrix with the given shape.
|
|
225
291
|
|
|
226
292
|
Args:
|
|
227
|
-
bsr: The BSR or CSR matrix to set to zero
|
|
228
|
-
rows_of_blocks: Number of rows of blocks
|
|
229
|
-
cols_of_blocks: Number of columns of blocks
|
|
230
|
-
block_type: Type of individual blocks.
|
|
231
|
-
|
|
232
|
-
|
|
293
|
+
bsr: The BSR or CSR matrix to set to zero.
|
|
294
|
+
rows_of_blocks: Number of rows of blocks.
|
|
295
|
+
cols_of_blocks: Number of columns of blocks.
|
|
296
|
+
block_type: Type of individual blocks.
|
|
297
|
+
For CSR matrices, this should be a scalar type.
|
|
298
|
+
For BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`).
|
|
299
|
+
device: Device on which to allocate the matrix arrays.
|
|
233
300
|
"""
|
|
234
301
|
|
|
235
302
|
bsr = bsr_matrix_t(block_type)()
|
|
@@ -266,13 +333,12 @@ def bsr_set_zero(
|
|
|
266
333
|
rows_of_blocks: Optional[int] = None,
|
|
267
334
|
cols_of_blocks: Optional[int] = None,
|
|
268
335
|
):
|
|
269
|
-
"""
|
|
270
|
-
Sets a BSR matrix to zero, possibly changing its size
|
|
336
|
+
"""Set a BSR matrix to zero, possibly changing its size.
|
|
271
337
|
|
|
272
338
|
Args:
|
|
273
|
-
bsr: The BSR or CSR matrix to set to zero
|
|
274
|
-
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
275
|
-
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
339
|
+
bsr: The BSR or CSR matrix to set to zero.
|
|
340
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks.
|
|
341
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks.
|
|
276
342
|
"""
|
|
277
343
|
|
|
278
344
|
if rows_of_blocks is not None:
|
|
@@ -289,46 +355,55 @@ def bsr_set_from_triplets(
|
|
|
289
355
|
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
290
356
|
rows: "Array[int]",
|
|
291
357
|
columns: "Array[int]",
|
|
292
|
-
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
358
|
+
values: Optional["Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]"] = None,
|
|
293
359
|
prune_numerical_zeros: bool = True,
|
|
360
|
+
masked: bool = False,
|
|
294
361
|
):
|
|
295
|
-
"""
|
|
296
|
-
Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
362
|
+
"""Fill a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
297
363
|
|
|
298
364
|
The first dimension of the three input arrays must match and indicates the number of COO triplets.
|
|
299
365
|
|
|
300
366
|
Args:
|
|
301
|
-
dest: Sparse matrix to populate
|
|
302
|
-
rows: Row index for each non-zero
|
|
303
|
-
columns: Columns index for each non-zero
|
|
367
|
+
dest: Sparse matrix to populate.
|
|
368
|
+
rows: Row index for each non-zero.
|
|
369
|
+
columns: Columns index for each non-zero.
|
|
304
370
|
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
305
|
-
to the
|
|
306
|
-
|
|
371
|
+
to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
|
|
372
|
+
If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
|
|
373
|
+
prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
|
|
374
|
+
masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
|
|
307
375
|
"""
|
|
308
376
|
|
|
309
|
-
if
|
|
377
|
+
if rows.device != columns.device or rows.device != dest.device:
|
|
310
378
|
raise ValueError("All arguments must reside on the same device")
|
|
311
379
|
|
|
312
|
-
if
|
|
380
|
+
if rows.shape[0] != columns.shape[0]:
|
|
313
381
|
raise ValueError("All triplet arrays must have the same length")
|
|
314
382
|
|
|
315
383
|
# Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
|
|
316
|
-
if values
|
|
317
|
-
if values.
|
|
318
|
-
raise ValueError("
|
|
319
|
-
|
|
320
|
-
if values.shape[
|
|
321
|
-
raise ValueError(
|
|
322
|
-
|
|
323
|
-
|
|
384
|
+
if values is not None:
|
|
385
|
+
if values.device != rows.device:
|
|
386
|
+
raise ValueError("All arguments must reside on the same device")
|
|
387
|
+
|
|
388
|
+
if values.shape[0] != rows.shape[0]:
|
|
389
|
+
raise ValueError("All triplet arrays must have the same length")
|
|
390
|
+
|
|
391
|
+
if values.ndim == 1:
|
|
392
|
+
if values.dtype != dest.values.dtype:
|
|
393
|
+
raise ValueError("Values array type must correspond to that of dest matrix")
|
|
394
|
+
elif values.ndim == 3:
|
|
395
|
+
if values.shape[1:] != dest.block_shape:
|
|
396
|
+
raise ValueError(
|
|
397
|
+
f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
|
|
398
|
+
)
|
|
324
399
|
|
|
325
|
-
|
|
326
|
-
|
|
400
|
+
if type_scalar_type(values.dtype) != dest.scalar_type:
|
|
401
|
+
raise ValueError("Scalar type of values array should correspond to that of matrix")
|
|
327
402
|
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
403
|
+
if not values.is_contiguous:
|
|
404
|
+
raise ValueError("Multi-dimensional values array should be contiguous")
|
|
405
|
+
else:
|
|
406
|
+
raise ValueError("Number of dimension for values array should be 1 or 3")
|
|
332
407
|
|
|
333
408
|
nnz = rows.shape[0]
|
|
334
409
|
if nnz == 0:
|
|
@@ -336,7 +411,8 @@ def bsr_set_from_triplets(
|
|
|
336
411
|
return
|
|
337
412
|
|
|
338
413
|
# Increase dest array sizes if needed
|
|
339
|
-
|
|
414
|
+
if not masked:
|
|
415
|
+
_bsr_ensure_fits(dest, nnz=nnz)
|
|
340
416
|
|
|
341
417
|
device = dest.values.device
|
|
342
418
|
scalar_type = dest.scalar_type
|
|
@@ -366,16 +442,51 @@ def bsr_set_from_triplets(
|
|
|
366
442
|
nnz,
|
|
367
443
|
ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
368
444
|
ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
369
|
-
ctypes.cast(values.ptr, ctypes.c_void_p),
|
|
445
|
+
None if values is None else ctypes.cast(values.ptr, ctypes.c_void_p),
|
|
370
446
|
prune_numerical_zeros,
|
|
447
|
+
masked,
|
|
371
448
|
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
372
449
|
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
373
|
-
ctypes.cast(dest.values.ptr, ctypes.c_void_p),
|
|
450
|
+
None if values is None else ctypes.cast(dest.values.ptr, ctypes.c_void_p),
|
|
374
451
|
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
375
452
|
nnz_event,
|
|
376
453
|
)
|
|
377
454
|
|
|
378
455
|
|
|
456
|
+
def bsr_from_triplets(
|
|
457
|
+
rows_of_blocks: int,
|
|
458
|
+
cols_of_blocks: int,
|
|
459
|
+
rows: "Array[int]",
|
|
460
|
+
columns: "Array[int]",
|
|
461
|
+
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
462
|
+
prune_numerical_zeros: bool = True,
|
|
463
|
+
):
|
|
464
|
+
"""Constructs a BSR matrix with values defined by coordinate-oriented (COO) triplets.
|
|
465
|
+
|
|
466
|
+
The first dimension of the three input arrays must match and indicates the number of COO triplets.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
rows_of_blocks: Number of rows of blocks.
|
|
470
|
+
cols_of_blocks: Number of columns of blocks.
|
|
471
|
+
rows: Row index for each non-zero.
|
|
472
|
+
columns: Columns index for each non-zero.
|
|
473
|
+
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
474
|
+
to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
|
|
475
|
+
prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
if values.ndim == 3:
|
|
479
|
+
block_type = wp.mat(shape=values.shape[1:], dtype=values.dtype)
|
|
480
|
+
else:
|
|
481
|
+
block_type = values.dtype
|
|
482
|
+
|
|
483
|
+
A = bsr_zeros(
|
|
484
|
+
rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
|
|
485
|
+
)
|
|
486
|
+
bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
|
|
487
|
+
return A
|
|
488
|
+
|
|
489
|
+
|
|
379
490
|
class _BsrExpression(Generic[_BlockType]):
|
|
380
491
|
pass
|
|
381
492
|
|
|
@@ -486,96 +597,73 @@ def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
|
|
|
486
597
|
raise ValueError("Argument cannot be interpreted as a BsrMatrix")
|
|
487
598
|
|
|
488
599
|
|
|
489
|
-
@wp.
|
|
490
|
-
def
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
dest_offsets: wp.array(dtype=int),
|
|
600
|
+
@wp.func
|
|
601
|
+
def _bsr_row_index(
|
|
602
|
+
offsets: wp.array(dtype=int),
|
|
603
|
+
row_count: int,
|
|
604
|
+
block: int,
|
|
495
605
|
):
|
|
496
|
-
row
|
|
606
|
+
"""Index of the row containing a block, or -1 if non-existing."""
|
|
607
|
+
return wp.where(block < offsets[row_count], wp.lower_bound(offsets, 0, row_count + 1, block + 1), 0) - 1
|
|
497
608
|
|
|
498
|
-
base_offset = src_offsets[row] * row_factor * col_factor
|
|
499
|
-
row_count = src_offsets[1 + row] - src_offsets[row]
|
|
500
609
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
@wp.kernel
|
|
509
|
-
def _bsr_assign_split_blocks(
|
|
510
|
-
structure_only: wp.bool,
|
|
511
|
-
scale: Any,
|
|
512
|
-
row_factor: int,
|
|
513
|
-
col_factor: int,
|
|
514
|
-
dest_row_count: int,
|
|
515
|
-
src_offsets: wp.array(dtype=int),
|
|
516
|
-
src_columns: wp.array(dtype=int),
|
|
517
|
-
src_values: wp.array3d(dtype=Any),
|
|
518
|
-
dest_offsets: wp.array(dtype=int),
|
|
519
|
-
dest_columns: wp.array(dtype=int),
|
|
520
|
-
dest_values: wp.array3d(dtype=Any),
|
|
610
|
+
@wp.func
|
|
611
|
+
def _bsr_block_index(
|
|
612
|
+
row: int,
|
|
613
|
+
col: int,
|
|
614
|
+
bsr_offsets: wp.array(dtype=int),
|
|
615
|
+
bsr_columns: wp.array(dtype=int),
|
|
521
616
|
):
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
return
|
|
526
|
-
|
|
527
|
-
dest_row = wp.lower_bound(dest_offsets, 0, dest_row_count + 1, dest_block + 1) - 1
|
|
528
|
-
src_row = dest_row // row_factor
|
|
529
|
-
|
|
530
|
-
dest_col_in_row = dest_block - dest_offsets[dest_row]
|
|
531
|
-
src_col_in_row = dest_col_in_row // col_factor
|
|
532
|
-
|
|
533
|
-
src_block = src_offsets[src_row] + src_col_in_row
|
|
617
|
+
"""Index of the block at block-coordinates (row, col), or -1 if non-existing.
|
|
618
|
+
Assumes bsr_columns is sorted.
|
|
619
|
+
"""
|
|
534
620
|
|
|
535
|
-
|
|
536
|
-
|
|
621
|
+
if row < 0:
|
|
622
|
+
return -1
|
|
537
623
|
|
|
538
|
-
|
|
539
|
-
|
|
624
|
+
mask_row_beg = bsr_offsets[row]
|
|
625
|
+
mask_row_end = bsr_offsets[row + 1]
|
|
540
626
|
|
|
541
|
-
|
|
627
|
+
if mask_row_beg == mask_row_end:
|
|
628
|
+
return -1
|
|
542
629
|
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
src_base_j = split_col * dest_cols_per_block
|
|
546
|
-
for i in range(dest_rows_per_block):
|
|
547
|
-
for j in range(dest_cols_per_block):
|
|
548
|
-
dest_values[dest_block, i, j] = dest_values.dtype(
|
|
549
|
-
scale * src_values[src_block, i + src_base_i, j + src_base_j]
|
|
550
|
-
)
|
|
630
|
+
block_index = wp.lower_bound(bsr_columns, mask_row_beg, mask_row_end, col)
|
|
631
|
+
return wp.where(bsr_columns[block_index] == col, block_index, -1)
|
|
551
632
|
|
|
552
633
|
|
|
553
|
-
@wp.kernel
|
|
554
|
-
def
|
|
555
|
-
|
|
556
|
-
|
|
634
|
+
@wp.kernel(enable_backward=False)
|
|
635
|
+
def _bsr_assign_list_blocks(
|
|
636
|
+
src_subrows: int,
|
|
637
|
+
src_subcols: int,
|
|
638
|
+
dest_subrows: int,
|
|
639
|
+
dest_subcols: int,
|
|
557
640
|
src_row_count: int,
|
|
558
641
|
src_offsets: wp.array(dtype=int),
|
|
559
642
|
src_columns: wp.array(dtype=int),
|
|
560
643
|
dest_rows: wp.array(dtype=int),
|
|
561
644
|
dest_cols: wp.array(dtype=int),
|
|
562
645
|
):
|
|
563
|
-
block = wp.tid()
|
|
646
|
+
block, subrow, subcol = wp.tid()
|
|
647
|
+
dest_block = (block * src_subcols + subcol) * src_subrows + subrow
|
|
564
648
|
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
649
|
+
row = _bsr_row_index(src_offsets, src_row_count, block)
|
|
650
|
+
if row == -1:
|
|
651
|
+
dest_rows[dest_block] = row # invalid
|
|
652
|
+
dest_cols[dest_block] = row
|
|
568
653
|
else:
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
654
|
+
dest_subrow = row * src_subrows + subrow
|
|
655
|
+
dest_subcol = src_columns[block] * src_subcols + subcol
|
|
656
|
+
dest_rows[dest_block] = dest_subrow // dest_subrows
|
|
657
|
+
dest_cols[dest_block] = dest_subcol // dest_subcols
|
|
572
658
|
|
|
573
659
|
|
|
574
660
|
@wp.kernel
|
|
575
|
-
def
|
|
661
|
+
def _bsr_assign_copy_blocks(
|
|
576
662
|
scale: Any,
|
|
577
|
-
|
|
578
|
-
|
|
663
|
+
src_subrows: int,
|
|
664
|
+
src_subcols: int,
|
|
665
|
+
dest_subrows: int,
|
|
666
|
+
dest_subcols: int,
|
|
579
667
|
src_row_count: int,
|
|
580
668
|
src_offsets: wp.array(dtype=int),
|
|
581
669
|
src_columns: wp.array(dtype=int),
|
|
@@ -585,61 +673,58 @@ def _bsr_assign_merge_blocks(
|
|
|
585
673
|
dest_values: wp.array3d(dtype=Any),
|
|
586
674
|
):
|
|
587
675
|
src_block = wp.tid()
|
|
676
|
+
src_block, subrow, subcol = wp.tid()
|
|
588
677
|
|
|
589
|
-
|
|
678
|
+
src_row = _bsr_row_index(src_offsets, src_row_count, src_block)
|
|
679
|
+
if src_row == -1:
|
|
590
680
|
return
|
|
591
681
|
|
|
592
|
-
src_row = wp.lower_bound(src_offsets, 0, src_row_count + 1, src_block + 1) - 1
|
|
593
682
|
src_col = src_columns[src_block]
|
|
594
683
|
|
|
595
|
-
|
|
596
|
-
|
|
684
|
+
dest_subrow = src_row * src_subrows + subrow
|
|
685
|
+
dest_subcol = src_col * src_subcols + subcol
|
|
686
|
+
dest_row = dest_subrow // dest_subrows
|
|
687
|
+
dest_col = dest_subcol // dest_subcols
|
|
597
688
|
|
|
598
|
-
dest_block =
|
|
689
|
+
dest_block = _bsr_block_index(dest_row, dest_col, dest_offsets, dest_columns)
|
|
690
|
+
if dest_block == -1:
|
|
691
|
+
return
|
|
692
|
+
|
|
693
|
+
split_row = dest_subrow - dest_subrows * dest_row
|
|
694
|
+
split_col = dest_subcol - dest_subcols * dest_col
|
|
599
695
|
|
|
600
|
-
|
|
601
|
-
|
|
696
|
+
rows_per_subblock = src_values.shape[1] // src_subrows
|
|
697
|
+
cols_per_subblock = src_values.shape[2] // src_subcols
|
|
602
698
|
|
|
603
|
-
|
|
604
|
-
|
|
699
|
+
dest_base_i = split_row * rows_per_subblock
|
|
700
|
+
dest_base_j = split_col * cols_per_subblock
|
|
605
701
|
|
|
606
|
-
|
|
607
|
-
|
|
702
|
+
src_base_i = subrow * rows_per_subblock
|
|
703
|
+
src_base_j = subcol * cols_per_subblock
|
|
608
704
|
|
|
609
|
-
for i in range(
|
|
610
|
-
for j in range(
|
|
705
|
+
for i in range(rows_per_subblock):
|
|
706
|
+
for j in range(cols_per_subblock):
|
|
611
707
|
dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
|
|
612
|
-
scale * src_values[src_block, i, j]
|
|
708
|
+
scale * src_values[src_block, i + src_base_i, j + src_base_j]
|
|
613
709
|
)
|
|
614
710
|
|
|
615
711
|
|
|
616
|
-
def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array:
|
|
617
|
-
if A.block_shape == (1, 1):
|
|
618
|
-
return A.values.reshape((A.values.shape[0], 1, 1))
|
|
619
|
-
|
|
620
|
-
return wp.array(
|
|
621
|
-
data=None,
|
|
622
|
-
ptr=A.values.ptr,
|
|
623
|
-
capacity=A.values.capacity,
|
|
624
|
-
device=A.device,
|
|
625
|
-
dtype=A.scalar_type,
|
|
626
|
-
shape=(A.values.shape[0], A.block_shape[0], A.block_shape[1]),
|
|
627
|
-
)
|
|
628
|
-
|
|
629
|
-
|
|
630
712
|
def bsr_assign(
|
|
631
713
|
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
632
714
|
src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
|
|
633
715
|
structure_only: bool = False,
|
|
716
|
+
masked: bool = False,
|
|
634
717
|
):
|
|
635
|
-
"""
|
|
718
|
+
"""Copy the content of the ``src`` BSR matrix to ``dest``.
|
|
636
719
|
|
|
637
720
|
Args:
|
|
638
|
-
src: Matrix to be copied
|
|
639
|
-
dest: Destination matrix. May have a different block shape
|
|
721
|
+
src: Matrix to be copied.
|
|
722
|
+
dest: Destination matrix. May have a different block shape or scalar type
|
|
723
|
+
than ``src``, in which case the required casting will be performed.
|
|
640
724
|
structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
|
|
641
|
-
to accommodate at least
|
|
725
|
+
to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
|
|
642
726
|
casting if the two matrices use distinct scalar types.
|
|
727
|
+
masked: If ``True``, prevent the assignment operation from adding new non-zeros blocks to ``dest``.
|
|
643
728
|
"""
|
|
644
729
|
|
|
645
730
|
src, src_scale = _extract_matrix_and_scale(src)
|
|
@@ -647,13 +732,50 @@ def bsr_assign(
|
|
|
647
732
|
if dest.values.device != src.values.device:
|
|
648
733
|
raise ValueError("Source and destination matrices must reside on the same device")
|
|
649
734
|
|
|
650
|
-
if
|
|
651
|
-
|
|
652
|
-
|
|
735
|
+
if src.block_shape[0] >= dest.block_shape[0]:
|
|
736
|
+
src_subrows = src.block_shape[0] // dest.block_shape[0]
|
|
737
|
+
dest_subrows = 1
|
|
738
|
+
else:
|
|
739
|
+
dest_subrows = dest.block_shape[0] // src.block_shape[0]
|
|
740
|
+
src_subrows = 1
|
|
741
|
+
|
|
742
|
+
if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
|
|
743
|
+
raise ValueError(
|
|
744
|
+
f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {src.block_shape[0]}, {dest.block_shape[0]})"
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
if src.block_shape[1] >= dest.block_shape[1]:
|
|
748
|
+
src_subcols = src.block_shape[1] // dest.block_shape[1]
|
|
749
|
+
dest_subcols = 1
|
|
750
|
+
else:
|
|
751
|
+
dest_subcols = dest.block_shape[1] // src.block_shape[1]
|
|
752
|
+
src_subcols = 1
|
|
753
|
+
|
|
754
|
+
if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
|
|
755
|
+
raise ValueError(
|
|
756
|
+
f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {src.block_shape[1]}, {dest.block_shape[1]})"
|
|
757
|
+
)
|
|
653
758
|
|
|
654
|
-
|
|
759
|
+
dest_nrow = (src.nrow * src_subrows) // dest_subrows
|
|
760
|
+
dest_ncol = (src.ncol * src_subcols) // dest_subcols
|
|
761
|
+
|
|
762
|
+
if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
|
|
763
|
+
raise ValueError("The requested block shape does not evenly divide the source matrix")
|
|
764
|
+
|
|
765
|
+
nnz_alloc = src.nnz * src_subrows * src_subcols
|
|
766
|
+
if masked:
|
|
767
|
+
if dest_nrow != dest.nrow or dest_ncol != dest.ncol:
|
|
768
|
+
raise ValueError(
|
|
769
|
+
f"Incompatible destination matrix size, expected ({dest_nrow}, {dest_ncol}), got ({dest.nrow}, {dest.ncol})"
|
|
770
|
+
)
|
|
771
|
+
else:
|
|
772
|
+
dest.nrow = dest_nrow
|
|
773
|
+
dest.ncol = dest_ncol
|
|
655
774
|
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
656
775
|
|
|
776
|
+
if dest.block_shape == src.block_shape and not masked:
|
|
777
|
+
# Direct copy
|
|
778
|
+
|
|
657
779
|
wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
|
|
658
780
|
dest.copy_nnz_async()
|
|
659
781
|
|
|
@@ -664,86 +786,29 @@ def bsr_assign(
|
|
|
664
786
|
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
|
|
665
787
|
bsr_scale(dest, src_scale)
|
|
666
788
|
|
|
667
|
-
|
|
668
|
-
#
|
|
669
|
-
|
|
670
|
-
row_factor = src.block_shape[0] // dest.block_shape[0]
|
|
671
|
-
col_factor = src.block_shape[1] // dest.block_shape[1]
|
|
672
|
-
|
|
673
|
-
if (
|
|
674
|
-
row_factor * dest.block_shape[0] != src.block_shape[0]
|
|
675
|
-
or col_factor * dest.block_shape[1] != src.block_shape[1]
|
|
676
|
-
):
|
|
677
|
-
raise ValueError(
|
|
678
|
-
f"Dest block shape {dest.block_shape} is not an exact divider of src block shape {src.block_shape}"
|
|
679
|
-
)
|
|
680
|
-
|
|
681
|
-
dest.nrow = src.nrow * row_factor
|
|
682
|
-
dest.ncol = src.ncol * col_factor
|
|
683
|
-
|
|
684
|
-
nnz_alloc = src.nnz * row_factor * col_factor
|
|
685
|
-
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
789
|
+
else:
|
|
790
|
+
# Masked and/or multiple src blocks per dest block, go through COO format
|
|
686
791
|
|
|
792
|
+
# Compute destination rows and columns
|
|
793
|
+
dest_rows = wp.empty(nnz_alloc, dtype=int, device=dest.device)
|
|
794
|
+
dest_cols = wp.empty(nnz_alloc, dtype=int, device=dest.device)
|
|
687
795
|
wp.launch(
|
|
688
|
-
|
|
689
|
-
dim=src.
|
|
690
|
-
device=dest.device,
|
|
691
|
-
inputs=[row_factor, col_factor, src.offsets, dest.offsets],
|
|
692
|
-
)
|
|
693
|
-
wp.launch(
|
|
694
|
-
_bsr_assign_split_blocks,
|
|
695
|
-
dim=dest.nnz,
|
|
796
|
+
_bsr_assign_list_blocks,
|
|
797
|
+
dim=(src.nnz, src_subrows, src_subcols),
|
|
696
798
|
device=dest.device,
|
|
697
799
|
inputs=[
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
800
|
+
src_subrows,
|
|
801
|
+
src_subcols,
|
|
802
|
+
dest_subrows,
|
|
803
|
+
dest_subcols,
|
|
804
|
+
src.nrow,
|
|
703
805
|
src.offsets,
|
|
704
806
|
src.columns,
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
dest.columns,
|
|
708
|
-
_bsr_values_as_3d_array(dest),
|
|
807
|
+
dest_rows,
|
|
808
|
+
dest_cols,
|
|
709
809
|
],
|
|
710
810
|
)
|
|
711
811
|
|
|
712
|
-
elif src.block_shape[0] <= dest.block_shape[0] and src.block_shape[1] <= dest.block_shape[1]:
|
|
713
|
-
# Merge blocks
|
|
714
|
-
|
|
715
|
-
row_factor = dest.block_shape[0] // src.block_shape[0]
|
|
716
|
-
col_factor = dest.block_shape[1] // src.block_shape[1]
|
|
717
|
-
|
|
718
|
-
if (
|
|
719
|
-
row_factor * src.block_shape[0] != dest.block_shape[0]
|
|
720
|
-
or col_factor * src.block_shape[1] != dest.block_shape[1]
|
|
721
|
-
):
|
|
722
|
-
raise ValueError(
|
|
723
|
-
f"Dest block shape {dest.block_shape} is not an exact multiple of src block shape {src.block_shape}"
|
|
724
|
-
)
|
|
725
|
-
|
|
726
|
-
if src.nrow % row_factor != 0 or src.ncol % col_factor != 0:
|
|
727
|
-
raise ValueError(
|
|
728
|
-
"The total rows and columns of the src matrix cannot be evenly divided using the requested block shape"
|
|
729
|
-
)
|
|
730
|
-
|
|
731
|
-
dest.nrow = src.nrow // row_factor
|
|
732
|
-
dest.ncol = src.ncol // col_factor
|
|
733
|
-
|
|
734
|
-
nnz_alloc = src.nnz # Conservative, in case all nnz in src belong to distinct merged blocks
|
|
735
|
-
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
736
|
-
|
|
737
|
-
# Compute destination rows and columns
|
|
738
|
-
dest_rows = wp.empty_like(src.columns)
|
|
739
|
-
dest_cols = wp.empty_like(src.columns)
|
|
740
|
-
wp.launch(
|
|
741
|
-
_bsr_assign_merge_row_col,
|
|
742
|
-
dim=src.nnz,
|
|
743
|
-
device=dest.device,
|
|
744
|
-
inputs=[row_factor, col_factor, src.nrow, src.offsets, src.columns, dest_rows, dest_cols],
|
|
745
|
-
)
|
|
746
|
-
|
|
747
812
|
# Compute destination offsets from triplets
|
|
748
813
|
from warp.context import runtime
|
|
749
814
|
|
|
@@ -758,11 +823,12 @@ def bsr_assign(
|
|
|
758
823
|
dest.block_shape[0],
|
|
759
824
|
dest.block_shape[1],
|
|
760
825
|
dest.nrow,
|
|
761
|
-
|
|
826
|
+
nnz_alloc,
|
|
762
827
|
ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
763
828
|
ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
764
829
|
0,
|
|
765
830
|
False,
|
|
831
|
+
masked,
|
|
766
832
|
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
767
833
|
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
768
834
|
0,
|
|
@@ -774,26 +840,25 @@ def bsr_assign(
|
|
|
774
840
|
if not structure_only:
|
|
775
841
|
dest.values.zero_()
|
|
776
842
|
wp.launch(
|
|
777
|
-
|
|
778
|
-
dim=src.nnz,
|
|
843
|
+
_bsr_assign_copy_blocks,
|
|
844
|
+
dim=(src.nnz, src_subrows, src_subcols),
|
|
779
845
|
device=dest.device,
|
|
780
846
|
inputs=[
|
|
781
847
|
src.scalar_type(src_scale),
|
|
782
|
-
|
|
783
|
-
|
|
848
|
+
src_subrows,
|
|
849
|
+
src_subcols,
|
|
850
|
+
dest_subrows,
|
|
851
|
+
dest_subcols,
|
|
784
852
|
src.nrow,
|
|
785
853
|
src.offsets,
|
|
786
854
|
src.columns,
|
|
787
|
-
|
|
855
|
+
src.scalar_values,
|
|
788
856
|
dest.offsets,
|
|
789
857
|
dest.columns,
|
|
790
|
-
|
|
858
|
+
dest.scalar_values,
|
|
791
859
|
],
|
|
792
860
|
)
|
|
793
861
|
|
|
794
|
-
else:
|
|
795
|
-
raise ValueError("Incompatible dest and src block shapes")
|
|
796
|
-
|
|
797
862
|
|
|
798
863
|
def bsr_copy(
|
|
799
864
|
A: BsrMatrixOrExpression,
|
|
@@ -801,15 +866,15 @@ def bsr_copy(
|
|
|
801
866
|
block_shape: Optional[Tuple[int, int]] = None,
|
|
802
867
|
structure_only: bool = False,
|
|
803
868
|
):
|
|
804
|
-
"""
|
|
869
|
+
"""Return a copy of matrix ``A``, possibly changing its scalar type.
|
|
805
870
|
|
|
806
871
|
Args:
|
|
807
|
-
A: Matrix to be copied
|
|
808
|
-
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from
|
|
809
|
-
block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from
|
|
810
|
-
Both dimensions of
|
|
872
|
+
A: Matrix to be copied.
|
|
873
|
+
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from ``A``.
|
|
874
|
+
block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from ``A``.
|
|
875
|
+
Both dimensions of ``block_shape`` must be either a multiple or an exact divider of the ones from ``A``.
|
|
811
876
|
structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
|
|
812
|
-
to accommodate at least
|
|
877
|
+
to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
|
|
813
878
|
casting if the two matrices use distinct scalar types.
|
|
814
879
|
"""
|
|
815
880
|
if scalar_type is None:
|
|
@@ -820,7 +885,7 @@ def bsr_copy(
|
|
|
820
885
|
if block_shape == (1, 1):
|
|
821
886
|
block_type = scalar_type
|
|
822
887
|
else:
|
|
823
|
-
block_type = wp.
|
|
888
|
+
block_type = wp.mat(shape=block_shape, dtype=scalar_type)
|
|
824
889
|
|
|
825
890
|
copy = bsr_zeros(
|
|
826
891
|
rows_of_blocks=A.nrow,
|
|
@@ -836,7 +901,7 @@ def bsr_set_transpose(
|
|
|
836
901
|
dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
|
|
837
902
|
src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
|
|
838
903
|
):
|
|
839
|
-
"""
|
|
904
|
+
"""Assign the transposed matrix ``src`` to matrix ``dest``."""
|
|
840
905
|
|
|
841
906
|
src, src_scale = _extract_matrix_and_scale(src)
|
|
842
907
|
|
|
@@ -897,13 +962,13 @@ def bsr_set_transpose(
|
|
|
897
962
|
bsr_scale(dest, src_scale)
|
|
898
963
|
|
|
899
964
|
|
|
900
|
-
def bsr_transposed(A: BsrMatrixOrExpression):
|
|
901
|
-
"""
|
|
965
|
+
def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
|
|
966
|
+
"""Return a copy of the transposed matrix ``A``."""
|
|
902
967
|
|
|
903
968
|
if A.block_shape == (1, 1):
|
|
904
969
|
block_type = A.values.dtype
|
|
905
970
|
else:
|
|
906
|
-
block_type = wp.
|
|
971
|
+
block_type = wp.mat(shape=A.block_shape[::-1], dtype=A.scalar_type)
|
|
907
972
|
|
|
908
973
|
transposed = bsr_zeros(
|
|
909
974
|
rows_of_blocks=A.ncol,
|
|
@@ -924,21 +989,18 @@ def _bsr_get_diag_kernel(
|
|
|
924
989
|
out: wp.array(dtype=Any),
|
|
925
990
|
):
|
|
926
991
|
row = wp.tid()
|
|
927
|
-
beg = A_offsets[row]
|
|
928
|
-
end = A_offsets[row + 1]
|
|
929
992
|
|
|
930
|
-
diag =
|
|
931
|
-
if diag
|
|
932
|
-
|
|
933
|
-
out[row] = scale * A_values[diag]
|
|
993
|
+
diag = _bsr_block_index(row, row, A_offsets, A_columns)
|
|
994
|
+
if diag != -1:
|
|
995
|
+
out[row] = scale * A_values[diag]
|
|
934
996
|
|
|
935
997
|
|
|
936
998
|
def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
|
|
937
|
-
"""
|
|
999
|
+
"""Return the array of blocks that constitute the diagonal of a sparse matrix.
|
|
938
1000
|
|
|
939
1001
|
Args:
|
|
940
|
-
A:
|
|
941
|
-
out:
|
|
1002
|
+
A: The sparse matrix from which to extract the diagonal.
|
|
1003
|
+
out: If provided, the array into which to store the diagonal blocks.
|
|
942
1004
|
"""
|
|
943
1005
|
|
|
944
1006
|
A, scale = _extract_matrix_and_scale(A)
|
|
@@ -965,36 +1027,16 @@ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[Block
|
|
|
965
1027
|
return out
|
|
966
1028
|
|
|
967
1029
|
|
|
968
|
-
@wp.kernel
|
|
1030
|
+
@wp.kernel(enable_backward=False)
|
|
969
1031
|
def _bsr_set_diag_kernel(
|
|
970
|
-
|
|
1032
|
+
nnz: int,
|
|
971
1033
|
A_offsets: wp.array(dtype=int),
|
|
972
1034
|
A_columns: wp.array(dtype=int),
|
|
973
|
-
A_values: wp.array(dtype=Any),
|
|
974
1035
|
):
|
|
975
1036
|
row = wp.tid()
|
|
976
|
-
A_offsets[row
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
if row == 0:
|
|
981
|
-
A_offsets[0] = 0
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
@wp.kernel
|
|
985
|
-
def _bsr_set_diag_constant_kernel(
|
|
986
|
-
diag_value: Any,
|
|
987
|
-
A_offsets: wp.array(dtype=int),
|
|
988
|
-
A_columns: wp.array(dtype=int),
|
|
989
|
-
A_values: wp.array(dtype=Any),
|
|
990
|
-
):
|
|
991
|
-
row = wp.tid()
|
|
992
|
-
A_offsets[row + 1] = row + 1
|
|
993
|
-
A_columns[row] = row
|
|
994
|
-
A_values[row] = diag_value
|
|
995
|
-
|
|
996
|
-
if row == 0:
|
|
997
|
-
A_offsets[0] = 0
|
|
1037
|
+
A_offsets[row] = wp.min(row, nnz)
|
|
1038
|
+
if row < nnz:
|
|
1039
|
+
A_columns[row] = row
|
|
998
1040
|
|
|
999
1041
|
|
|
1000
1042
|
def bsr_set_diag(
|
|
@@ -1002,20 +1044,26 @@ def bsr_set_diag(
|
|
|
1002
1044
|
diag: "Union[BlockType, Array[BlockType]]",
|
|
1003
1045
|
rows_of_blocks: Optional[int] = None,
|
|
1004
1046
|
cols_of_blocks: Optional[int] = None,
|
|
1005
|
-
):
|
|
1006
|
-
"""
|
|
1047
|
+
) -> None:
|
|
1048
|
+
"""Set ``A`` as a block-diagonal matrix.
|
|
1007
1049
|
|
|
1008
1050
|
Args:
|
|
1009
|
-
A:
|
|
1010
|
-
diag:
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1051
|
+
A: The sparse matrix to modify.
|
|
1052
|
+
diag: Specifies the values for diagonal blocks. Can be one of:
|
|
1053
|
+
|
|
1054
|
+
- A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
|
|
1055
|
+
- A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
|
|
1056
|
+
- ``None``: Diagonal block values are left uninitialized
|
|
1057
|
+
|
|
1058
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks.
|
|
1059
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks.
|
|
1060
|
+
|
|
1061
|
+
The shape of the matrix will be defined one of the following, in this order:
|
|
1014
1062
|
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1063
|
+
- ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
|
|
1064
|
+
If only one is given, the second is assumed equal.
|
|
1065
|
+
- The first dimension of ``diag``, if ``diag`` is an array
|
|
1066
|
+
- The current dimensions of ``A`` otherwise
|
|
1019
1067
|
"""
|
|
1020
1068
|
|
|
1021
1069
|
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
@@ -1023,7 +1071,7 @@ def bsr_set_diag(
|
|
|
1023
1071
|
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
1024
1072
|
cols_of_blocks = rows_of_blocks
|
|
1025
1073
|
|
|
1026
|
-
if
|
|
1074
|
+
if is_array(diag):
|
|
1027
1075
|
if rows_of_blocks is None:
|
|
1028
1076
|
rows_of_blocks = diag.shape[0]
|
|
1029
1077
|
cols_of_blocks = diag.shape[0]
|
|
@@ -1035,43 +1083,45 @@ def bsr_set_diag(
|
|
|
1035
1083
|
nnz = min(A.nrow, A.ncol)
|
|
1036
1084
|
_bsr_ensure_fits(A, nnz=nnz)
|
|
1037
1085
|
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
wp.launch(
|
|
1050
|
-
kernel=_bsr_set_diag_constant_kernel,
|
|
1051
|
-
dim=nnz,
|
|
1052
|
-
device=A.values.device,
|
|
1053
|
-
inputs=[diag, A.offsets, A.columns, A.values],
|
|
1054
|
-
)
|
|
1086
|
+
wp.launch(
|
|
1087
|
+
kernel=_bsr_set_diag_kernel,
|
|
1088
|
+
dim=nnz + 1,
|
|
1089
|
+
device=A.offsets.device,
|
|
1090
|
+
inputs=[nnz, A.offsets, A.columns],
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
if is_array(diag):
|
|
1094
|
+
wp.copy(src=diag, dest=A.values, count=nnz)
|
|
1095
|
+
elif diag is not None:
|
|
1096
|
+
A.values.fill_(diag)
|
|
1055
1097
|
|
|
1056
1098
|
A.copy_nnz_async(known_nnz=nnz)
|
|
1057
1099
|
|
|
1058
1100
|
|
|
1059
1101
|
def bsr_diag(
|
|
1060
|
-
diag:
|
|
1102
|
+
diag: Optional[Union[BlockType, Array[BlockType]]] = None,
|
|
1061
1103
|
rows_of_blocks: Optional[int] = None,
|
|
1062
1104
|
cols_of_blocks: Optional[int] = None,
|
|
1105
|
+
block_type: Optional[BlockType] = None,
|
|
1106
|
+
device=None,
|
|
1063
1107
|
) -> BsrMatrix["BlockType"]:
|
|
1064
|
-
"""
|
|
1108
|
+
"""Create and return a block-diagonal BSR matrix from an given block value or array of block values.
|
|
1065
1109
|
|
|
1066
1110
|
Args:
|
|
1067
|
-
diag:
|
|
1068
|
-
|
|
1111
|
+
diag: Specifies the values for diagonal blocks. Can be one of:
|
|
1112
|
+
|
|
1113
|
+
- A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
|
|
1114
|
+
- A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
|
|
1069
1115
|
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
1070
1116
|
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
1117
|
+
block_type: If ``diag`` is ``None``, block type of the matrix. Otherwise deduced from ``diag``
|
|
1118
|
+
device: If ``diag`` is not a Warp array, device on which to allocate the matrix. Otherwise deduced from ``diag``
|
|
1119
|
+
|
|
1120
|
+
The shape of the matrix will be defined one of the following, in this order:
|
|
1071
1121
|
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1122
|
+
- ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
|
|
1123
|
+
If only one is given, the second is assumed equal.
|
|
1124
|
+
- The first dimension of ``diag`` if ``diag`` is an array.
|
|
1075
1125
|
"""
|
|
1076
1126
|
|
|
1077
1127
|
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
@@ -1079,43 +1129,39 @@ def bsr_diag(
|
|
|
1079
1129
|
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
1080
1130
|
cols_of_blocks = rows_of_blocks
|
|
1081
1131
|
|
|
1082
|
-
if
|
|
1132
|
+
if is_array(diag):
|
|
1083
1133
|
if rows_of_blocks is None:
|
|
1084
1134
|
rows_of_blocks = diag.shape[0]
|
|
1085
1135
|
cols_of_blocks = diag.shape[0]
|
|
1086
1136
|
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
cols_of_blocks,
|
|
1090
|
-
block_type=diag.dtype,
|
|
1091
|
-
device=diag.device,
|
|
1092
|
-
)
|
|
1137
|
+
block_type = diag.dtype
|
|
1138
|
+
device = diag.device
|
|
1093
1139
|
else:
|
|
1094
1140
|
if rows_of_blocks is None:
|
|
1095
1141
|
raise ValueError(
|
|
1096
1142
|
"rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
|
|
1097
1143
|
)
|
|
1098
1144
|
|
|
1145
|
+
if block_type is None:
|
|
1146
|
+
if diag is None:
|
|
1147
|
+
raise ValueError("Either `diag` or `block_type` needs to be provided")
|
|
1148
|
+
|
|
1099
1149
|
block_type = type(diag)
|
|
1100
|
-
if not
|
|
1150
|
+
if not type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
|
|
1101
1151
|
block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
|
|
1102
1152
|
|
|
1103
|
-
|
|
1104
|
-
rows_of_blocks,
|
|
1105
|
-
cols_of_blocks,
|
|
1106
|
-
block_type=block_type,
|
|
1107
|
-
)
|
|
1108
|
-
|
|
1153
|
+
A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
|
|
1109
1154
|
bsr_set_diag(A, diag)
|
|
1110
1155
|
return A
|
|
1111
1156
|
|
|
1112
1157
|
|
|
1113
|
-
def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
|
|
1114
|
-
"""
|
|
1158
|
+
def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None) -> None:
|
|
1159
|
+
"""Set ``A`` as the identity matrix.
|
|
1115
1160
|
|
|
1116
1161
|
Args:
|
|
1117
|
-
A:
|
|
1118
|
-
rows_of_blocks:
|
|
1162
|
+
A: The sparse matrix to modify.
|
|
1163
|
+
rows_of_blocks: If provided, the matrix will be resized as a square
|
|
1164
|
+
matrix with ``rows_of_blocks`` rows and columns.
|
|
1119
1165
|
"""
|
|
1120
1166
|
|
|
1121
1167
|
if A.block_shape == (1, 1):
|
|
@@ -1133,11 +1179,11 @@ def bsr_identity(
|
|
|
1133
1179
|
block_type: BlockType[Rows, Rows, Scalar],
|
|
1134
1180
|
device: wp.context.Devicelike = None,
|
|
1135
1181
|
) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
|
|
1136
|
-
"""
|
|
1182
|
+
"""Create and return a square identity matrix.
|
|
1137
1183
|
|
|
1138
1184
|
Args:
|
|
1139
1185
|
rows_of_blocks: Number of rows and columns of blocks in the created matrix.
|
|
1140
|
-
block_type: Block type for the newly created matrix
|
|
1186
|
+
block_type: Block type for the newly created matrix. Must be square
|
|
1141
1187
|
device: Device onto which to allocate the data arrays
|
|
1142
1188
|
"""
|
|
1143
1189
|
A = bsr_zeros(
|
|
@@ -1159,9 +1205,7 @@ def _bsr_scale_kernel(
|
|
|
1159
1205
|
|
|
1160
1206
|
|
|
1161
1207
|
def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
1162
|
-
"""
|
|
1163
|
-
Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
|
|
1164
|
-
"""
|
|
1208
|
+
"""Perform the operation ``x := alpha * x`` on BSR matrix ``x`` and return ``x``."""
|
|
1165
1209
|
|
|
1166
1210
|
x, scale = _extract_matrix_and_scale(x)
|
|
1167
1211
|
alpha *= scale
|
|
@@ -1170,8 +1214,7 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
|
1170
1214
|
if alpha == 0.0:
|
|
1171
1215
|
bsr_set_zero(x)
|
|
1172
1216
|
else:
|
|
1173
|
-
|
|
1174
|
-
alpha = x.scalar_type(alpha)
|
|
1217
|
+
alpha = x.scalar_type(alpha)
|
|
1175
1218
|
|
|
1176
1219
|
wp.launch(
|
|
1177
1220
|
kernel=_bsr_scale_kernel,
|
|
@@ -1183,15 +1226,10 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
|
1183
1226
|
return x
|
|
1184
1227
|
|
|
1185
1228
|
|
|
1186
|
-
@wp.kernel
|
|
1187
|
-
def _bsr_get_block_row(
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
if i >= bsr_offsets[row_count]:
|
|
1191
|
-
rows[dest_offset + i] = -1 # invalid
|
|
1192
|
-
else:
|
|
1193
|
-
row = wp.lower_bound(bsr_offsets, 0, row_count + 1, i + 1) - 1
|
|
1194
|
-
rows[dest_offset + i] = row
|
|
1229
|
+
@wp.kernel(enable_backward=False)
|
|
1230
|
+
def _bsr_get_block_row(row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
|
|
1231
|
+
block = wp.tid()
|
|
1232
|
+
rows[block] = _bsr_row_index(bsr_offsets, row_count, block)
|
|
1195
1233
|
|
|
1196
1234
|
|
|
1197
1235
|
@wp.kernel
|
|
@@ -1207,21 +1245,15 @@ def _bsr_axpy_add_block(
|
|
|
1207
1245
|
):
|
|
1208
1246
|
i = wp.tid()
|
|
1209
1247
|
row = rows[i + src_offset]
|
|
1210
|
-
|
|
1211
|
-
if row < 0:
|
|
1212
|
-
return
|
|
1213
|
-
|
|
1214
1248
|
col = cols[i + src_offset]
|
|
1215
|
-
beg = dst_offsets[row]
|
|
1216
|
-
end = dst_offsets[row + 1]
|
|
1217
1249
|
|
|
1218
|
-
block =
|
|
1219
|
-
|
|
1220
|
-
|
|
1250
|
+
block = _bsr_block_index(row, col, dst_offsets, dst_columns)
|
|
1251
|
+
if block != -1:
|
|
1252
|
+
dst_values[block] += scale * src_values[i]
|
|
1221
1253
|
|
|
1222
1254
|
|
|
1223
1255
|
class bsr_axpy_work_arrays:
|
|
1224
|
-
"""Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
|
|
1256
|
+
"""Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls."""
|
|
1225
1257
|
|
|
1226
1258
|
def __init__(self):
|
|
1227
1259
|
self._reset(None)
|
|
@@ -1251,25 +1283,33 @@ def bsr_axpy(
|
|
|
1251
1283
|
y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
1252
1284
|
alpha: Scalar = 1.0,
|
|
1253
1285
|
beta: Scalar = 1.0,
|
|
1286
|
+
masked: bool = False,
|
|
1254
1287
|
work_arrays: Optional[bsr_axpy_work_arrays] = None,
|
|
1255
1288
|
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
1256
1289
|
"""
|
|
1257
|
-
|
|
1290
|
+
Perform the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices ``x`` and ``y`` and return ``y``.
|
|
1258
1291
|
|
|
1259
|
-
The
|
|
1292
|
+
The ``x`` and ``y`` matrices are allowed to alias.
|
|
1260
1293
|
|
|
1261
1294
|
Args:
|
|
1262
1295
|
x: Read-only right-hand-side.
|
|
1263
|
-
y: Mutable left-hand-side. If
|
|
1264
|
-
alpha: Uniform scaling factor for
|
|
1265
|
-
beta: Uniform scaling factor for
|
|
1266
|
-
|
|
1296
|
+
y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
|
|
1297
|
+
alpha: Uniform scaling factor for ``x``.
|
|
1298
|
+
beta: Uniform scaling factor for ``y``.
|
|
1299
|
+
masked: If ``True``, discard all blocks from ``x`` which are not
|
|
1300
|
+
existing non-zeros of ``y``.
|
|
1301
|
+
work_arrays: In most cases, this function will require the use of temporary storage.
|
|
1302
|
+
This storage can be reused across calls by passing an instance of
|
|
1303
|
+
:class:`bsr_axpy_work_arrays` in ``work_arrays``.
|
|
1267
1304
|
"""
|
|
1268
1305
|
|
|
1269
1306
|
x, x_scale = _extract_matrix_and_scale(x)
|
|
1270
1307
|
alpha *= x_scale
|
|
1271
1308
|
|
|
1272
1309
|
if y is None:
|
|
1310
|
+
if masked:
|
|
1311
|
+
raise ValueError("Left-hand-side 'y' matrix must be provided for masked addition")
|
|
1312
|
+
|
|
1273
1313
|
# If not output matrix is provided, allocate it for convenience
|
|
1274
1314
|
y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
|
|
1275
1315
|
beta = 0.0
|
|
@@ -1313,27 +1353,17 @@ def bsr_axpy(
|
|
|
1313
1353
|
work_arrays._allocate(device, y, sum_nnz)
|
|
1314
1354
|
|
|
1315
1355
|
wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
|
|
1316
|
-
|
|
1317
|
-
kernel=_bsr_get_block_row,
|
|
1318
|
-
device=device,
|
|
1319
|
-
dim=y_nnz,
|
|
1320
|
-
inputs=[0, y.nrow, y.offsets, work_arrays._sum_rows],
|
|
1321
|
-
)
|
|
1356
|
+
y.uncompress_rows(out=work_arrays._sum_rows)
|
|
1322
1357
|
|
|
1323
1358
|
wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
|
|
1324
|
-
|
|
1325
|
-
kernel=_bsr_get_block_row,
|
|
1326
|
-
device=device,
|
|
1327
|
-
dim=x_nnz,
|
|
1328
|
-
inputs=[y_nnz, x.nrow, x.offsets, work_arrays._sum_rows],
|
|
1329
|
-
)
|
|
1359
|
+
x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
|
|
1330
1360
|
|
|
1331
1361
|
# Save old y values before overwriting matrix
|
|
1332
1362
|
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y_nnz)
|
|
1333
1363
|
|
|
1334
1364
|
# Increase dest array sizes if needed
|
|
1335
|
-
if
|
|
1336
|
-
y
|
|
1365
|
+
if not masked:
|
|
1366
|
+
_bsr_ensure_fits(y, nnz=sum_nnz)
|
|
1337
1367
|
|
|
1338
1368
|
from warp.context import runtime
|
|
1339
1369
|
|
|
@@ -1355,6 +1385,7 @@ def bsr_axpy(
|
|
|
1355
1385
|
ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1356
1386
|
0,
|
|
1357
1387
|
False,
|
|
1388
|
+
masked,
|
|
1358
1389
|
ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1359
1390
|
ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1360
1391
|
0,
|
|
@@ -1362,8 +1393,6 @@ def bsr_axpy(
|
|
|
1362
1393
|
nnz_event,
|
|
1363
1394
|
)
|
|
1364
1395
|
|
|
1365
|
-
_bsr_ensure_fits(y, nnz=sum_nnz)
|
|
1366
|
-
|
|
1367
1396
|
y.values.zero_()
|
|
1368
1397
|
|
|
1369
1398
|
wp.launch(
|
|
@@ -1401,55 +1430,90 @@ def bsr_axpy(
|
|
|
1401
1430
|
return y
|
|
1402
1431
|
|
|
1403
1432
|
|
|
1404
|
-
@wp.kernel
|
|
1433
|
+
@wp.kernel(enable_backward=False)
|
|
1405
1434
|
def _bsr_mm_count_coeffs(
|
|
1435
|
+
y_ncol: int,
|
|
1406
1436
|
z_nnz: int,
|
|
1407
1437
|
x_offsets: wp.array(dtype=int),
|
|
1408
1438
|
x_columns: wp.array(dtype=int),
|
|
1409
1439
|
y_offsets: wp.array(dtype=int),
|
|
1410
|
-
|
|
1440
|
+
y_columns: wp.array(dtype=int),
|
|
1441
|
+
row_min: wp.array(dtype=int),
|
|
1442
|
+
block_counts: wp.array(dtype=int),
|
|
1411
1443
|
):
|
|
1412
1444
|
row = wp.tid()
|
|
1413
|
-
|
|
1445
|
+
row_count = int(0)
|
|
1414
1446
|
|
|
1415
1447
|
x_beg = x_offsets[row]
|
|
1416
1448
|
x_end = x_offsets[row + 1]
|
|
1417
1449
|
|
|
1450
|
+
min_col = y_ncol
|
|
1451
|
+
max_col = int(0)
|
|
1452
|
+
|
|
1418
1453
|
for x_block in range(x_beg, x_end):
|
|
1419
1454
|
x_col = x_columns[x_block]
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1455
|
+
y_row_end = y_offsets[x_col + 1]
|
|
1456
|
+
y_row_beg = y_offsets[x_col]
|
|
1457
|
+
block_count = y_row_end - y_row_beg
|
|
1458
|
+
if block_count != 0:
|
|
1459
|
+
min_col = wp.min(y_columns[y_row_beg], min_col)
|
|
1460
|
+
max_col = wp.max(y_columns[y_row_end - 1], max_col)
|
|
1461
|
+
|
|
1462
|
+
block_counts[x_block + 1] = block_count
|
|
1463
|
+
row_count += block_count
|
|
1464
|
+
|
|
1465
|
+
if row_count > wp.max(0, max_col - min_col):
|
|
1466
|
+
row_min[row] = min_col
|
|
1467
|
+
block_counts[x_end] = max_col + 1 - min_col
|
|
1468
|
+
for x_block in range(x_beg, x_end - 1):
|
|
1469
|
+
block_counts[x_block + 1] = 0
|
|
1470
|
+
else:
|
|
1471
|
+
row_min[row] = -1
|
|
1423
1472
|
|
|
1424
1473
|
if row == 0:
|
|
1425
|
-
|
|
1474
|
+
block_counts[0] = z_nnz
|
|
1426
1475
|
|
|
1427
1476
|
|
|
1428
|
-
@wp.kernel
|
|
1477
|
+
@wp.kernel(enable_backward=False)
|
|
1429
1478
|
def _bsr_mm_list_coeffs(
|
|
1479
|
+
x_nrow: int,
|
|
1430
1480
|
x_offsets: wp.array(dtype=int),
|
|
1431
1481
|
x_columns: wp.array(dtype=int),
|
|
1432
1482
|
y_offsets: wp.array(dtype=int),
|
|
1433
1483
|
y_columns: wp.array(dtype=int),
|
|
1484
|
+
mm_row_min: wp.array(dtype=int),
|
|
1434
1485
|
mm_offsets: wp.array(dtype=int),
|
|
1435
1486
|
mm_rows: wp.array(dtype=int),
|
|
1436
1487
|
mm_cols: wp.array(dtype=int),
|
|
1437
1488
|
):
|
|
1438
|
-
|
|
1439
|
-
mm_block = mm_offsets[
|
|
1489
|
+
x_block = wp.tid()
|
|
1490
|
+
mm_block = mm_offsets[x_block]
|
|
1440
1491
|
|
|
1441
|
-
|
|
1442
|
-
|
|
1492
|
+
row = _bsr_row_index(x_offsets, x_nrow, x_block)
|
|
1493
|
+
if row == -1:
|
|
1494
|
+
return
|
|
1443
1495
|
|
|
1444
|
-
|
|
1496
|
+
row_min_col = mm_row_min[row]
|
|
1497
|
+
if row_min_col != -1:
|
|
1445
1498
|
x_col = x_columns[x_block]
|
|
1446
1499
|
|
|
1447
1500
|
y_beg = y_offsets[x_col]
|
|
1448
1501
|
y_end = y_offsets[x_col + 1]
|
|
1502
|
+
|
|
1449
1503
|
for y_block in range(y_beg, y_end):
|
|
1450
|
-
|
|
1451
|
-
mm_rows[mm_block] = row
|
|
1452
|
-
mm_block
|
|
1504
|
+
col = y_columns[y_block]
|
|
1505
|
+
mm_rows[mm_block + col - row_min_col] = row
|
|
1506
|
+
mm_cols[mm_block + col - row_min_col] = col
|
|
1507
|
+
|
|
1508
|
+
return
|
|
1509
|
+
|
|
1510
|
+
x_col = x_columns[x_block]
|
|
1511
|
+
y_beg = y_offsets[x_col]
|
|
1512
|
+
y_end = y_offsets[x_col + 1]
|
|
1513
|
+
for y_block in range(y_beg, y_end):
|
|
1514
|
+
mm_cols[mm_block] = y_columns[y_block]
|
|
1515
|
+
mm_rows[mm_block] = row
|
|
1516
|
+
mm_block += 1
|
|
1453
1517
|
|
|
1454
1518
|
|
|
1455
1519
|
@wp.kernel
|
|
@@ -1468,7 +1532,10 @@ def _bsr_mm_compute_values(
|
|
|
1468
1532
|
):
|
|
1469
1533
|
mm_block = wp.tid()
|
|
1470
1534
|
|
|
1471
|
-
row =
|
|
1535
|
+
row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
|
|
1536
|
+
if row == -1:
|
|
1537
|
+
return
|
|
1538
|
+
|
|
1472
1539
|
col = mm_cols[mm_block]
|
|
1473
1540
|
|
|
1474
1541
|
mm_val = mm_values.dtype(type(alpha)(0.0))
|
|
@@ -1477,26 +1544,23 @@ def _bsr_mm_compute_values(
|
|
|
1477
1544
|
x_end = x_offsets[row + 1]
|
|
1478
1545
|
for x_block in range(x_beg, x_end):
|
|
1479
1546
|
x_col = x_columns[x_block]
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
y_block = wp.lower_bound(y_columns, y_beg, y_end, col)
|
|
1484
|
-
if y_block < y_end:
|
|
1485
|
-
if y_columns[y_block] == col:
|
|
1486
|
-
mm_val += x_values[x_block] * y_values[y_block]
|
|
1547
|
+
y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
|
|
1548
|
+
if y_block != -1:
|
|
1549
|
+
mm_val += x_values[x_block] * y_values[y_block]
|
|
1487
1550
|
|
|
1488
1551
|
mm_values[mm_block] += alpha * mm_val
|
|
1489
1552
|
|
|
1490
1553
|
|
|
1491
1554
|
class bsr_mm_work_arrays:
|
|
1492
|
-
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
|
|
1555
|
+
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
|
|
1493
1556
|
|
|
1494
1557
|
def __init__(self):
|
|
1495
1558
|
self._reset(None)
|
|
1496
1559
|
|
|
1497
1560
|
def _reset(self, device):
|
|
1498
1561
|
self.device = device
|
|
1499
|
-
self.
|
|
1562
|
+
self._mm_row_min = None
|
|
1563
|
+
self._mm_block_counts = None
|
|
1500
1564
|
self._mm_rows = None
|
|
1501
1565
|
self._mm_cols = None
|
|
1502
1566
|
self._old_z_values = None
|
|
@@ -1504,7 +1568,7 @@ class bsr_mm_work_arrays:
|
|
|
1504
1568
|
self._old_z_columns = None
|
|
1505
1569
|
self._mm_nnz = 0
|
|
1506
1570
|
|
|
1507
|
-
def _allocate_stage_1(self, device, z: BsrMatrix, beta: float, z_aliasing: bool):
|
|
1571
|
+
def _allocate_stage_1(self, device, x_nnz: int, z: BsrMatrix, beta: float, z_aliasing: bool):
|
|
1508
1572
|
if self.device != device:
|
|
1509
1573
|
self._reset(device)
|
|
1510
1574
|
|
|
@@ -1512,8 +1576,10 @@ class bsr_mm_work_arrays:
|
|
|
1512
1576
|
z_nnz = z.nnz_sync()
|
|
1513
1577
|
self._copied_z_nnz = z_nnz if beta != 0.0 or z_aliasing else 0
|
|
1514
1578
|
|
|
1515
|
-
if self.
|
|
1516
|
-
self.
|
|
1579
|
+
if self._mm_row_min is None or self._mm_block_counts.size < z.nrow + 1:
|
|
1580
|
+
self._mm_row_min = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
|
|
1581
|
+
if self._mm_block_counts is None or self._mm_block_counts.size < x_nnz + 1:
|
|
1582
|
+
self._mm_block_counts = wp.empty(shape=(x_nnz + 1,), dtype=int, device=self.device)
|
|
1517
1583
|
|
|
1518
1584
|
if self._copied_z_nnz > 0:
|
|
1519
1585
|
if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
|
|
@@ -1540,25 +1606,31 @@ def bsr_mm(
|
|
|
1540
1606
|
z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
1541
1607
|
alpha: Scalar = 1.0,
|
|
1542
1608
|
beta: Scalar = 0.0,
|
|
1609
|
+
masked: bool = False,
|
|
1543
1610
|
work_arrays: Optional[bsr_mm_work_arrays] = None,
|
|
1544
1611
|
reuse_topology: bool = False,
|
|
1545
1612
|
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
1546
1613
|
"""
|
|
1547
|
-
|
|
1614
|
+
Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
|
|
1548
1615
|
|
|
1549
|
-
The
|
|
1550
|
-
If the matrix
|
|
1616
|
+
The ``x``, ``y`` and ``z`` matrices are allowed to alias.
|
|
1617
|
+
If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
|
|
1551
1618
|
|
|
1552
1619
|
Args:
|
|
1553
1620
|
x: Read-only left factor of the matrix-matrix product.
|
|
1554
1621
|
y: Read-only right factor of the matrix-matrix product.
|
|
1555
|
-
z: Mutable left-hand-side. If
|
|
1556
|
-
alpha: Uniform scaling factor for the ``x
|
|
1557
|
-
beta: Uniform scaling factor for
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1622
|
+
z: Mutable left-hand-side. If ``z`` is not provided, it will be allocated and treated as zero.
|
|
1623
|
+
alpha: Uniform scaling factor for the ``x @ y`` product
|
|
1624
|
+
beta: Uniform scaling factor for ``z``
|
|
1625
|
+
masked: If ``True``, ignore all blocks from ``x @ y`` which are not existing non-zeros of ``y``
|
|
1626
|
+
work_arrays: In most cases, this function will require the use of temporary storage.
|
|
1627
|
+
This storage can be reused across calls by passing an instance of
|
|
1628
|
+
:class:`bsr_mm_work_arrays` in ``work_arrays``.
|
|
1629
|
+
reuse_topology: If ``True``, reuse the product topology information
|
|
1630
|
+
stored in ``work_arrays`` rather than recompute it from scratch.
|
|
1631
|
+
The matrices ``x``, ``y`` and ``z`` must be structurally similar to
|
|
1632
|
+
the previous call in which ``work_arrays`` were populated.
|
|
1633
|
+
This is necessary for ``bsr_mm`` to be captured in a CUDA graph.
|
|
1562
1634
|
"""
|
|
1563
1635
|
|
|
1564
1636
|
x, x_scale = _extract_matrix_and_scale(x)
|
|
@@ -1567,12 +1639,15 @@ def bsr_mm(
|
|
|
1567
1639
|
alpha *= y_scale
|
|
1568
1640
|
|
|
1569
1641
|
if z is None:
|
|
1642
|
+
if masked:
|
|
1643
|
+
raise ValueError("Left-hand-side 'z' matrix must be provided for masked multiplication")
|
|
1644
|
+
|
|
1570
1645
|
# If not output matrix is provided, allocate it for convenience
|
|
1571
1646
|
z_block_shape = (x.block_shape[0], y.block_shape[1])
|
|
1572
1647
|
if z_block_shape == (1, 1):
|
|
1573
1648
|
z_block_type = x.scalar_type
|
|
1574
1649
|
else:
|
|
1575
|
-
z_block_type = wp.
|
|
1650
|
+
z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
|
|
1576
1651
|
z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
|
|
1577
1652
|
beta = 0.0
|
|
1578
1653
|
|
|
@@ -1598,14 +1673,22 @@ def bsr_mm(
|
|
|
1598
1673
|
# Easy case
|
|
1599
1674
|
return bsr_scale(z, beta)
|
|
1600
1675
|
|
|
1601
|
-
if not isinstance(alpha, z.scalar_type):
|
|
1602
|
-
alpha = z.scalar_type(alpha)
|
|
1603
|
-
if not isinstance(beta, z.scalar_type):
|
|
1604
|
-
beta = z.scalar_type(beta)
|
|
1605
|
-
|
|
1606
1676
|
z_aliasing = z == x or z == y
|
|
1607
1677
|
|
|
1608
|
-
if
|
|
1678
|
+
if masked:
|
|
1679
|
+
# no need to copy z, scale in-place
|
|
1680
|
+
copied_z_nnz = 0
|
|
1681
|
+
mm_nnz = z.nnz
|
|
1682
|
+
|
|
1683
|
+
if z_aliasing:
|
|
1684
|
+
raise ValueError("`masked=True` is not supported for aliased inputs")
|
|
1685
|
+
|
|
1686
|
+
if beta == 0.0:
|
|
1687
|
+
# do not bsr_scale(0), this would not preserve topology
|
|
1688
|
+
z.values.zero_()
|
|
1689
|
+
else:
|
|
1690
|
+
bsr_scale(z, beta)
|
|
1691
|
+
elif reuse_topology:
|
|
1609
1692
|
if work_arrays is None:
|
|
1610
1693
|
raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
|
|
1611
1694
|
|
|
@@ -1618,133 +1701,142 @@ def bsr_mm(
|
|
|
1618
1701
|
if work_arrays is None:
|
|
1619
1702
|
work_arrays = bsr_mm_work_arrays()
|
|
1620
1703
|
|
|
1621
|
-
work_arrays._allocate_stage_1(device, z, beta, z_aliasing)
|
|
1704
|
+
work_arrays._allocate_stage_1(device, x.nnz, z, beta, z_aliasing)
|
|
1622
1705
|
copied_z_nnz = work_arrays._copied_z_nnz
|
|
1623
1706
|
|
|
1624
1707
|
# Prefix sum of number of (unmerged) mm blocks per row
|
|
1708
|
+
work_arrays._mm_block_counts.zero_()
|
|
1625
1709
|
wp.launch(
|
|
1626
1710
|
kernel=_bsr_mm_count_coeffs,
|
|
1627
1711
|
device=device,
|
|
1628
1712
|
dim=z.nrow,
|
|
1629
1713
|
inputs=[
|
|
1714
|
+
y.ncol,
|
|
1630
1715
|
copied_z_nnz,
|
|
1631
1716
|
x.offsets,
|
|
1632
1717
|
x.columns,
|
|
1633
1718
|
y.offsets,
|
|
1634
|
-
|
|
1719
|
+
y.columns,
|
|
1720
|
+
work_arrays._mm_row_min,
|
|
1721
|
+
work_arrays._mm_block_counts,
|
|
1635
1722
|
],
|
|
1636
1723
|
)
|
|
1637
|
-
warp.utils.array_scan(work_arrays.
|
|
1724
|
+
warp.utils.array_scan(work_arrays._mm_block_counts, work_arrays._mm_block_counts)
|
|
1638
1725
|
|
|
1639
1726
|
# Get back total counts on host -- we need a synchronization here
|
|
1640
1727
|
# Use pinned buffer from z, we are going to need it later anyway
|
|
1641
1728
|
nnz_buf, _ = z._nnz_transfer_buf_and_event()
|
|
1642
1729
|
stream = wp.get_stream(device) if device.is_cuda else None
|
|
1643
|
-
wp.copy(dest=nnz_buf, src=work_arrays.
|
|
1730
|
+
wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
|
|
1644
1731
|
if device.is_cuda:
|
|
1645
1732
|
wp.synchronize_stream(stream)
|
|
1646
1733
|
mm_nnz = int(nnz_buf.numpy()[0])
|
|
1647
1734
|
|
|
1735
|
+
if mm_nnz == copied_z_nnz:
|
|
1736
|
+
# x@y = 0
|
|
1737
|
+
return bsr_scale(z, beta)
|
|
1738
|
+
|
|
1648
1739
|
work_arrays._allocate_stage_2(mm_nnz)
|
|
1649
1740
|
|
|
1650
1741
|
# If z has a non-zero scale, save current data before overwriting it
|
|
1651
1742
|
if copied_z_nnz > 0:
|
|
1652
1743
|
# Copy z row and column indices
|
|
1653
1744
|
wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
|
|
1654
|
-
|
|
1655
|
-
kernel=_bsr_get_block_row,
|
|
1656
|
-
device=device,
|
|
1657
|
-
dim=copied_z_nnz,
|
|
1658
|
-
inputs=[0, z.nrow, z.offsets, work_arrays._mm_rows],
|
|
1659
|
-
)
|
|
1745
|
+
z.uncompress_rows(out=work_arrays._mm_rows)
|
|
1660
1746
|
if z_aliasing:
|
|
1661
1747
|
# If z is aliasing with x or y, need to save topology as well
|
|
1662
1748
|
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
1663
1749
|
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
1664
1750
|
|
|
1665
1751
|
# Fill unmerged mm blocks rows and columns
|
|
1752
|
+
work_arrays._mm_rows[copied_z_nnz:].fill_(-1)
|
|
1666
1753
|
wp.launch(
|
|
1667
1754
|
kernel=_bsr_mm_list_coeffs,
|
|
1668
1755
|
device=device,
|
|
1669
|
-
dim=
|
|
1756
|
+
dim=x.nnz,
|
|
1670
1757
|
inputs=[
|
|
1758
|
+
x.nrow,
|
|
1671
1759
|
x.offsets,
|
|
1672
1760
|
x.columns,
|
|
1673
1761
|
y.offsets,
|
|
1674
1762
|
y.columns,
|
|
1675
|
-
work_arrays.
|
|
1763
|
+
work_arrays._mm_row_min,
|
|
1764
|
+
work_arrays._mm_block_counts,
|
|
1676
1765
|
work_arrays._mm_rows,
|
|
1677
1766
|
work_arrays._mm_cols,
|
|
1678
1767
|
],
|
|
1679
1768
|
)
|
|
1680
1769
|
|
|
1770
|
+
alpha = z.scalar_type(alpha)
|
|
1771
|
+
beta = z.scalar_type(beta)
|
|
1772
|
+
|
|
1681
1773
|
if copied_z_nnz > 0:
|
|
1682
1774
|
# Save current z values in temporary buffer
|
|
1683
1775
|
wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
|
|
1684
1776
|
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
z.columns
|
|
1777
|
+
if not masked:
|
|
1778
|
+
# Increase dest array size if needed
|
|
1779
|
+
if z.columns.shape[0] < mm_nnz:
|
|
1780
|
+
z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
|
|
1688
1781
|
|
|
1689
|
-
|
|
1782
|
+
from warp.context import runtime
|
|
1690
1783
|
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1784
|
+
if device.is_cpu:
|
|
1785
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
1786
|
+
else:
|
|
1787
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
1695
1788
|
|
|
1696
|
-
|
|
1789
|
+
nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
|
|
1697
1790
|
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
1706
|
-
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1791
|
+
with wp.ScopedDevice(z.device):
|
|
1792
|
+
native_func(
|
|
1793
|
+
z.block_shape[0],
|
|
1794
|
+
z.block_shape[1],
|
|
1795
|
+
z.nrow,
|
|
1796
|
+
mm_nnz,
|
|
1797
|
+
ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1798
|
+
ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1799
|
+
0,
|
|
1800
|
+
False,
|
|
1801
|
+
masked,
|
|
1802
|
+
ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1803
|
+
ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1804
|
+
0,
|
|
1805
|
+
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1806
|
+
nnz_event,
|
|
1807
|
+
)
|
|
1714
1808
|
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
_bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
|
|
1809
|
+
# Resize z to fit mm result if necessary
|
|
1810
|
+
# If we are not reusing the product topology, this needs another synchronization
|
|
1811
|
+
if not reuse_topology:
|
|
1812
|
+
work_arrays.result_nnz = z.nnz_sync()
|
|
1720
1813
|
|
|
1721
|
-
|
|
1814
|
+
_bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
|
|
1815
|
+
z.values.zero_()
|
|
1722
1816
|
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
|
|
1732
|
-
|
|
1733
|
-
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1739
|
-
|
|
1817
|
+
if copied_z_nnz > 0:
|
|
1818
|
+
# Add back original z values
|
|
1819
|
+
wp.launch(
|
|
1820
|
+
kernel=_bsr_axpy_add_block,
|
|
1821
|
+
device=device,
|
|
1822
|
+
dim=copied_z_nnz,
|
|
1823
|
+
inputs=[
|
|
1824
|
+
0,
|
|
1825
|
+
beta,
|
|
1826
|
+
work_arrays._mm_rows,
|
|
1827
|
+
work_arrays._mm_cols,
|
|
1828
|
+
z.offsets,
|
|
1829
|
+
z.columns,
|
|
1830
|
+
work_arrays._old_z_values,
|
|
1831
|
+
z.values,
|
|
1832
|
+
],
|
|
1833
|
+
)
|
|
1740
1834
|
|
|
1741
1835
|
# Add mm blocks to z values
|
|
1742
|
-
if (
|
|
1743
|
-
warp.types.type_is_matrix(z.values.dtype)
|
|
1744
|
-
):
|
|
1836
|
+
if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
|
|
1745
1837
|
# Result block type is scalar, but operands are matrices
|
|
1746
1838
|
# Cast result to (1x1) matrix to perform multiplication
|
|
1747
|
-
mm_values = z.values.view(wp.
|
|
1839
|
+
mm_values = z.values.view(wp.mat(shape=(1, 1), dtype=z.scalar_type))
|
|
1748
1840
|
else:
|
|
1749
1841
|
mm_values = z.values
|
|
1750
1842
|
|
|
@@ -1817,15 +1909,31 @@ def _bsr_mv_transpose_kernel(
|
|
|
1817
1909
|
wp.atomic_add(y, A_columns[block], v)
|
|
1818
1910
|
|
|
1819
1911
|
|
|
1820
|
-
def
|
|
1821
|
-
|
|
1912
|
+
def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
|
|
1913
|
+
# cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
|
|
1914
|
+
|
|
1915
|
+
scalar_count = array.size * type_length(array.dtype)
|
|
1916
|
+
if scalar_count != expected_scalar_count:
|
|
1917
|
+
raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
|
|
1918
|
+
|
|
1919
|
+
if array.ndim == 1 and types_equal(array.dtype, dtype):
|
|
1822
1920
|
return array
|
|
1823
1921
|
|
|
1922
|
+
if type_scalar_type(array.dtype) != type_scalar_type(dtype):
|
|
1923
|
+
raise ValueError(f"Incompatible scalar types, {type_repr(array.dtype)} vs {type_repr(dtype)}")
|
|
1924
|
+
|
|
1824
1925
|
if array.ndim > 2:
|
|
1825
1926
|
raise ValueError(f"Incompatible array number of dimensions {array.ndim}")
|
|
1826
1927
|
|
|
1827
1928
|
if not array.is_contiguous:
|
|
1828
|
-
raise ValueError("
|
|
1929
|
+
raise ValueError("Array must be contiguous")
|
|
1930
|
+
|
|
1931
|
+
vec_length = type_length(dtype)
|
|
1932
|
+
vec_count = scalar_count // vec_length
|
|
1933
|
+
if vec_count * vec_length != scalar_count:
|
|
1934
|
+
raise ValueError(
|
|
1935
|
+
f"Array of shape {array.shape} and type {type_repr(array.dtype)} cannot be reshaped to an array of type {type_repr(dtype)}"
|
|
1936
|
+
)
|
|
1829
1937
|
|
|
1830
1938
|
def vec_view(array):
|
|
1831
1939
|
return wp.array(
|
|
@@ -1833,8 +1941,8 @@ def _bsr_mv_as_vec_array(array: wp.array) -> wp.array:
|
|
|
1833
1941
|
ptr=array.ptr,
|
|
1834
1942
|
capacity=array.capacity,
|
|
1835
1943
|
device=array.device,
|
|
1836
|
-
dtype=
|
|
1837
|
-
shape=
|
|
1944
|
+
dtype=dtype,
|
|
1945
|
+
shape=vec_count,
|
|
1838
1946
|
grad=None if array.grad is None else vec_view(array.grad),
|
|
1839
1947
|
)
|
|
1840
1948
|
|
|
@@ -1852,20 +1960,20 @@ def bsr_mv(
|
|
|
1852
1960
|
transpose: bool = False,
|
|
1853
1961
|
work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1854
1962
|
) -> "Array[Vector[Rows, Scalar] | Scalar]":
|
|
1855
|
-
"""
|
|
1856
|
-
Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
|
|
1963
|
+
"""Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
|
|
1857
1964
|
|
|
1858
|
-
The
|
|
1965
|
+
The ``x`` and ``y`` vectors are allowed to alias.
|
|
1859
1966
|
|
|
1860
1967
|
Args:
|
|
1861
1968
|
A: Read-only, left matrix factor of the matrix-vector product.
|
|
1862
1969
|
x: Read-only, right vector factor of the matrix-vector product.
|
|
1863
|
-
y: Mutable left-hand-side. If
|
|
1864
|
-
alpha: Uniform scaling factor for
|
|
1865
|
-
beta: Uniform scaling factor for
|
|
1866
|
-
transpose: If ``True``, use the transpose of the matrix
|
|
1867
|
-
work_buffer: Temporary storage is required if and only if
|
|
1868
|
-
|
|
1970
|
+
y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
|
|
1971
|
+
alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
|
|
1972
|
+
beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
|
|
1973
|
+
transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
|
|
1974
|
+
work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
|
|
1975
|
+
If provided, the ``work_buffer`` array will be used for this purpose,
|
|
1976
|
+
otherwise a temporary allocation will be performed.
|
|
1869
1977
|
"""
|
|
1870
1978
|
|
|
1871
1979
|
A, A_scale = _extract_matrix_and_scale(A)
|
|
@@ -1885,22 +1993,11 @@ def bsr_mv(
|
|
|
1885
1993
|
y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype)
|
|
1886
1994
|
beta = 0.0
|
|
1887
1995
|
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
if not isinstance(beta, A.scalar_type):
|
|
1891
|
-
beta = A.scalar_type(beta)
|
|
1996
|
+
alpha = A.scalar_type(alpha)
|
|
1997
|
+
beta = A.scalar_type(beta)
|
|
1892
1998
|
|
|
1893
1999
|
if A.values.device != x.device or A.values.device != y.device:
|
|
1894
|
-
raise ValueError("A, x and y must reside on the same device")
|
|
1895
|
-
|
|
1896
|
-
if x.shape[0] != ncol:
|
|
1897
|
-
raise ValueError("Number of columns of A must match number of rows of x")
|
|
1898
|
-
if y.shape[0] != nrow:
|
|
1899
|
-
raise ValueError("Number of rows of A must match number of rows of y")
|
|
1900
|
-
|
|
1901
|
-
# View 2d arrays as arrays of vecs
|
|
1902
|
-
x = _bsr_mv_as_vec_array(x)
|
|
1903
|
-
y = _bsr_mv_as_vec_array(y)
|
|
2000
|
+
raise ValueError("A, x, and y must reside on the same device")
|
|
1904
2001
|
|
|
1905
2002
|
if x.ptr == y.ptr:
|
|
1906
2003
|
# Aliasing case, need temporary storage
|
|
@@ -1908,24 +2005,29 @@ def bsr_mv(
|
|
|
1908
2005
|
work_buffer = wp.empty_like(y)
|
|
1909
2006
|
elif work_buffer.size < y.size:
|
|
1910
2007
|
raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
|
|
1911
|
-
elif not
|
|
1912
|
-
raise ValueError(f"Work buffer must have same data type as y, {
|
|
2008
|
+
elif not types_equal(work_buffer.dtype, y.dtype):
|
|
2009
|
+
raise ValueError(f"Work buffer must have same data type as y, {type_repr(y.dtype)}")
|
|
1913
2010
|
|
|
1914
2011
|
# Save old y values before overwriting vector
|
|
1915
2012
|
wp.copy(dest=work_buffer, src=y, count=y.size)
|
|
1916
2013
|
x = work_buffer
|
|
1917
2014
|
|
|
1918
2015
|
# Promote scalar vectors to length-1 vecs and conversely
|
|
1919
|
-
if
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
if block_shape[1] == 1 and x.dtype == A.scalar_type:
|
|
1923
|
-
x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
2016
|
+
if type_is_matrix(A.values.dtype):
|
|
2017
|
+
x_dtype = wp.vec(length=block_shape[1], dtype=A.scalar_type)
|
|
2018
|
+
y_dtype = wp.vec(length=block_shape[0], dtype=A.scalar_type)
|
|
1924
2019
|
else:
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
2020
|
+
x_dtype = A.scalar_type
|
|
2021
|
+
y_dtype = A.scalar_type
|
|
2022
|
+
|
|
2023
|
+
try:
|
|
2024
|
+
x_view = _vec_array_view(x, x_dtype, expected_scalar_count=ncol * block_shape[1])
|
|
2025
|
+
except ValueError as err:
|
|
2026
|
+
raise ValueError("Incompatible 'x' vector for bsr_mv") from err
|
|
2027
|
+
try:
|
|
2028
|
+
y_view = _vec_array_view(y, y_dtype, expected_scalar_count=nrow * block_shape[0])
|
|
2029
|
+
except ValueError as err:
|
|
2030
|
+
raise ValueError("Incompatible 'y' vector for bsr_mv") from err
|
|
1929
2031
|
|
|
1930
2032
|
if transpose:
|
|
1931
2033
|
if beta.value == 0.0:
|
|
@@ -1942,14 +2044,14 @@ def bsr_mv(
|
|
|
1942
2044
|
kernel=_bsr_mv_transpose_kernel,
|
|
1943
2045
|
device=A.values.device,
|
|
1944
2046
|
dim=ncol,
|
|
1945
|
-
inputs=[alpha, A.offsets, A.columns, A.values,
|
|
2047
|
+
inputs=[alpha, A.offsets, A.columns, A.values, x_view, y_view],
|
|
1946
2048
|
)
|
|
1947
2049
|
else:
|
|
1948
2050
|
wp.launch(
|
|
1949
2051
|
kernel=_bsr_mv_kernel,
|
|
1950
2052
|
device=A.values.device,
|
|
1951
2053
|
dim=nrow,
|
|
1952
|
-
inputs=[alpha, A.offsets, A.columns, A.values,
|
|
2054
|
+
inputs=[alpha, A.offsets, A.columns, A.values, x_view, beta, y_view],
|
|
1953
2055
|
)
|
|
1954
2056
|
|
|
1955
2057
|
return y
|