warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/sparse.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c)
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,2565 +13,39 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
from warp.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
"
|
|
52
|
-
"bsr_scale",
|
|
53
|
-
"bsr_set_diag",
|
|
54
|
-
"bsr_set_from_triplets",
|
|
55
|
-
"bsr_set_identity",
|
|
56
|
-
"bsr_set_transpose",
|
|
57
|
-
"bsr_set_zero",
|
|
58
|
-
"bsr_transposed",
|
|
59
|
-
"bsr_zeros",
|
|
60
|
-
]
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
# typing hints
|
|
64
|
-
|
|
65
|
-
_BlockType = TypeVar("BlockType") # noqa: PLC0132
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class _MatrixBlockType(Generic[Rows, Cols, Scalar]):
|
|
69
|
-
pass
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class _ScalarBlockType(Generic[Scalar]):
|
|
73
|
-
pass
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
|
|
77
|
-
|
|
78
|
-
_struct_cache = {}
|
|
79
|
-
_transfer_buffer_cache = {}
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
class BsrMatrix(Generic[_BlockType]):
|
|
83
|
-
"""Untyped base class for BSR and CSR matrices.
|
|
84
|
-
|
|
85
|
-
Should not be constructed directly but through functions such as :func:`bsr_zeros`.
|
|
86
|
-
|
|
87
|
-
Attributes:
|
|
88
|
-
nrow (int): Number of rows of blocks.
|
|
89
|
-
ncol (int): Number of columns of blocks.
|
|
90
|
-
nnz (int): Upper bound for the number of non-zero blocks, used for
|
|
91
|
-
dimensioning launches. The exact number is at ``offsets[nrow-1]``.
|
|
92
|
-
See also :meth:`nnz_sync`.
|
|
93
|
-
offsets (Array[int]): Array of size at least ``1 + nrow`` such that the
|
|
94
|
-
start and end indices of the blocks of row ``r`` are ``offsets[r]``
|
|
95
|
-
and ``offsets[r+1]``, respectively.
|
|
96
|
-
columns (Array[int]): Array of size at least equal to ``nnz`` containing
|
|
97
|
-
block column indices.
|
|
98
|
-
values (Array[BlockType]): Array of size at least equal to ``nnz``
|
|
99
|
-
containing block values.
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def scalar_type(self) -> Scalar:
|
|
104
|
-
"""Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type."""
|
|
105
|
-
return type_scalar_type(self.values.dtype)
|
|
106
|
-
|
|
107
|
-
@property
|
|
108
|
-
def block_shape(self) -> Tuple[int, int]:
|
|
109
|
-
"""Shape of the individual blocks."""
|
|
110
|
-
return getattr(self.values.dtype, "_shape_", (1, 1))
|
|
111
|
-
|
|
112
|
-
@property
|
|
113
|
-
def block_size(self) -> int:
|
|
114
|
-
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
|
|
115
|
-
return type_size(self.values.dtype)
|
|
116
|
-
|
|
117
|
-
@property
|
|
118
|
-
def shape(self) -> Tuple[int, int]:
|
|
119
|
-
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block."""
|
|
120
|
-
block_shape = self.block_shape
|
|
121
|
-
return (self.nrow * block_shape[0], self.ncol * block_shape[1])
|
|
122
|
-
|
|
123
|
-
@property
|
|
124
|
-
def dtype(self) -> type:
|
|
125
|
-
"""Data type for individual block values."""
|
|
126
|
-
return self.values.dtype
|
|
127
|
-
|
|
128
|
-
@property
|
|
129
|
-
def device(self) -> wp.context.Device:
|
|
130
|
-
"""Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
|
|
131
|
-
return self.values.device
|
|
132
|
-
|
|
133
|
-
@property
|
|
134
|
-
def requires_grad(self) -> bool:
|
|
135
|
-
"""Read-only property indicating whether the matrix participates in adjoint computations."""
|
|
136
|
-
return self.values.requires_grad
|
|
137
|
-
|
|
138
|
-
@property
|
|
139
|
-
def scalar_values(self) -> wp.array:
|
|
140
|
-
"""Accesses the ``values`` array as a 3d scalar array."""
|
|
141
|
-
values_view = _as_3d_array(self.values, self.block_shape)
|
|
142
|
-
values_view._ref = self.values # keep ref in case we're garbage collected
|
|
143
|
-
return values_view
|
|
144
|
-
|
|
145
|
-
def uncompress_rows(self, out: wp.array = None) -> wp.array:
|
|
146
|
-
"""Compute the row index for each non-zero block from the compressed row offsets."""
|
|
147
|
-
if out is None:
|
|
148
|
-
out = wp.empty(self.nnz, dtype=int, device=self.device)
|
|
149
|
-
|
|
150
|
-
wp.launch(
|
|
151
|
-
kernel=_bsr_get_block_row,
|
|
152
|
-
device=self.device,
|
|
153
|
-
dim=self.nnz,
|
|
154
|
-
inputs=[self.nrow, self.offsets, out],
|
|
155
|
-
)
|
|
156
|
-
return out
|
|
157
|
-
|
|
158
|
-
def nnz_sync(self):
|
|
159
|
-
"""Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
|
|
160
|
-
or, if none has been scheduled yet, starts a new transfer and waits for it to complete.
|
|
161
|
-
Then updates the nnz upper bound.
|
|
162
|
-
|
|
163
|
-
See also :meth:`copy_nnz_async`.
|
|
164
|
-
"""
|
|
165
|
-
|
|
166
|
-
buf, event = self._nnz_transfer_if_any()
|
|
167
|
-
if buf is None:
|
|
168
|
-
self.copy_nnz_async()
|
|
169
|
-
buf, event = self._nnz_transfer_if_any()
|
|
170
|
-
|
|
171
|
-
if event is not None:
|
|
172
|
-
wp.synchronize_event(event)
|
|
173
|
-
self.nnz = int(buf.numpy()[0])
|
|
174
|
-
return self.nnz
|
|
175
|
-
|
|
176
|
-
def copy_nnz_async(self) -> None:
|
|
177
|
-
"""
|
|
178
|
-
Start the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
|
|
179
|
-
|
|
180
|
-
Needs to be called whenever the offsets array has been modified from outside ``warp.sparse``.
|
|
181
|
-
|
|
182
|
-
See also :meth:`nnz_sync`.
|
|
183
|
-
"""
|
|
184
|
-
|
|
185
|
-
buf, event = self._setup_nnz_transfer()
|
|
186
|
-
stream = wp.get_stream(self.device) if self.device.is_cuda else None
|
|
187
|
-
wp.copy(src=self.offsets, dest=buf, src_offset=self.nrow, count=1, stream=stream)
|
|
188
|
-
if event is not None:
|
|
189
|
-
stream.record_event(event)
|
|
190
|
-
|
|
191
|
-
def _setup_nnz_transfer(self):
|
|
192
|
-
buf, event = self._nnz_transfer_if_any()
|
|
193
|
-
if buf is not None:
|
|
194
|
-
return buf, event
|
|
195
|
-
|
|
196
|
-
buf, event = _allocate_transfer_buf(self.device)
|
|
197
|
-
if buf is not None:
|
|
198
|
-
BsrMatrix.__setattr__(self, "_nnz_transfer", (buf, event))
|
|
199
|
-
|
|
200
|
-
return buf, event
|
|
201
|
-
|
|
202
|
-
def _nnz_transfer_if_any(self):
|
|
203
|
-
return getattr(self, "_nnz_transfer", (None, None))
|
|
204
|
-
|
|
205
|
-
def __del__(self):
|
|
206
|
-
buf, event = self._nnz_transfer_if_any()
|
|
207
|
-
if buf is not None:
|
|
208
|
-
_redeem_transfer_buf(self.device, buf, event)
|
|
209
|
-
|
|
210
|
-
# Overloaded math operators
|
|
211
|
-
def __add__(self, y):
|
|
212
|
-
return bsr_axpy(y, bsr_copy(self))
|
|
213
|
-
|
|
214
|
-
def __iadd__(self, y):
|
|
215
|
-
return bsr_axpy(y, self)
|
|
216
|
-
|
|
217
|
-
def __radd__(self, x):
|
|
218
|
-
return bsr_axpy(x, bsr_copy(self))
|
|
219
|
-
|
|
220
|
-
def __sub__(self, y):
|
|
221
|
-
return bsr_axpy(y, bsr_copy(self), alpha=-1.0)
|
|
222
|
-
|
|
223
|
-
def __rsub__(self, x):
|
|
224
|
-
return bsr_axpy(x, bsr_copy(self), beta=-1.0)
|
|
225
|
-
|
|
226
|
-
def __isub__(self, y):
|
|
227
|
-
return bsr_axpy(y, self, alpha=-1.0)
|
|
228
|
-
|
|
229
|
-
def __mul__(self, y):
|
|
230
|
-
return _BsrScalingExpression(self, y)
|
|
231
|
-
|
|
232
|
-
def __rmul__(self, x):
|
|
233
|
-
return _BsrScalingExpression(self, x)
|
|
234
|
-
|
|
235
|
-
def __imul__(self, y):
|
|
236
|
-
return bsr_scale(self, y)
|
|
237
|
-
|
|
238
|
-
def __matmul__(self, y):
|
|
239
|
-
if isinstance(y, wp.array):
|
|
240
|
-
return bsr_mv(self, y)
|
|
241
|
-
|
|
242
|
-
return bsr_mm(self, y)
|
|
243
|
-
|
|
244
|
-
def __rmatmul__(self, x):
|
|
245
|
-
if isinstance(x, wp.array):
|
|
246
|
-
return bsr_mv(self, x, transpose=True)
|
|
247
|
-
|
|
248
|
-
return bsr_mm(x, self)
|
|
249
|
-
|
|
250
|
-
def __imatmul__(self, y):
|
|
251
|
-
return bsr_mm(self, y, self)
|
|
252
|
-
|
|
253
|
-
def __truediv__(self, y):
|
|
254
|
-
return _BsrScalingExpression(self, 1.0 / y)
|
|
255
|
-
|
|
256
|
-
def __neg__(self):
|
|
257
|
-
return _BsrScalingExpression(self, -1.0)
|
|
258
|
-
|
|
259
|
-
def transpose(self):
|
|
260
|
-
"""Return a transposed copy of this matrix."""
|
|
261
|
-
return bsr_transposed(self)
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
def _allocate_transfer_buf(device):
|
|
265
|
-
if device.ordinal in _transfer_buffer_cache:
|
|
266
|
-
all_, pool = _transfer_buffer_cache[device.ordinal]
|
|
267
|
-
else:
|
|
268
|
-
all_ = []
|
|
269
|
-
pool = []
|
|
270
|
-
_transfer_buffer_cache[device.ordinal] = (all_, pool)
|
|
271
|
-
|
|
272
|
-
if pool:
|
|
273
|
-
return pool.pop()
|
|
274
|
-
|
|
275
|
-
if device.is_capturing:
|
|
276
|
-
return None, None
|
|
277
|
-
|
|
278
|
-
buf = wp.empty(dtype=int, shape=(1,), device="cpu", pinned=device.is_cuda)
|
|
279
|
-
event = wp.Event(device) if device.is_cuda else None
|
|
280
|
-
all_.append((buf, event)) # keep a reference to the buffer and event, prevent garbage collection before redeem
|
|
281
|
-
return buf, event
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
def _redeem_transfer_buf(device, buf, event):
|
|
285
|
-
all_, pool = _transfer_buffer_cache[device.ordinal]
|
|
286
|
-
pool.append((buf, event))
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
def bsr_matrix_t(dtype: BlockType):
|
|
290
|
-
dtype = type_to_warp(dtype)
|
|
291
|
-
|
|
292
|
-
if not type_is_matrix(dtype) and dtype not in scalar_types:
|
|
293
|
-
raise ValueError(f"BsrMatrix block type must be either warp matrix or scalar; got {type_repr(dtype)}")
|
|
294
|
-
|
|
295
|
-
class BsrMatrixTyped(BsrMatrix):
|
|
296
|
-
nrow: int
|
|
297
|
-
"""Number of rows of blocks."""
|
|
298
|
-
ncol: int
|
|
299
|
-
"""Number of columns of blocks."""
|
|
300
|
-
nnz: int
|
|
301
|
-
"""Upper bound for the number of non-zeros."""
|
|
302
|
-
offsets: wp.array(dtype=int)
|
|
303
|
-
"""Array of size at least ``1 + nrow``."""
|
|
304
|
-
columns: wp.array(dtype=int)
|
|
305
|
-
"""Array of size at least equal to ``nnz``."""
|
|
306
|
-
values: wp.array(dtype=dtype)
|
|
307
|
-
|
|
308
|
-
module = wp.get_module(BsrMatrix.__module__)
|
|
309
|
-
|
|
310
|
-
if hasattr(dtype, "_shape_"):
|
|
311
|
-
type_str = f"{type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
|
|
312
|
-
else:
|
|
313
|
-
type_str = dtype.__name__
|
|
314
|
-
key = f"{BsrMatrix.__qualname__}_{type_str}"
|
|
315
|
-
|
|
316
|
-
if key not in _struct_cache:
|
|
317
|
-
_struct_cache[key] = wp.codegen.Struct(
|
|
318
|
-
key=key,
|
|
319
|
-
cls=BsrMatrixTyped,
|
|
320
|
-
module=module,
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
return _struct_cache[key]
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
def bsr_zeros(
|
|
327
|
-
rows_of_blocks: int,
|
|
328
|
-
cols_of_blocks: int,
|
|
329
|
-
block_type: BlockType,
|
|
330
|
-
device: wp.context.Devicelike = None,
|
|
331
|
-
) -> BsrMatrix:
|
|
332
|
-
"""Construct and return an empty BSR or CSR matrix with the given shape.
|
|
333
|
-
|
|
334
|
-
Args:
|
|
335
|
-
bsr: The BSR or CSR matrix to set to zero.
|
|
336
|
-
rows_of_blocks: Number of rows of blocks.
|
|
337
|
-
cols_of_blocks: Number of columns of blocks.
|
|
338
|
-
block_type: Type of individual blocks.
|
|
339
|
-
For CSR matrices, this should be a scalar type.
|
|
340
|
-
For BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`).
|
|
341
|
-
device: Device on which to allocate the matrix arrays.
|
|
342
|
-
"""
|
|
343
|
-
|
|
344
|
-
bsr = bsr_matrix_t(block_type)()
|
|
345
|
-
|
|
346
|
-
bsr.nrow = int(rows_of_blocks)
|
|
347
|
-
bsr.ncol = int(cols_of_blocks)
|
|
348
|
-
bsr.nnz = 0
|
|
349
|
-
bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
|
|
350
|
-
bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
|
|
351
|
-
bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
|
|
352
|
-
|
|
353
|
-
return bsr
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
def _bsr_ensure_fits(bsr: BsrMatrix, nrow: Optional[int] = None, nnz: Optional[int] = None) -> None:
|
|
357
|
-
if nrow is None:
|
|
358
|
-
nrow = bsr.nrow
|
|
359
|
-
if nnz is None:
|
|
360
|
-
nnz = bsr.nnz
|
|
361
|
-
else:
|
|
362
|
-
# update nnz upper bound
|
|
363
|
-
bsr.nnz = int(nnz)
|
|
364
|
-
|
|
365
|
-
if bsr.offsets.size < nrow + 1:
|
|
366
|
-
bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
|
|
367
|
-
if bsr.columns.size < nnz:
|
|
368
|
-
bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
|
|
369
|
-
if bsr.values.size < nnz:
|
|
370
|
-
bsr.values = wp.empty(
|
|
371
|
-
shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device, requires_grad=bsr.values.requires_grad
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
def bsr_set_zero(
|
|
376
|
-
bsr: BsrMatrix,
|
|
377
|
-
rows_of_blocks: Optional[int] = None,
|
|
378
|
-
cols_of_blocks: Optional[int] = None,
|
|
379
|
-
):
|
|
380
|
-
"""Set a BSR matrix to zero, possibly changing its size.
|
|
381
|
-
|
|
382
|
-
Args:
|
|
383
|
-
bsr: The BSR or CSR matrix to set to zero.
|
|
384
|
-
rows_of_blocks: If not ``None``, the new number of rows of blocks.
|
|
385
|
-
cols_of_blocks: If not ``None``, the new number of columns of blocks.
|
|
386
|
-
"""
|
|
387
|
-
|
|
388
|
-
if rows_of_blocks is not None:
|
|
389
|
-
bsr.nrow = int(rows_of_blocks)
|
|
390
|
-
if cols_of_blocks is not None:
|
|
391
|
-
bsr.ncol = int(cols_of_blocks)
|
|
392
|
-
|
|
393
|
-
_bsr_ensure_fits(bsr, nnz=0)
|
|
394
|
-
bsr.offsets.zero_()
|
|
395
|
-
bsr.copy_nnz_async()
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
def _as_3d_array(arr, block_shape):
|
|
399
|
-
return wp.array(
|
|
400
|
-
ptr=arr.ptr,
|
|
401
|
-
capacity=arr.capacity,
|
|
402
|
-
device=arr.device,
|
|
403
|
-
dtype=type_scalar_type(arr.dtype),
|
|
404
|
-
shape=(arr.shape[0], *block_shape),
|
|
405
|
-
grad=None if arr.grad is None else _as_3d_array(arr.grad, block_shape),
|
|
406
|
-
)
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
def _optional_ctypes_pointer(array: Optional[wp.array], ctype):
|
|
410
|
-
return None if array is None else ctypes.cast(array.ptr, ctypes.POINTER(ctype))
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
def _optional_ctypes_event(event: Optional[wp.Event]):
|
|
414
|
-
return None if event is None else event.cuda_event
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
_zero_value_masks = {
|
|
418
|
-
wp.float16: 0x7FFF,
|
|
419
|
-
wp.float32: 0x7FFFFFFF,
|
|
420
|
-
wp.float64: 0x7FFFFFFFFFFFFFFF,
|
|
421
|
-
wp.int8: 0xFF,
|
|
422
|
-
wp.int16: 0xFFFF,
|
|
423
|
-
wp.int32: 0xFFFFFFFF,
|
|
424
|
-
wp.int64: 0xFFFFFFFFFFFFFFFF,
|
|
425
|
-
}
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
@wp.kernel
|
|
429
|
-
def _bsr_accumulate_triplet_values(
|
|
430
|
-
row_count: int,
|
|
431
|
-
tpl_summed_offsets: wp.array(dtype=int),
|
|
432
|
-
tpl_summed_indices: wp.array(dtype=int),
|
|
433
|
-
tpl_values: wp.array3d(dtype=Any),
|
|
434
|
-
bsr_offsets: wp.array(dtype=int),
|
|
435
|
-
bsr_values: wp.array3d(dtype=Any),
|
|
436
|
-
):
|
|
437
|
-
block, i, j = wp.tid()
|
|
438
|
-
|
|
439
|
-
if block >= bsr_offsets[row_count]:
|
|
440
|
-
return
|
|
441
|
-
|
|
442
|
-
if block == 0:
|
|
443
|
-
beg = 0
|
|
444
|
-
else:
|
|
445
|
-
beg = tpl_summed_offsets[block - 1]
|
|
446
|
-
end = tpl_summed_offsets[block]
|
|
447
|
-
|
|
448
|
-
val = tpl_values[tpl_summed_indices[beg], i, j]
|
|
449
|
-
for k in range(beg + 1, end):
|
|
450
|
-
val += tpl_values[tpl_summed_indices[k], i, j]
|
|
451
|
-
|
|
452
|
-
bsr_values[block, i, j] = val
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
def bsr_set_from_triplets(
|
|
456
|
-
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
457
|
-
rows: "Array[int]",
|
|
458
|
-
columns: "Array[int]",
|
|
459
|
-
values: Optional["Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]"] = None,
|
|
460
|
-
count: Optional["Array[int]"] = None,
|
|
461
|
-
prune_numerical_zeros: bool = True,
|
|
462
|
-
masked: bool = False,
|
|
463
|
-
):
|
|
464
|
-
"""Fill a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
465
|
-
|
|
466
|
-
The first dimension of the three input arrays must match and indicates the number of COO triplets.
|
|
467
|
-
|
|
468
|
-
Args:
|
|
469
|
-
dest: Sparse matrix to populate.
|
|
470
|
-
rows: Row index for each non-zero.
|
|
471
|
-
columns: Columns index for each non-zero.
|
|
472
|
-
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
473
|
-
to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
|
|
474
|
-
If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
|
|
475
|
-
count: Single-element array indicating the number of triplets. If ``None``, the number of triplets is determined from the shape of
|
|
476
|
-
``rows`` and ``columns`` arrays.
|
|
477
|
-
prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
|
|
478
|
-
masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
|
|
479
|
-
"""
|
|
480
|
-
|
|
481
|
-
if rows.device != columns.device or rows.device != dest.device:
|
|
482
|
-
raise ValueError(
|
|
483
|
-
f"Rows and columns must reside on the destination matrix device, got {rows.device}, {columns.device} and {dest.device}"
|
|
484
|
-
)
|
|
485
|
-
|
|
486
|
-
if rows.shape[0] != columns.shape[0]:
|
|
487
|
-
raise ValueError(
|
|
488
|
-
f"Rows and columns arrays must have the same length, got {rows.shape[0]} and {columns.shape[0]}"
|
|
489
|
-
)
|
|
490
|
-
|
|
491
|
-
if rows.dtype != wp.int32 or columns.dtype != wp.int32:
|
|
492
|
-
raise TypeError("Rows and columns arrays must be of type int32")
|
|
493
|
-
|
|
494
|
-
if count is not None:
|
|
495
|
-
if count.device != rows.device:
|
|
496
|
-
raise ValueError(f"Count and rows must reside on the same device, got {count.device} and {rows.device}")
|
|
497
|
-
|
|
498
|
-
if count.shape != (1,):
|
|
499
|
-
raise ValueError(f"Count array must be a single-element array, got {count.shape}")
|
|
500
|
-
|
|
501
|
-
if count.dtype != wp.int32:
|
|
502
|
-
raise TypeError("Count array must be of type int32")
|
|
503
|
-
|
|
504
|
-
# Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
|
|
505
|
-
if values is not None:
|
|
506
|
-
if values.device != rows.device:
|
|
507
|
-
raise ValueError(f"Values and rows must reside on the same device, got {values.device} and {rows.device}")
|
|
508
|
-
|
|
509
|
-
if values.shape[0] != rows.shape[0]:
|
|
510
|
-
raise ValueError(
|
|
511
|
-
f"Values and rows arrays must have the same length, got {values.shape[0]} and {rows.shape[0]}"
|
|
512
|
-
)
|
|
513
|
-
|
|
514
|
-
if values.ndim == 1:
|
|
515
|
-
if not types_equal(values.dtype, dest.values.dtype):
|
|
516
|
-
raise ValueError(
|
|
517
|
-
f"Values array type must correspond to that of the dest matrix, got {type_repr(values.dtype)} and {type_repr(dest.values.dtype)}"
|
|
518
|
-
)
|
|
519
|
-
elif values.ndim == 3:
|
|
520
|
-
if values.shape[1:] != dest.block_shape:
|
|
521
|
-
raise ValueError(
|
|
522
|
-
f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
|
|
523
|
-
)
|
|
524
|
-
|
|
525
|
-
if type_scalar_type(values.dtype) != dest.scalar_type:
|
|
526
|
-
raise ValueError(
|
|
527
|
-
f"Scalar type of values array ({type_repr(values.dtype)}) should correspond to that of matrix ({type_repr(dest.scalar_type)})"
|
|
528
|
-
)
|
|
529
|
-
else:
|
|
530
|
-
raise ValueError(f"Number of dimension for values array should be 1 or 3, got {values.ndim}")
|
|
531
|
-
|
|
532
|
-
if prune_numerical_zeros and not values.is_contiguous:
|
|
533
|
-
raise ValueError("Values array should be contiguous for numerical zero pruning")
|
|
534
|
-
|
|
535
|
-
nnz = rows.shape[0]
|
|
536
|
-
if nnz == 0:
|
|
537
|
-
bsr_set_zero(dest)
|
|
538
|
-
return
|
|
539
|
-
|
|
540
|
-
# Increase dest array sizes if needed
|
|
541
|
-
if not masked:
|
|
542
|
-
_bsr_ensure_fits(dest, nnz=nnz)
|
|
543
|
-
|
|
544
|
-
device = dest.values.device
|
|
545
|
-
scalar_type = dest.scalar_type
|
|
546
|
-
zero_value_mask = _zero_value_masks.get(scalar_type, 0) if prune_numerical_zeros else 0
|
|
547
|
-
|
|
548
|
-
# compute the BSR topology
|
|
549
|
-
|
|
550
|
-
from warp.context import runtime
|
|
551
|
-
|
|
552
|
-
if device.is_cpu:
|
|
553
|
-
native_func = runtime.core.wp_bsr_matrix_from_triplets_host
|
|
554
|
-
else:
|
|
555
|
-
native_func = runtime.core.wp_bsr_matrix_from_triplets_device
|
|
556
|
-
|
|
557
|
-
nnz_buf, nnz_event = dest._setup_nnz_transfer()
|
|
558
|
-
summed_triplet_offsets = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
|
|
559
|
-
summed_triplet_indices = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
|
|
560
|
-
|
|
561
|
-
with wp.ScopedDevice(device):
|
|
562
|
-
native_func(
|
|
563
|
-
dest.block_size,
|
|
564
|
-
type_size_in_bytes(scalar_type),
|
|
565
|
-
dest.nrow,
|
|
566
|
-
dest.ncol,
|
|
567
|
-
nnz,
|
|
568
|
-
_optional_ctypes_pointer(count, ctype=ctypes.c_int32),
|
|
569
|
-
ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
570
|
-
ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
571
|
-
_optional_ctypes_pointer(values, ctype=ctypes.c_int32),
|
|
572
|
-
zero_value_mask,
|
|
573
|
-
masked,
|
|
574
|
-
ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
575
|
-
ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
576
|
-
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
577
|
-
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
578
|
-
_optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
|
|
579
|
-
_optional_ctypes_event(nnz_event),
|
|
580
|
-
)
|
|
581
|
-
|
|
582
|
-
# now accumulate repeated blocks
|
|
583
|
-
wp.launch(
|
|
584
|
-
_bsr_accumulate_triplet_values,
|
|
585
|
-
dim=(nnz, *dest.block_shape),
|
|
586
|
-
inputs=[
|
|
587
|
-
dest.nrow,
|
|
588
|
-
summed_triplet_offsets,
|
|
589
|
-
summed_triplet_indices,
|
|
590
|
-
_as_3d_array(values, dest.block_shape),
|
|
591
|
-
dest.offsets,
|
|
592
|
-
],
|
|
593
|
-
outputs=[dest.scalar_values],
|
|
594
|
-
)
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
def bsr_from_triplets(
|
|
598
|
-
rows_of_blocks: int,
|
|
599
|
-
cols_of_blocks: int,
|
|
600
|
-
rows: "Array[int]",
|
|
601
|
-
columns: "Array[int]",
|
|
602
|
-
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
603
|
-
prune_numerical_zeros: bool = True,
|
|
604
|
-
):
|
|
605
|
-
"""Constructs a BSR matrix with values defined by coordinate-oriented (COO) triplets.
|
|
606
|
-
|
|
607
|
-
The first dimension of the three input arrays must match and indicates the number of COO triplets.
|
|
608
|
-
|
|
609
|
-
Args:
|
|
610
|
-
rows_of_blocks: Number of rows of blocks.
|
|
611
|
-
cols_of_blocks: Number of columns of blocks.
|
|
612
|
-
rows: Row index for each non-zero.
|
|
613
|
-
columns: Columns index for each non-zero.
|
|
614
|
-
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
615
|
-
to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
|
|
616
|
-
prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
|
|
617
|
-
"""
|
|
618
|
-
|
|
619
|
-
if values.ndim == 3:
|
|
620
|
-
block_type = wp.mat(shape=values.shape[1:], dtype=values.dtype)
|
|
621
|
-
else:
|
|
622
|
-
block_type = values.dtype
|
|
623
|
-
|
|
624
|
-
A = bsr_zeros(
|
|
625
|
-
rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
|
|
626
|
-
)
|
|
627
|
-
A.values.requires_grad = values.requires_grad
|
|
628
|
-
bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
|
|
629
|
-
return A
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
class _BsrExpression(Generic[_BlockType]):
|
|
633
|
-
pass
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
class _BsrScalingExpression(_BsrExpression):
|
|
637
|
-
def __init__(self, mat, scale):
|
|
638
|
-
self.mat = mat
|
|
639
|
-
self.scale = scale
|
|
640
|
-
|
|
641
|
-
def eval(self):
|
|
642
|
-
return bsr_copy(self)
|
|
643
|
-
|
|
644
|
-
@property
|
|
645
|
-
def nrow(self) -> int:
|
|
646
|
-
return self.mat.nrow
|
|
647
|
-
|
|
648
|
-
@property
|
|
649
|
-
def ncol(self) -> int:
|
|
650
|
-
return self.mat.ncol
|
|
651
|
-
|
|
652
|
-
@property
|
|
653
|
-
def nnz(self) -> int:
|
|
654
|
-
return self.mat.nnz
|
|
655
|
-
|
|
656
|
-
@property
|
|
657
|
-
def offsets(self) -> wp.array:
|
|
658
|
-
return self.mat.offsets
|
|
659
|
-
|
|
660
|
-
@property
|
|
661
|
-
def columns(self) -> wp.array:
|
|
662
|
-
return self.mat.columns
|
|
663
|
-
|
|
664
|
-
@property
|
|
665
|
-
def scalar_type(self) -> Scalar:
|
|
666
|
-
return self.mat.scalar_type
|
|
667
|
-
|
|
668
|
-
@property
|
|
669
|
-
def block_shape(self) -> Tuple[int, int]:
|
|
670
|
-
return self.mat.block_shape
|
|
671
|
-
|
|
672
|
-
@property
|
|
673
|
-
def block_size(self) -> int:
|
|
674
|
-
return self.mat.block_size
|
|
675
|
-
|
|
676
|
-
@property
|
|
677
|
-
def shape(self) -> Tuple[int, int]:
|
|
678
|
-
return self.mat.shape
|
|
679
|
-
|
|
680
|
-
@property
|
|
681
|
-
def dtype(self) -> type:
|
|
682
|
-
return self.mat.dtype
|
|
683
|
-
|
|
684
|
-
@property
|
|
685
|
-
def requires_grad(self) -> bool:
|
|
686
|
-
return self.mat.requires_grad
|
|
687
|
-
|
|
688
|
-
@property
|
|
689
|
-
def device(self) -> wp.context.Device:
|
|
690
|
-
return self.mat.device
|
|
691
|
-
|
|
692
|
-
# Overloaded math operators
|
|
693
|
-
def __add__(self, y):
|
|
694
|
-
return bsr_axpy(y, bsr_copy(self.mat), alpha=self.scale)
|
|
695
|
-
|
|
696
|
-
def __radd__(self, x):
|
|
697
|
-
return bsr_axpy(x, bsr_copy(self.mat), beta=self.scale)
|
|
698
|
-
|
|
699
|
-
def __sub__(self, y):
|
|
700
|
-
return bsr_axpy(y, bsr_copy(self.mat), alpha=-self.scale)
|
|
701
|
-
|
|
702
|
-
def __rsub__(self, x):
|
|
703
|
-
return bsr_axpy(x, bsr_copy(self.mat), beta=-self.scale)
|
|
704
|
-
|
|
705
|
-
def __mul__(self, y):
|
|
706
|
-
return _BsrScalingExpression(self.mat, y * self.scale)
|
|
707
|
-
|
|
708
|
-
def __rmul__(self, x):
|
|
709
|
-
return _BsrScalingExpression(self.mat, x * self.scale)
|
|
710
|
-
|
|
711
|
-
def __matmul__(self, y):
|
|
712
|
-
if isinstance(y, wp.array):
|
|
713
|
-
return bsr_mv(self.mat, y, alpha=self.scale)
|
|
714
|
-
|
|
715
|
-
return bsr_mm(self.mat, y, alpha=self.scale)
|
|
716
|
-
|
|
717
|
-
def __rmatmul__(self, x):
|
|
718
|
-
if isinstance(x, wp.array):
|
|
719
|
-
return bsr_mv(self.mat, x, alpha=self.scale, transpose=True)
|
|
720
|
-
|
|
721
|
-
return bsr_mm(x, self.mat, alpha=self.scale)
|
|
722
|
-
|
|
723
|
-
def __truediv__(self, y):
|
|
724
|
-
return _BsrScalingExpression(self.mat, self.scale / y)
|
|
725
|
-
|
|
726
|
-
def __neg__(self):
|
|
727
|
-
return _BsrScalingExpression(self.mat, -self.scale)
|
|
728
|
-
|
|
729
|
-
def transpose(self):
|
|
730
|
-
"""Returns a transposed copy of this matrix"""
|
|
731
|
-
return _BsrScalingExpression(self.mat.transpose(), self.scale)
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
BsrMatrixOrExpression = Union[BsrMatrix[_BlockType], _BsrExpression[_BlockType]]
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
|
|
738
|
-
if isinstance(bsr, BsrMatrix):
|
|
739
|
-
return bsr, 1.0
|
|
740
|
-
if isinstance(bsr, _BsrScalingExpression):
|
|
741
|
-
return bsr.mat, bsr.scale
|
|
742
|
-
|
|
743
|
-
raise ValueError("Argument cannot be interpreted as a BsrMatrix")
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
@wp.func
|
|
747
|
-
def _bsr_row_index(
|
|
748
|
-
offsets: wp.array(dtype=int),
|
|
749
|
-
row_count: int,
|
|
750
|
-
block: int,
|
|
751
|
-
):
|
|
752
|
-
"""Index of the row containing a block, or -1 if non-existing."""
|
|
753
|
-
return wp.where(block < offsets[row_count], wp.lower_bound(offsets, 0, row_count + 1, block + 1), 0) - 1
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
@wp.func
|
|
757
|
-
def _bsr_block_index(
|
|
758
|
-
row: int,
|
|
759
|
-
col: int,
|
|
760
|
-
bsr_offsets: wp.array(dtype=int),
|
|
761
|
-
bsr_columns: wp.array(dtype=int),
|
|
762
|
-
):
|
|
763
|
-
"""Index of the block at block-coordinates (row, col), or -1 if non-existing.
|
|
764
|
-
Assumes bsr_columns is sorted.
|
|
765
|
-
"""
|
|
766
|
-
|
|
767
|
-
if row < 0:
|
|
768
|
-
return -1
|
|
769
|
-
|
|
770
|
-
mask_row_beg = bsr_offsets[row]
|
|
771
|
-
mask_row_end = bsr_offsets[row + 1]
|
|
772
|
-
|
|
773
|
-
if mask_row_beg == mask_row_end:
|
|
774
|
-
return -1
|
|
775
|
-
|
|
776
|
-
block_index = wp.lower_bound(bsr_columns, mask_row_beg, mask_row_end, col)
|
|
777
|
-
return wp.where(bsr_columns[block_index] == col, block_index, -1)
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
@wp.kernel(enable_backward=False)
|
|
781
|
-
def _bsr_assign_list_blocks(
|
|
782
|
-
src_subrows: int,
|
|
783
|
-
src_subcols: int,
|
|
784
|
-
dest_subrows: int,
|
|
785
|
-
dest_subcols: int,
|
|
786
|
-
src_row_count: int,
|
|
787
|
-
src_offsets: wp.array(dtype=int),
|
|
788
|
-
src_columns: wp.array(dtype=int),
|
|
789
|
-
dest_rows: wp.array(dtype=int),
|
|
790
|
-
dest_cols: wp.array(dtype=int),
|
|
791
|
-
):
|
|
792
|
-
block, subrow, subcol = wp.tid()
|
|
793
|
-
dest_block = (block * src_subcols + subcol) * src_subrows + subrow
|
|
794
|
-
|
|
795
|
-
row = _bsr_row_index(src_offsets, src_row_count, block)
|
|
796
|
-
if row == -1:
|
|
797
|
-
dest_rows[dest_block] = row # invalid
|
|
798
|
-
dest_cols[dest_block] = row
|
|
799
|
-
else:
|
|
800
|
-
dest_subrow = row * src_subrows + subrow
|
|
801
|
-
dest_subcol = src_columns[block] * src_subcols + subcol
|
|
802
|
-
dest_rows[dest_block] = dest_subrow // dest_subrows
|
|
803
|
-
dest_cols[dest_block] = dest_subcol // dest_subcols
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
@wp.kernel
|
|
807
|
-
def _bsr_assign_copy_blocks(
|
|
808
|
-
scale: Any,
|
|
809
|
-
src_subrows: int,
|
|
810
|
-
src_subcols: int,
|
|
811
|
-
dest_subrows: int,
|
|
812
|
-
dest_subcols: int,
|
|
813
|
-
src_row_count: int,
|
|
814
|
-
src_offsets: wp.array(dtype=int),
|
|
815
|
-
src_columns: wp.array(dtype=int),
|
|
816
|
-
src_values: wp.array3d(dtype=Any),
|
|
817
|
-
dest_offsets: wp.array(dtype=int),
|
|
818
|
-
dest_columns: wp.array(dtype=int),
|
|
819
|
-
dest_values: wp.array3d(dtype=Any),
|
|
820
|
-
):
|
|
821
|
-
src_block = wp.tid()
|
|
822
|
-
src_block, subrow, subcol = wp.tid()
|
|
823
|
-
|
|
824
|
-
src_row = _bsr_row_index(src_offsets, src_row_count, src_block)
|
|
825
|
-
if src_row == -1:
|
|
826
|
-
return
|
|
827
|
-
|
|
828
|
-
src_col = src_columns[src_block]
|
|
829
|
-
|
|
830
|
-
dest_subrow = src_row * src_subrows + subrow
|
|
831
|
-
dest_subcol = src_col * src_subcols + subcol
|
|
832
|
-
dest_row = dest_subrow // dest_subrows
|
|
833
|
-
dest_col = dest_subcol // dest_subcols
|
|
834
|
-
|
|
835
|
-
dest_block = _bsr_block_index(dest_row, dest_col, dest_offsets, dest_columns)
|
|
836
|
-
if dest_block == -1:
|
|
837
|
-
return
|
|
838
|
-
|
|
839
|
-
split_row = dest_subrow - dest_subrows * dest_row
|
|
840
|
-
split_col = dest_subcol - dest_subcols * dest_col
|
|
841
|
-
|
|
842
|
-
rows_per_subblock = src_values.shape[1] // src_subrows
|
|
843
|
-
cols_per_subblock = src_values.shape[2] // src_subcols
|
|
844
|
-
|
|
845
|
-
dest_base_i = split_row * rows_per_subblock
|
|
846
|
-
dest_base_j = split_col * cols_per_subblock
|
|
847
|
-
|
|
848
|
-
src_base_i = subrow * rows_per_subblock
|
|
849
|
-
src_base_j = subcol * cols_per_subblock
|
|
850
|
-
|
|
851
|
-
for i in range(rows_per_subblock):
|
|
852
|
-
for j in range(cols_per_subblock):
|
|
853
|
-
dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
|
|
854
|
-
scale * src_values[src_block, i + src_base_i, j + src_base_j]
|
|
855
|
-
)
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
def bsr_assign(
|
|
859
|
-
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
860
|
-
src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
|
|
861
|
-
structure_only: bool = False,
|
|
862
|
-
masked: bool = False,
|
|
863
|
-
):
|
|
864
|
-
"""Copy the content of the ``src`` BSR matrix to ``dest``.
|
|
865
|
-
|
|
866
|
-
Args:
|
|
867
|
-
src: Matrix to be copied.
|
|
868
|
-
dest: Destination matrix. May have a different block shape or scalar type
|
|
869
|
-
than ``src``, in which case the required casting will be performed.
|
|
870
|
-
structure_only: If ``True``, only the non-zero indices are copied, and uninitialized value storage is allocated
|
|
871
|
-
to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
|
|
872
|
-
casting if the two matrices use distinct scalar types.
|
|
873
|
-
masked: If ``True``, prevent the assignment operation from adding new non-zero blocks to ``dest``.
|
|
874
|
-
"""
|
|
875
|
-
|
|
876
|
-
src, src_scale = _extract_matrix_and_scale(src)
|
|
877
|
-
|
|
878
|
-
if dest.values.device != src.values.device:
|
|
879
|
-
raise ValueError("Source and destination matrices must reside on the same device")
|
|
880
|
-
|
|
881
|
-
if src.block_shape[0] >= dest.block_shape[0]:
|
|
882
|
-
src_subrows = src.block_shape[0] // dest.block_shape[0]
|
|
883
|
-
dest_subrows = 1
|
|
884
|
-
else:
|
|
885
|
-
dest_subrows = dest.block_shape[0] // src.block_shape[0]
|
|
886
|
-
src_subrows = 1
|
|
887
|
-
|
|
888
|
-
if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
|
|
889
|
-
raise ValueError(
|
|
890
|
-
f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {dest.block_shape[0]}, {src.block_shape[0]})"
|
|
891
|
-
)
|
|
892
|
-
|
|
893
|
-
if src.block_shape[1] >= dest.block_shape[1]:
|
|
894
|
-
src_subcols = src.block_shape[1] // dest.block_shape[1]
|
|
895
|
-
dest_subcols = 1
|
|
896
|
-
else:
|
|
897
|
-
dest_subcols = dest.block_shape[1] // src.block_shape[1]
|
|
898
|
-
src_subcols = 1
|
|
899
|
-
|
|
900
|
-
if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
|
|
901
|
-
raise ValueError(
|
|
902
|
-
f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {dest.block_shape[1]}, {src.block_shape[1]})"
|
|
903
|
-
)
|
|
904
|
-
|
|
905
|
-
dest_nrow = (src.nrow * src_subrows) // dest_subrows
|
|
906
|
-
dest_ncol = (src.ncol * src_subcols) // dest_subcols
|
|
907
|
-
|
|
908
|
-
if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
|
|
909
|
-
raise ValueError(
|
|
910
|
-
f"The requested block shape {dest.block_shape} does not evenly divide the source matrix of total size {src.shape}"
|
|
911
|
-
)
|
|
912
|
-
|
|
913
|
-
nnz_alloc = src.nnz * src_subrows * src_subcols
|
|
914
|
-
if masked:
|
|
915
|
-
if dest_nrow != dest.nrow or dest_ncol != dest.ncol:
|
|
916
|
-
raise ValueError(
|
|
917
|
-
f"Incompatible destination matrix size, expected ({dest_nrow}, {dest_ncol}), got ({dest.nrow}, {dest.ncol})"
|
|
918
|
-
)
|
|
919
|
-
else:
|
|
920
|
-
dest.nrow = dest_nrow
|
|
921
|
-
dest.ncol = dest_ncol
|
|
922
|
-
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
923
|
-
|
|
924
|
-
if dest.block_shape == src.block_shape and not masked:
|
|
925
|
-
# Direct copy
|
|
926
|
-
|
|
927
|
-
wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
|
|
928
|
-
dest.copy_nnz_async()
|
|
929
|
-
|
|
930
|
-
if nnz_alloc > 0:
|
|
931
|
-
wp.copy(dest=dest.columns, src=src.columns, count=nnz_alloc)
|
|
932
|
-
|
|
933
|
-
if not structure_only:
|
|
934
|
-
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
|
|
935
|
-
bsr_scale(dest, src_scale)
|
|
936
|
-
|
|
937
|
-
else:
|
|
938
|
-
# Masked and/or multiple src blocks per dest block, go through COO format
|
|
939
|
-
|
|
940
|
-
# Compute destination rows and columns
|
|
941
|
-
dest_rows = wp.empty(nnz_alloc, dtype=int, device=dest.device)
|
|
942
|
-
dest_cols = wp.empty(nnz_alloc, dtype=int, device=dest.device)
|
|
943
|
-
wp.launch(
|
|
944
|
-
_bsr_assign_list_blocks,
|
|
945
|
-
dim=(src.nnz, src_subrows, src_subcols),
|
|
946
|
-
device=dest.device,
|
|
947
|
-
inputs=[
|
|
948
|
-
src_subrows,
|
|
949
|
-
src_subcols,
|
|
950
|
-
dest_subrows,
|
|
951
|
-
dest_subcols,
|
|
952
|
-
src.nrow,
|
|
953
|
-
src.offsets,
|
|
954
|
-
src.columns,
|
|
955
|
-
dest_rows,
|
|
956
|
-
dest_cols,
|
|
957
|
-
],
|
|
958
|
-
)
|
|
959
|
-
|
|
960
|
-
# Compute destination offsets from triplets
|
|
961
|
-
from warp.context import runtime
|
|
962
|
-
|
|
963
|
-
if dest.device.is_cpu:
|
|
964
|
-
native_func = runtime.core.wp_bsr_matrix_from_triplets_host
|
|
965
|
-
else:
|
|
966
|
-
native_func = runtime.core.wp_bsr_matrix_from_triplets_device
|
|
967
|
-
|
|
968
|
-
nnz_buf, nnz_event = dest._setup_nnz_transfer()
|
|
969
|
-
with wp.ScopedDevice(dest.device):
|
|
970
|
-
native_func(
|
|
971
|
-
dest.block_size,
|
|
972
|
-
0, # scalar_size_in_bytes
|
|
973
|
-
dest.nrow,
|
|
974
|
-
dest.ncol,
|
|
975
|
-
nnz_alloc,
|
|
976
|
-
None, # device nnz
|
|
977
|
-
ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
978
|
-
ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
979
|
-
None, # triplet values
|
|
980
|
-
0, # zero_value_mask
|
|
981
|
-
masked,
|
|
982
|
-
None, # summed block offsets
|
|
983
|
-
None, # summed block indices
|
|
984
|
-
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
985
|
-
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
986
|
-
_optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
|
|
987
|
-
_optional_ctypes_event(nnz_event),
|
|
988
|
-
)
|
|
989
|
-
|
|
990
|
-
# merge block values
|
|
991
|
-
if not structure_only:
|
|
992
|
-
dest.values.zero_()
|
|
993
|
-
wp.launch(
|
|
994
|
-
_bsr_assign_copy_blocks,
|
|
995
|
-
dim=(src.nnz, src_subrows, src_subcols),
|
|
996
|
-
device=dest.device,
|
|
997
|
-
inputs=[
|
|
998
|
-
src.scalar_type(src_scale),
|
|
999
|
-
src_subrows,
|
|
1000
|
-
src_subcols,
|
|
1001
|
-
dest_subrows,
|
|
1002
|
-
dest_subcols,
|
|
1003
|
-
src.nrow,
|
|
1004
|
-
src.offsets,
|
|
1005
|
-
src.columns,
|
|
1006
|
-
src.scalar_values,
|
|
1007
|
-
dest.offsets,
|
|
1008
|
-
dest.columns,
|
|
1009
|
-
dest.scalar_values,
|
|
1010
|
-
],
|
|
1011
|
-
)
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
def bsr_copy(
|
|
1015
|
-
A: BsrMatrixOrExpression,
|
|
1016
|
-
scalar_type: Optional[Scalar] = None,
|
|
1017
|
-
block_shape: Optional[Tuple[int, int]] = None,
|
|
1018
|
-
structure_only: bool = False,
|
|
1019
|
-
):
|
|
1020
|
-
"""Return a copy of matrix ``A``, possibly changing its scalar type.
|
|
1021
|
-
|
|
1022
|
-
Args:
|
|
1023
|
-
A: Matrix to be copied.
|
|
1024
|
-
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from ``A``.
|
|
1025
|
-
block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from ``A``.
|
|
1026
|
-
Both dimensions of ``block_shape`` must be either a multiple or an exact divider of the ones from ``A``.
|
|
1027
|
-
structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
|
|
1028
|
-
to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
|
|
1029
|
-
casting if the two matrices use distinct scalar types.
|
|
1030
|
-
"""
|
|
1031
|
-
if scalar_type is None:
|
|
1032
|
-
scalar_type = A.scalar_type
|
|
1033
|
-
if block_shape is None:
|
|
1034
|
-
block_shape = A.block_shape
|
|
1035
|
-
|
|
1036
|
-
if block_shape == (1, 1):
|
|
1037
|
-
block_type = scalar_type
|
|
1038
|
-
else:
|
|
1039
|
-
block_type = wp.mat(shape=block_shape, dtype=scalar_type)
|
|
1040
|
-
|
|
1041
|
-
copy = bsr_zeros(
|
|
1042
|
-
rows_of_blocks=A.nrow,
|
|
1043
|
-
cols_of_blocks=A.ncol,
|
|
1044
|
-
block_type=block_type,
|
|
1045
|
-
device=A.device,
|
|
1046
|
-
)
|
|
1047
|
-
copy.values.requires_grad = A.requires_grad
|
|
1048
|
-
bsr_assign(dest=copy, src=A, structure_only=structure_only)
|
|
1049
|
-
return copy
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
@wp.kernel
|
|
1053
|
-
def _bsr_transpose_values(
|
|
1054
|
-
col_count: int,
|
|
1055
|
-
scale: Any,
|
|
1056
|
-
bsr_values: wp.array3d(dtype=Any),
|
|
1057
|
-
block_index_map: wp.array(dtype=int),
|
|
1058
|
-
transposed_bsr_offsets: wp.array(dtype=int),
|
|
1059
|
-
transposed_bsr_values: wp.array3d(dtype=Any),
|
|
1060
|
-
):
|
|
1061
|
-
block, i, j = wp.tid()
|
|
1062
|
-
|
|
1063
|
-
if block >= transposed_bsr_offsets[col_count]:
|
|
1064
|
-
return
|
|
1065
|
-
|
|
1066
|
-
transposed_bsr_values[block, i, j] = bsr_values[block_index_map[block], j, i] * scale
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
def bsr_set_transpose(
|
|
1070
|
-
dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
|
|
1071
|
-
src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
|
|
1072
|
-
):
|
|
1073
|
-
"""Assign the transposed matrix ``src`` to matrix ``dest``."""
|
|
1074
|
-
|
|
1075
|
-
src, src_scale = _extract_matrix_and_scale(src)
|
|
1076
|
-
|
|
1077
|
-
if dest.values.device != src.values.device:
|
|
1078
|
-
raise ValueError(
|
|
1079
|
-
f"All arguments must reside on the same device, got {dest.values.device} and {src.values.device}"
|
|
1080
|
-
)
|
|
1081
|
-
|
|
1082
|
-
if dest.scalar_type != src.scalar_type:
|
|
1083
|
-
raise ValueError(f"All arguments must have the same scalar type, got {dest.scalar_type} and {src.scalar_type}")
|
|
1084
|
-
|
|
1085
|
-
transpose_block_shape = src.block_shape[::-1]
|
|
1086
|
-
|
|
1087
|
-
if dest.block_shape != transpose_block_shape:
|
|
1088
|
-
raise ValueError(f"Destination block shape must be {transpose_block_shape}, got {dest.block_shape}")
|
|
1089
|
-
|
|
1090
|
-
nnz = src.nnz
|
|
1091
|
-
dest.nrow = src.ncol
|
|
1092
|
-
dest.ncol = src.nrow
|
|
1093
|
-
|
|
1094
|
-
if nnz == 0:
|
|
1095
|
-
bsr_set_zero(dest)
|
|
1096
|
-
return
|
|
1097
|
-
|
|
1098
|
-
# Increase dest array sizes if needed
|
|
1099
|
-
_bsr_ensure_fits(dest, nnz=nnz)
|
|
1100
|
-
|
|
1101
|
-
from warp.context import runtime
|
|
1102
|
-
|
|
1103
|
-
if dest.values.device.is_cpu:
|
|
1104
|
-
native_func = runtime.core.wp_bsr_transpose_host
|
|
1105
|
-
else:
|
|
1106
|
-
native_func = runtime.core.wp_bsr_transpose_device
|
|
1107
|
-
|
|
1108
|
-
block_index_map = wp.empty(shape=2 * nnz, dtype=int, device=src.device)
|
|
1109
|
-
|
|
1110
|
-
with wp.ScopedDevice(dest.device):
|
|
1111
|
-
native_func(
|
|
1112
|
-
src.nrow,
|
|
1113
|
-
src.ncol,
|
|
1114
|
-
nnz,
|
|
1115
|
-
ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1116
|
-
ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1117
|
-
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1118
|
-
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1119
|
-
ctypes.cast(block_index_map.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1120
|
-
)
|
|
1121
|
-
|
|
1122
|
-
dest.copy_nnz_async()
|
|
1123
|
-
|
|
1124
|
-
wp.launch(
|
|
1125
|
-
_bsr_transpose_values,
|
|
1126
|
-
dim=(nnz, *dest.block_shape),
|
|
1127
|
-
device=dest.device,
|
|
1128
|
-
inputs=[src.ncol, dest.scalar_type(src_scale), src.scalar_values, block_index_map, dest.offsets],
|
|
1129
|
-
outputs=[dest.scalar_values],
|
|
1130
|
-
)
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
|
|
1134
|
-
"""Return a copy of the transposed matrix ``A``."""
|
|
1135
|
-
|
|
1136
|
-
if A.block_shape == (1, 1):
|
|
1137
|
-
block_type = A.values.dtype
|
|
1138
|
-
else:
|
|
1139
|
-
block_type = wp.mat(shape=A.block_shape[::-1], dtype=A.scalar_type)
|
|
1140
|
-
|
|
1141
|
-
transposed = bsr_zeros(
|
|
1142
|
-
rows_of_blocks=A.ncol,
|
|
1143
|
-
cols_of_blocks=A.nrow,
|
|
1144
|
-
block_type=block_type,
|
|
1145
|
-
device=A.device,
|
|
1146
|
-
)
|
|
1147
|
-
transposed.values.requires_grad = A.requires_grad
|
|
1148
|
-
bsr_set_transpose(dest=transposed, src=A)
|
|
1149
|
-
return transposed
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
@wp.kernel
|
|
1153
|
-
def _bsr_get_diag_kernel(
|
|
1154
|
-
scale: Any,
|
|
1155
|
-
A_offsets: wp.array(dtype=int),
|
|
1156
|
-
A_columns: wp.array(dtype=int),
|
|
1157
|
-
A_values: wp.array3d(dtype=Any),
|
|
1158
|
-
out: wp.array3d(dtype=Any),
|
|
1159
|
-
):
|
|
1160
|
-
row, br, bc = wp.tid()
|
|
1161
|
-
|
|
1162
|
-
diag = _bsr_block_index(row, row, A_offsets, A_columns)
|
|
1163
|
-
if diag != -1:
|
|
1164
|
-
out[row, br, bc] = scale * A_values[diag, br, bc]
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
|
|
1168
|
-
"""Return the array of blocks that constitute the diagonal of a sparse matrix.
|
|
1169
|
-
|
|
1170
|
-
Args:
|
|
1171
|
-
A: The sparse matrix from which to extract the diagonal.
|
|
1172
|
-
out: If provided, the array into which to store the diagonal blocks.
|
|
1173
|
-
"""
|
|
1174
|
-
|
|
1175
|
-
A, scale = _extract_matrix_and_scale(A)
|
|
1176
|
-
|
|
1177
|
-
dim = min(A.nrow, A.ncol)
|
|
1178
|
-
|
|
1179
|
-
if out is None:
|
|
1180
|
-
out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
|
|
1181
|
-
else:
|
|
1182
|
-
if not types_equal(out.dtype, A.values.dtype):
|
|
1183
|
-
raise ValueError(f"Output array must have type {A.values.dtype}, got {out.dtype}")
|
|
1184
|
-
if out.device != A.values.device:
|
|
1185
|
-
raise ValueError(f"Output array must reside on device {A.values.device}, got {out.device}")
|
|
1186
|
-
if out.shape[0] < dim:
|
|
1187
|
-
raise ValueError(f"Output array must be of length at least {dim}, got {out.shape[0]}")
|
|
1188
|
-
|
|
1189
|
-
wp.launch(
|
|
1190
|
-
kernel=_bsr_get_diag_kernel,
|
|
1191
|
-
dim=(dim, *A.block_shape),
|
|
1192
|
-
device=A.values.device,
|
|
1193
|
-
inputs=[A.scalar_type(scale), A.offsets, A.columns, A.scalar_values, _as_3d_array(out, A.block_shape)],
|
|
1194
|
-
)
|
|
1195
|
-
|
|
1196
|
-
return out
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
@wp.kernel(enable_backward=False)
|
|
1200
|
-
def _bsr_set_diag_kernel(
|
|
1201
|
-
nnz: int,
|
|
1202
|
-
A_offsets: wp.array(dtype=int),
|
|
1203
|
-
A_columns: wp.array(dtype=int),
|
|
1204
|
-
):
|
|
1205
|
-
row = wp.tid()
|
|
1206
|
-
A_offsets[row] = wp.min(row, nnz)
|
|
1207
|
-
if row < nnz:
|
|
1208
|
-
A_columns[row] = row
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
def bsr_set_diag(
|
|
1212
|
-
A: BsrMatrix[BlockType],
|
|
1213
|
-
diag: "Union[BlockType, Array[BlockType]]",
|
|
1214
|
-
rows_of_blocks: Optional[int] = None,
|
|
1215
|
-
cols_of_blocks: Optional[int] = None,
|
|
1216
|
-
) -> None:
|
|
1217
|
-
"""Set ``A`` as a block-diagonal matrix.
|
|
1218
|
-
|
|
1219
|
-
Args:
|
|
1220
|
-
A: The sparse matrix to modify.
|
|
1221
|
-
diag: Specifies the values for diagonal blocks. Can be one of:
|
|
1222
|
-
|
|
1223
|
-
- A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
|
|
1224
|
-
- A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
|
|
1225
|
-
- ``None``: Diagonal block values are left uninitialized
|
|
1226
|
-
|
|
1227
|
-
rows_of_blocks: If not ``None``, the new number of rows of blocks.
|
|
1228
|
-
cols_of_blocks: If not ``None``, the new number of columns of blocks.
|
|
1229
|
-
|
|
1230
|
-
The shape of the matrix will be defined one of the following, in this order:
|
|
1231
|
-
|
|
1232
|
-
- ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
|
|
1233
|
-
If only one is given, the second is assumed equal.
|
|
1234
|
-
- The first dimension of ``diag``, if ``diag`` is an array
|
|
1235
|
-
- The current dimensions of ``A`` otherwise
|
|
1236
|
-
"""
|
|
1237
|
-
|
|
1238
|
-
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
1239
|
-
rows_of_blocks = cols_of_blocks
|
|
1240
|
-
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
1241
|
-
cols_of_blocks = rows_of_blocks
|
|
1242
|
-
|
|
1243
|
-
if is_array(diag):
|
|
1244
|
-
if rows_of_blocks is None:
|
|
1245
|
-
rows_of_blocks = diag.shape[0]
|
|
1246
|
-
cols_of_blocks = diag.shape[0]
|
|
1247
|
-
|
|
1248
|
-
if rows_of_blocks is not None:
|
|
1249
|
-
A.nrow = rows_of_blocks
|
|
1250
|
-
A.ncol = cols_of_blocks
|
|
1251
|
-
|
|
1252
|
-
nnz = min(A.nrow, A.ncol)
|
|
1253
|
-
_bsr_ensure_fits(A, nnz=nnz)
|
|
1254
|
-
|
|
1255
|
-
wp.launch(
|
|
1256
|
-
kernel=_bsr_set_diag_kernel,
|
|
1257
|
-
dim=nnz + 1,
|
|
1258
|
-
device=A.offsets.device,
|
|
1259
|
-
inputs=[nnz, A.offsets, A.columns],
|
|
1260
|
-
)
|
|
1261
|
-
|
|
1262
|
-
if is_array(diag):
|
|
1263
|
-
wp.copy(src=diag, dest=A.values, count=nnz)
|
|
1264
|
-
elif diag is not None:
|
|
1265
|
-
A.values.fill_(diag)
|
|
1266
|
-
|
|
1267
|
-
A.copy_nnz_async()
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
def bsr_diag(
|
|
1271
|
-
diag: Optional[Union[BlockType, Array[BlockType]]] = None,
|
|
1272
|
-
rows_of_blocks: Optional[int] = None,
|
|
1273
|
-
cols_of_blocks: Optional[int] = None,
|
|
1274
|
-
block_type: Optional[BlockType] = None,
|
|
1275
|
-
device=None,
|
|
1276
|
-
) -> BsrMatrix["BlockType"]:
|
|
1277
|
-
"""Create and return a block-diagonal BSR matrix from an given block value or array of block values.
|
|
1278
|
-
|
|
1279
|
-
Args:
|
|
1280
|
-
diag: Specifies the values for diagonal blocks. Can be one of:
|
|
1281
|
-
|
|
1282
|
-
- A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
|
|
1283
|
-
- A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
|
|
1284
|
-
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
1285
|
-
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
1286
|
-
block_type: If ``diag`` is ``None``, block type of the matrix. Otherwise deduced from ``diag``
|
|
1287
|
-
device: If ``diag`` is not a Warp array, device on which to allocate the matrix. Otherwise deduced from ``diag``
|
|
1288
|
-
|
|
1289
|
-
The shape of the matrix will be defined one of the following, in this order:
|
|
1290
|
-
|
|
1291
|
-
- ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
|
|
1292
|
-
If only one is given, the second is assumed equal.
|
|
1293
|
-
- The first dimension of ``diag`` if ``diag`` is an array.
|
|
1294
|
-
"""
|
|
1295
|
-
|
|
1296
|
-
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
1297
|
-
rows_of_blocks = cols_of_blocks
|
|
1298
|
-
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
1299
|
-
cols_of_blocks = rows_of_blocks
|
|
1300
|
-
|
|
1301
|
-
if is_array(diag):
|
|
1302
|
-
if rows_of_blocks is None:
|
|
1303
|
-
rows_of_blocks = diag.shape[0]
|
|
1304
|
-
cols_of_blocks = diag.shape[0]
|
|
1305
|
-
|
|
1306
|
-
block_type = diag.dtype
|
|
1307
|
-
device = diag.device
|
|
1308
|
-
else:
|
|
1309
|
-
if rows_of_blocks is None:
|
|
1310
|
-
raise ValueError(
|
|
1311
|
-
"rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
|
|
1312
|
-
)
|
|
1313
|
-
|
|
1314
|
-
if block_type is None:
|
|
1315
|
-
if diag is None:
|
|
1316
|
-
raise ValueError("Either `diag` or `block_type` needs to be provided")
|
|
1317
|
-
|
|
1318
|
-
block_type = type(diag)
|
|
1319
|
-
if not type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
|
|
1320
|
-
block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
|
|
1321
|
-
|
|
1322
|
-
A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
|
|
1323
|
-
if is_array(diag):
|
|
1324
|
-
A.values.requires_grad = diag.requires_grad
|
|
1325
|
-
bsr_set_diag(A, diag)
|
|
1326
|
-
return A
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None) -> None:
|
|
1330
|
-
"""Set ``A`` as the identity matrix.
|
|
1331
|
-
|
|
1332
|
-
Args:
|
|
1333
|
-
A: The sparse matrix to modify.
|
|
1334
|
-
rows_of_blocks: If provided, the matrix will be resized as a square
|
|
1335
|
-
matrix with ``rows_of_blocks`` rows and columns.
|
|
1336
|
-
"""
|
|
1337
|
-
|
|
1338
|
-
if A.block_shape == (1, 1):
|
|
1339
|
-
identity = A.scalar_type(1.0)
|
|
1340
|
-
else:
|
|
1341
|
-
from numpy import eye
|
|
1342
|
-
|
|
1343
|
-
identity = eye(A.block_shape[0])
|
|
1344
|
-
|
|
1345
|
-
bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
def bsr_identity(
|
|
1349
|
-
rows_of_blocks: int,
|
|
1350
|
-
block_type: BlockType[Rows, Rows, Scalar],
|
|
1351
|
-
device: wp.context.Devicelike = None,
|
|
1352
|
-
) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
|
|
1353
|
-
"""Create and return a square identity matrix.
|
|
1354
|
-
|
|
1355
|
-
Args:
|
|
1356
|
-
rows_of_blocks: Number of rows and columns of blocks in the created matrix.
|
|
1357
|
-
block_type: Block type for the newly created matrix. Must be square
|
|
1358
|
-
device: Device onto which to allocate the data arrays
|
|
1359
|
-
"""
|
|
1360
|
-
A = bsr_zeros(
|
|
1361
|
-
rows_of_blocks=rows_of_blocks,
|
|
1362
|
-
cols_of_blocks=rows_of_blocks,
|
|
1363
|
-
block_type=block_type,
|
|
1364
|
-
device=device,
|
|
1365
|
-
)
|
|
1366
|
-
bsr_set_identity(A)
|
|
1367
|
-
return A
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
@wp.kernel
|
|
1371
|
-
def _bsr_scale_kernel(
|
|
1372
|
-
alpha: Any,
|
|
1373
|
-
values: wp.array(dtype=Any),
|
|
1374
|
-
):
|
|
1375
|
-
row = wp.tid()
|
|
1376
|
-
values[row] = alpha * values[row]
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
@wp.kernel
|
|
1380
|
-
def _bsr_scale_kernel(
|
|
1381
|
-
alpha: Any,
|
|
1382
|
-
values: wp.array3d(dtype=Any),
|
|
1383
|
-
):
|
|
1384
|
-
row, br, bc = wp.tid()
|
|
1385
|
-
values[row, br, bc] = alpha * values[row, br, bc]
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
1389
|
-
"""Perform the operation ``x := alpha * x`` on BSR matrix ``x`` and return ``x``."""
|
|
1390
|
-
|
|
1391
|
-
x, scale = _extract_matrix_and_scale(x)
|
|
1392
|
-
alpha *= scale
|
|
1393
|
-
|
|
1394
|
-
if alpha != 1.0 and x.nnz > 0:
|
|
1395
|
-
if alpha == 0.0:
|
|
1396
|
-
bsr_set_zero(x)
|
|
1397
|
-
else:
|
|
1398
|
-
alpha = x.scalar_type(alpha)
|
|
1399
|
-
|
|
1400
|
-
wp.launch(
|
|
1401
|
-
kernel=_bsr_scale_kernel,
|
|
1402
|
-
dim=(x.nnz, *x.block_shape),
|
|
1403
|
-
device=x.values.device,
|
|
1404
|
-
inputs=[alpha, x.scalar_values],
|
|
1405
|
-
)
|
|
1406
|
-
|
|
1407
|
-
return x
|
|
1408
|
-
|
|
1409
|
-
|
|
1410
|
-
@wp.kernel(enable_backward=False)
|
|
1411
|
-
def _bsr_get_block_row(row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
|
|
1412
|
-
block = wp.tid()
|
|
1413
|
-
rows[block] = _bsr_row_index(bsr_offsets, row_count, block)
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
@wp.kernel
|
|
1417
|
-
def _bsr_axpy_add_block(
|
|
1418
|
-
src_offset: int,
|
|
1419
|
-
scale: Any,
|
|
1420
|
-
rows: wp.array(dtype=int),
|
|
1421
|
-
cols: wp.array(dtype=int),
|
|
1422
|
-
dst_offsets: wp.array(dtype=int),
|
|
1423
|
-
dst_columns: wp.array(dtype=int),
|
|
1424
|
-
src_values: wp.array3d(dtype=Any),
|
|
1425
|
-
dst_values: wp.array3d(dtype=Any),
|
|
1426
|
-
):
|
|
1427
|
-
i, br, bc = wp.tid()
|
|
1428
|
-
row = rows[i + src_offset]
|
|
1429
|
-
col = cols[i + src_offset]
|
|
1430
|
-
|
|
1431
|
-
block = _bsr_block_index(row, col, dst_offsets, dst_columns)
|
|
1432
|
-
if block != -1:
|
|
1433
|
-
dst_values[block, br, bc] += scale * src_values[i, br, bc]
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
class bsr_axpy_work_arrays:
|
|
1437
|
-
"""Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls."""
|
|
1438
|
-
|
|
1439
|
-
def __init__(self):
|
|
1440
|
-
self._reset(None)
|
|
1441
|
-
|
|
1442
|
-
def _reset(self, device):
|
|
1443
|
-
self.device = device
|
|
1444
|
-
self._sum_rows = None
|
|
1445
|
-
self._sum_cols = None
|
|
1446
|
-
self._old_y_values = None
|
|
1447
|
-
self._old_x_values = None
|
|
1448
|
-
|
|
1449
|
-
def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
|
|
1450
|
-
if self.device != device:
|
|
1451
|
-
self._reset(device)
|
|
1452
|
-
|
|
1453
|
-
if self._sum_rows is None or self._sum_rows.size < sum_nnz:
|
|
1454
|
-
self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
1455
|
-
if self._sum_cols is None or self._sum_cols.size < sum_nnz:
|
|
1456
|
-
self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
1457
|
-
|
|
1458
|
-
if self._old_y_values is None or self._old_y_values.size < y.nnz:
|
|
1459
|
-
self._old_y_values = wp.empty_like(y.values[: y.nnz])
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
def bsr_axpy(
|
|
1463
|
-
x: BsrMatrixOrExpression,
|
|
1464
|
-
y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
1465
|
-
alpha: Scalar = 1.0,
|
|
1466
|
-
beta: Scalar = 1.0,
|
|
1467
|
-
masked: bool = False,
|
|
1468
|
-
work_arrays: Optional[bsr_axpy_work_arrays] = None,
|
|
1469
|
-
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
1470
|
-
"""
|
|
1471
|
-
Perform the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices ``x`` and ``y`` and return ``y``.
|
|
1472
|
-
|
|
1473
|
-
The ``x`` and ``y`` matrices are allowed to alias.
|
|
1474
|
-
|
|
1475
|
-
Args:
|
|
1476
|
-
x: Read-only first operand.
|
|
1477
|
-
y: Mutable second operand and output matrix. If ``y`` is not provided, it will be allocated and treated as zero.
|
|
1478
|
-
alpha: Uniform scaling factor for ``x``.
|
|
1479
|
-
beta: Uniform scaling factor for ``y``.
|
|
1480
|
-
masked: If ``True``, discard all blocks from ``x`` which are not
|
|
1481
|
-
existing non-zeros of ``y``.
|
|
1482
|
-
work_arrays: In most cases, this function will require the use of temporary storage.
|
|
1483
|
-
This storage can be reused across calls by passing an instance of
|
|
1484
|
-
:class:`bsr_axpy_work_arrays` in ``work_arrays``.
|
|
1485
|
-
"""
|
|
1486
|
-
|
|
1487
|
-
x, x_scale = _extract_matrix_and_scale(x)
|
|
1488
|
-
alpha *= x_scale
|
|
1489
|
-
|
|
1490
|
-
if y is None:
|
|
1491
|
-
if masked:
|
|
1492
|
-
raise ValueError("Left-hand-side 'y' matrix must be provided for masked addition")
|
|
1493
|
-
|
|
1494
|
-
# If not output matrix is provided, allocate it for convenience
|
|
1495
|
-
y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
|
|
1496
|
-
y.values.requires_grad = x.requires_grad
|
|
1497
|
-
beta = 0.0
|
|
1498
|
-
|
|
1499
|
-
x_nnz = x.nnz
|
|
1500
|
-
y_nnz = y.nnz
|
|
1501
|
-
|
|
1502
|
-
# Handle easy cases first
|
|
1503
|
-
if beta == 0.0 or y_nnz == 0:
|
|
1504
|
-
bsr_assign(src=x, dest=y)
|
|
1505
|
-
return bsr_scale(y, alpha=alpha)
|
|
1506
|
-
|
|
1507
|
-
if alpha == 0.0 or x_nnz == 0:
|
|
1508
|
-
return bsr_scale(y, alpha=beta)
|
|
1509
|
-
|
|
1510
|
-
if not isinstance(alpha, y.scalar_type):
|
|
1511
|
-
alpha = y.scalar_type(alpha)
|
|
1512
|
-
if not isinstance(beta, y.scalar_type):
|
|
1513
|
-
beta = y.scalar_type(beta)
|
|
1514
|
-
|
|
1515
|
-
if x == y:
|
|
1516
|
-
# Aliasing case
|
|
1517
|
-
return bsr_scale(y, alpha=alpha.value + beta.value)
|
|
1518
|
-
|
|
1519
|
-
# General case
|
|
1520
|
-
|
|
1521
|
-
if x.values.device != y.values.device:
|
|
1522
|
-
raise ValueError(f"All arguments must reside on the same device, got {x.values.device} and {y.values.device}")
|
|
1523
|
-
|
|
1524
|
-
if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
|
|
1525
|
-
raise ValueError(
|
|
1526
|
-
f"Matrices must have the same block type, got ({x.block_shape}, {x.scalar_type}) and ({y.block_shape}, {y.scalar_type})"
|
|
1527
|
-
)
|
|
1528
|
-
|
|
1529
|
-
if x.nrow != y.nrow or x.ncol != y.ncol:
|
|
1530
|
-
raise ValueError(
|
|
1531
|
-
f"Matrices must have the same number of rows and columns, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
|
|
1532
|
-
)
|
|
1533
|
-
|
|
1534
|
-
if work_arrays is None:
|
|
1535
|
-
work_arrays = bsr_axpy_work_arrays()
|
|
1536
|
-
|
|
1537
|
-
sum_nnz = x_nnz + y_nnz
|
|
1538
|
-
device = y.values.device
|
|
1539
|
-
work_arrays._allocate(device, y, sum_nnz)
|
|
1540
|
-
|
|
1541
|
-
wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
|
|
1542
|
-
y.uncompress_rows(out=work_arrays._sum_rows)
|
|
1543
|
-
|
|
1544
|
-
wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
|
|
1545
|
-
x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
|
|
1546
|
-
|
|
1547
|
-
# Save old y values before overwriting matrix
|
|
1548
|
-
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
|
|
1549
|
-
|
|
1550
|
-
# Increase dest array sizes if needed
|
|
1551
|
-
if not masked:
|
|
1552
|
-
_bsr_ensure_fits(y, nnz=sum_nnz)
|
|
1553
|
-
|
|
1554
|
-
from warp.context import runtime
|
|
1555
|
-
|
|
1556
|
-
if device.is_cpu:
|
|
1557
|
-
native_func = runtime.core.wp_bsr_matrix_from_triplets_host
|
|
1558
|
-
else:
|
|
1559
|
-
native_func = runtime.core.wp_bsr_matrix_from_triplets_device
|
|
1560
|
-
|
|
1561
|
-
old_y_nnz = y_nnz
|
|
1562
|
-
nnz_buf, nnz_event = y._setup_nnz_transfer()
|
|
1563
|
-
|
|
1564
|
-
with wp.ScopedDevice(y.device):
|
|
1565
|
-
native_func(
|
|
1566
|
-
y.block_size,
|
|
1567
|
-
0, # scalar_size_in_bytes
|
|
1568
|
-
y.nrow,
|
|
1569
|
-
y.ncol,
|
|
1570
|
-
sum_nnz,
|
|
1571
|
-
None, # device nnz
|
|
1572
|
-
ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1573
|
-
ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1574
|
-
None, # triplet values
|
|
1575
|
-
0, # zero_value_mask
|
|
1576
|
-
masked,
|
|
1577
|
-
None, # summed block offsets
|
|
1578
|
-
None, # summed block indices
|
|
1579
|
-
ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1580
|
-
ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1581
|
-
_optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
|
|
1582
|
-
_optional_ctypes_event(nnz_event),
|
|
1583
|
-
)
|
|
1584
|
-
|
|
1585
|
-
y.values.zero_()
|
|
1586
|
-
|
|
1587
|
-
wp.launch(
|
|
1588
|
-
kernel=_bsr_axpy_add_block,
|
|
1589
|
-
device=device,
|
|
1590
|
-
dim=(old_y_nnz, y.block_shape[0], y.block_shape[1]),
|
|
1591
|
-
inputs=[
|
|
1592
|
-
0,
|
|
1593
|
-
beta,
|
|
1594
|
-
work_arrays._sum_rows,
|
|
1595
|
-
work_arrays._sum_cols,
|
|
1596
|
-
y.offsets,
|
|
1597
|
-
y.columns,
|
|
1598
|
-
_as_3d_array(work_arrays._old_y_values, y.block_shape),
|
|
1599
|
-
y.scalar_values,
|
|
1600
|
-
],
|
|
1601
|
-
)
|
|
1602
|
-
|
|
1603
|
-
wp.launch(
|
|
1604
|
-
kernel=_bsr_axpy_add_block,
|
|
1605
|
-
device=device,
|
|
1606
|
-
dim=(x_nnz, y.block_shape[0], y.block_shape[1]),
|
|
1607
|
-
inputs=[
|
|
1608
|
-
old_y_nnz,
|
|
1609
|
-
alpha,
|
|
1610
|
-
work_arrays._sum_rows,
|
|
1611
|
-
work_arrays._sum_cols,
|
|
1612
|
-
y.offsets,
|
|
1613
|
-
y.columns,
|
|
1614
|
-
x.scalar_values,
|
|
1615
|
-
y.scalar_values,
|
|
1616
|
-
],
|
|
1617
|
-
)
|
|
1618
|
-
|
|
1619
|
-
return y
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
def make_bsr_mm_count_coeffs(tile_size):
|
|
1623
|
-
from warp.fem.cache import dynamic_kernel
|
|
1624
|
-
|
|
1625
|
-
@dynamic_kernel(suffix=tile_size)
|
|
1626
|
-
def bsr_mm_count_coeffs(
|
|
1627
|
-
y_ncol: int,
|
|
1628
|
-
z_nnz: int,
|
|
1629
|
-
x_offsets: wp.array(dtype=int),
|
|
1630
|
-
x_columns: wp.array(dtype=int),
|
|
1631
|
-
y_offsets: wp.array(dtype=int),
|
|
1632
|
-
y_columns: wp.array(dtype=int),
|
|
1633
|
-
row_min: wp.array(dtype=int),
|
|
1634
|
-
block_counts: wp.array(dtype=int),
|
|
1635
|
-
):
|
|
1636
|
-
row, lane = wp.tid()
|
|
1637
|
-
row_count = int(0)
|
|
1638
|
-
|
|
1639
|
-
x_beg = x_offsets[row]
|
|
1640
|
-
x_end = x_offsets[row + 1]
|
|
1641
|
-
|
|
1642
|
-
min_col = y_ncol
|
|
1643
|
-
max_col = int(0)
|
|
1644
|
-
|
|
1645
|
-
for x_block in range(x_beg + lane, x_end, tile_size):
|
|
1646
|
-
x_col = x_columns[x_block]
|
|
1647
|
-
y_row_end = y_offsets[x_col + 1]
|
|
1648
|
-
y_row_beg = y_offsets[x_col]
|
|
1649
|
-
block_count = y_row_end - y_row_beg
|
|
1650
|
-
if block_count != 0:
|
|
1651
|
-
min_col = wp.min(y_columns[y_row_beg], min_col)
|
|
1652
|
-
max_col = wp.max(y_columns[y_row_end - 1], max_col)
|
|
1653
|
-
|
|
1654
|
-
block_counts[x_block + 1] = block_count
|
|
1655
|
-
row_count += block_count
|
|
1656
|
-
|
|
1657
|
-
if wp.static(tile_size) > 1:
|
|
1658
|
-
row_count = wp.tile_sum(wp.tile(row_count))[0]
|
|
1659
|
-
min_col = wp.tile_min(wp.tile(min_col))[0]
|
|
1660
|
-
max_col = wp.tile_max(wp.tile(max_col))[0]
|
|
1661
|
-
col_range_size = wp.max(0, max_col - min_col + 1)
|
|
1662
|
-
|
|
1663
|
-
if row_count > col_range_size:
|
|
1664
|
-
# Optimization for deep products.
|
|
1665
|
-
# Do not store the whole whole list of src product terms, they would be highly redundant
|
|
1666
|
-
# Instead just mark a range in the output matrix
|
|
1667
|
-
|
|
1668
|
-
if lane == 0:
|
|
1669
|
-
row_min[row] = min_col
|
|
1670
|
-
block_counts[x_end] = col_range_size
|
|
1671
|
-
|
|
1672
|
-
for x_block in range(x_beg + lane, x_end - 1, tile_size):
|
|
1673
|
-
block_counts[x_block + 1] = 0
|
|
1674
|
-
elif lane == 0:
|
|
1675
|
-
row_min[row] = -1
|
|
1676
|
-
|
|
1677
|
-
if lane == 0 and row == 0:
|
|
1678
|
-
block_counts[0] = z_nnz
|
|
1679
|
-
|
|
1680
|
-
return bsr_mm_count_coeffs
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
@wp.kernel(enable_backward=False)
|
|
1684
|
-
def _bsr_mm_list_coeffs(
|
|
1685
|
-
copied_z_nnz: int,
|
|
1686
|
-
x_nrow: int,
|
|
1687
|
-
x_offsets: wp.array(dtype=int),
|
|
1688
|
-
x_columns: wp.array(dtype=int),
|
|
1689
|
-
y_offsets: wp.array(dtype=int),
|
|
1690
|
-
y_columns: wp.array(dtype=int),
|
|
1691
|
-
mm_row_min: wp.array(dtype=int),
|
|
1692
|
-
mm_offsets: wp.array(dtype=int),
|
|
1693
|
-
mm_rows: wp.array(dtype=int),
|
|
1694
|
-
mm_cols: wp.array(dtype=int),
|
|
1695
|
-
mm_src_blocks: wp.array(dtype=int),
|
|
1696
|
-
):
|
|
1697
|
-
mm_block = wp.tid() + copied_z_nnz
|
|
1698
|
-
|
|
1699
|
-
x_nnz = x_offsets[x_nrow]
|
|
1700
|
-
x_block = wp.lower_bound(mm_offsets, 0, x_nnz + 1, mm_block + 1) - 1
|
|
1701
|
-
pos = mm_block - mm_offsets[x_block]
|
|
1702
|
-
|
|
1703
|
-
row = _bsr_row_index(x_offsets, x_nrow, x_block)
|
|
1704
|
-
|
|
1705
|
-
row_min_col = mm_row_min[row]
|
|
1706
|
-
if row_min_col == -1:
|
|
1707
|
-
x_col = x_columns[x_block]
|
|
1708
|
-
y_beg = y_offsets[x_col]
|
|
1709
|
-
y_block = y_beg + pos
|
|
1710
|
-
col = y_columns[y_block]
|
|
1711
|
-
src_block = x_block
|
|
1712
|
-
else:
|
|
1713
|
-
col = row_min_col + pos
|
|
1714
|
-
src_block = -1
|
|
1715
|
-
|
|
1716
|
-
mm_cols[mm_block] = col
|
|
1717
|
-
mm_rows[mm_block] = row
|
|
1718
|
-
mm_src_blocks[mm_block] = src_block
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
@wp.func
|
|
1722
|
-
def _bsr_mm_use_triplets(
|
|
1723
|
-
row: int,
|
|
1724
|
-
mm_block: int,
|
|
1725
|
-
mm_row_min: wp.array(dtype=int),
|
|
1726
|
-
row_offsets: wp.array(dtype=int),
|
|
1727
|
-
summed_triplet_offsets: wp.array(dtype=int),
|
|
1728
|
-
):
|
|
1729
|
-
x_beg = row_offsets[row]
|
|
1730
|
-
x_end = row_offsets[row + 1]
|
|
1731
|
-
|
|
1732
|
-
if mm_row_min:
|
|
1733
|
-
if mm_row_min[row] == -1:
|
|
1734
|
-
if mm_block == 0:
|
|
1735
|
-
block_beg = 0
|
|
1736
|
-
else:
|
|
1737
|
-
block_beg = summed_triplet_offsets[mm_block - 1]
|
|
1738
|
-
block_end = summed_triplet_offsets[mm_block]
|
|
1739
|
-
|
|
1740
|
-
if x_end - x_beg > 3 * (block_end - block_beg):
|
|
1741
|
-
return True, block_beg, block_end
|
|
1742
|
-
|
|
1743
|
-
return False, x_beg, x_end
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
@wp.kernel(enable_backward=False)
|
|
1747
|
-
def _bsr_mm_compute_values(
|
|
1748
|
-
alpha: Any,
|
|
1749
|
-
x_offsets: wp.array(dtype=int),
|
|
1750
|
-
x_columns: wp.array(dtype=int),
|
|
1751
|
-
x_values: wp.array(dtype=Any),
|
|
1752
|
-
y_offsets: wp.array(dtype=int),
|
|
1753
|
-
y_columns: wp.array(dtype=int),
|
|
1754
|
-
y_values: wp.array(dtype=Any),
|
|
1755
|
-
mm_row_min: wp.array(dtype=int),
|
|
1756
|
-
summed_triplet_offsets: wp.array(dtype=int),
|
|
1757
|
-
summed_triplet_src_blocks: wp.indexedarray(dtype=int),
|
|
1758
|
-
mm_row_count: int,
|
|
1759
|
-
mm_offsets: wp.array(dtype=int),
|
|
1760
|
-
mm_cols: wp.array(dtype=int),
|
|
1761
|
-
mm_values: wp.array(dtype=Any),
|
|
1762
|
-
):
|
|
1763
|
-
mm_block = wp.tid()
|
|
1764
|
-
|
|
1765
|
-
row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
|
|
1766
|
-
if row == -1:
|
|
1767
|
-
return
|
|
1768
|
-
|
|
1769
|
-
use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
|
|
1770
|
-
row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
|
|
1771
|
-
)
|
|
1772
|
-
|
|
1773
|
-
mm_val = mm_values.dtype(type(alpha)(0.0))
|
|
1774
|
-
col = mm_cols[mm_block]
|
|
1775
|
-
if use_triplets:
|
|
1776
|
-
for tpl_idx in range(block_beg, block_end):
|
|
1777
|
-
x_block = summed_triplet_src_blocks[tpl_idx]
|
|
1778
|
-
x_col = x_columns[x_block]
|
|
1779
|
-
if x_block != -1:
|
|
1780
|
-
y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
|
|
1781
|
-
mm_val += x_values[x_block] * y_values[y_block]
|
|
1782
|
-
else:
|
|
1783
|
-
for x_block in range(block_beg, block_end):
|
|
1784
|
-
x_col = x_columns[x_block]
|
|
1785
|
-
y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
|
|
1786
|
-
if y_block != -1:
|
|
1787
|
-
mm_val += x_values[x_block] * y_values[y_block]
|
|
1788
|
-
|
|
1789
|
-
mm_values[mm_block] += alpha * mm_val
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
def make_bsr_mm_compute_values_tiled_outer(subblock_rows, subblock_cols, block_depth, scalar_type, tile_size):
|
|
1793
|
-
from warp.fem.cache import dynamic_func, dynamic_kernel
|
|
1794
|
-
|
|
1795
|
-
mm_type = wp.mat(dtype=scalar_type, shape=(subblock_rows, subblock_cols))
|
|
1796
|
-
|
|
1797
|
-
x_col_vec_t = wp.vec(dtype=scalar_type, length=subblock_rows)
|
|
1798
|
-
y_row_vec_t = wp.vec(dtype=scalar_type, length=subblock_cols)
|
|
1799
|
-
|
|
1800
|
-
suffix = f"{subblock_rows}{subblock_cols}{block_depth}{tile_size}{scalar_type.__name__}"
|
|
1801
|
-
|
|
1802
|
-
@dynamic_func(suffix=suffix)
|
|
1803
|
-
def _outer_product(
|
|
1804
|
-
x_values: wp.array2d(dtype=scalar_type),
|
|
1805
|
-
y_values: wp.array2d(dtype=scalar_type),
|
|
1806
|
-
brow_off: int,
|
|
1807
|
-
bcol_off: int,
|
|
1808
|
-
block_col: int,
|
|
1809
|
-
brow_count: int,
|
|
1810
|
-
bcol_count: int,
|
|
1811
|
-
):
|
|
1812
|
-
x_col_vec = x_col_vec_t()
|
|
1813
|
-
y_row_vec = y_row_vec_t()
|
|
1814
|
-
|
|
1815
|
-
for k in range(brow_count):
|
|
1816
|
-
x_col_vec[k] = x_values[brow_off + k, block_col]
|
|
1817
|
-
for k in range(bcol_count):
|
|
1818
|
-
y_row_vec[k] = y_values[block_col, bcol_off + k]
|
|
1819
|
-
|
|
1820
|
-
return wp.outer(x_col_vec, y_row_vec)
|
|
1821
|
-
|
|
1822
|
-
@dynamic_kernel(suffix=suffix, kernel_options={"enable_backward": False})
|
|
1823
|
-
def bsr_mm_compute_values(
|
|
1824
|
-
alpha: scalar_type,
|
|
1825
|
-
x_offsets: wp.array(dtype=int),
|
|
1826
|
-
x_columns: wp.array(dtype=int),
|
|
1827
|
-
x_values: wp.array3d(dtype=scalar_type),
|
|
1828
|
-
y_offsets: wp.array(dtype=int),
|
|
1829
|
-
y_columns: wp.array(dtype=int),
|
|
1830
|
-
y_values: wp.array3d(dtype=scalar_type),
|
|
1831
|
-
mm_row_min: wp.array(dtype=int),
|
|
1832
|
-
summed_triplet_offsets: wp.array(dtype=int),
|
|
1833
|
-
summed_triplet_src_blocks: wp.indexedarray(dtype=int),
|
|
1834
|
-
mm_row_count: int,
|
|
1835
|
-
mm_offsets: wp.array(dtype=int),
|
|
1836
|
-
mm_cols: wp.array(dtype=int),
|
|
1837
|
-
mm_values: wp.array3d(dtype=scalar_type),
|
|
1838
|
-
):
|
|
1839
|
-
mm_block, subrow, subcol, lane = wp.tid()
|
|
1840
|
-
|
|
1841
|
-
brow_off = subrow * wp.static(subblock_rows)
|
|
1842
|
-
bcol_off = subcol * wp.static(subblock_cols)
|
|
1843
|
-
|
|
1844
|
-
brow_count = wp.min(mm_values.shape[1] - brow_off, subblock_rows)
|
|
1845
|
-
bcol_count = wp.min(mm_values.shape[2] - bcol_off, subblock_cols)
|
|
1846
|
-
|
|
1847
|
-
mm_row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
|
|
1848
|
-
if mm_row == -1:
|
|
1849
|
-
return
|
|
1850
|
-
|
|
1851
|
-
lane_val = mm_type()
|
|
1852
|
-
|
|
1853
|
-
use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
|
|
1854
|
-
mm_row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
|
|
1855
|
-
)
|
|
1856
|
-
|
|
1857
|
-
col_count = (block_end - block_beg) * block_depth
|
|
1858
|
-
|
|
1859
|
-
mm_col = mm_cols[mm_block]
|
|
1860
|
-
if use_triplets:
|
|
1861
|
-
for col in range(lane, col_count, tile_size):
|
|
1862
|
-
tpl_block = col // wp.static(block_depth)
|
|
1863
|
-
block_col = col - tpl_block * wp.static(block_depth)
|
|
1864
|
-
tpl_block += block_beg
|
|
1865
|
-
|
|
1866
|
-
x_block = summed_triplet_src_blocks[tpl_block]
|
|
1867
|
-
if x_block != -1:
|
|
1868
|
-
x_col = x_columns[x_block]
|
|
1869
|
-
y_block = _bsr_block_index(x_col, mm_col, y_offsets, y_columns)
|
|
1870
|
-
lane_val += _outer_product(
|
|
1871
|
-
x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
|
|
1872
|
-
)
|
|
1873
|
-
else:
|
|
1874
|
-
for col in range(lane, col_count, tile_size):
|
|
1875
|
-
x_block = col // wp.static(block_depth)
|
|
1876
|
-
block_col = col - x_block * wp.static(block_depth)
|
|
1877
|
-
x_block += block_beg
|
|
1878
|
-
|
|
1879
|
-
x_col = x_columns[x_block]
|
|
1880
|
-
y_block = _bsr_block_index(x_col, mm_col, y_offsets, y_columns)
|
|
1881
|
-
|
|
1882
|
-
if y_block != -1:
|
|
1883
|
-
lane_val += _outer_product(
|
|
1884
|
-
x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
|
|
1885
|
-
)
|
|
1886
|
-
|
|
1887
|
-
mm_val = wp.tile_sum(wp.tile(lane_val, preserve_type=True))[0]
|
|
1888
|
-
|
|
1889
|
-
for coef in range(lane, wp.static(subblock_cols * subblock_rows), tile_size):
|
|
1890
|
-
br = coef // subblock_cols
|
|
1891
|
-
bc = coef - br * subblock_cols
|
|
1892
|
-
if br < brow_count and bc < bcol_count:
|
|
1893
|
-
mm_values[mm_block, br + brow_off, bc + bcol_off] += mm_val[br, bc] * alpha
|
|
1894
|
-
|
|
1895
|
-
return bsr_mm_compute_values
|
|
1896
|
-
|
|
1897
|
-
|
|
1898
|
-
class bsr_mm_work_arrays:
|
|
1899
|
-
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
|
|
1900
|
-
|
|
1901
|
-
def __init__(self):
|
|
1902
|
-
self._reset(None)
|
|
1903
|
-
|
|
1904
|
-
def _reset(self, device):
|
|
1905
|
-
self.device = device
|
|
1906
|
-
self._mm_row_min = None
|
|
1907
|
-
self._mm_block_counts = None
|
|
1908
|
-
self._mm_rows = None
|
|
1909
|
-
self._mm_cols = None
|
|
1910
|
-
self._mm_src_blocks = None
|
|
1911
|
-
self._old_z_values = None
|
|
1912
|
-
self._old_z_offsets = None
|
|
1913
|
-
self._old_z_columns = None
|
|
1914
|
-
self._mm_nnz = 0
|
|
1915
|
-
|
|
1916
|
-
def _allocate_stage_1(self, device, x_nnz: int, z: BsrMatrix, beta: float, z_aliasing: bool):
|
|
1917
|
-
if self.device != device:
|
|
1918
|
-
self._reset(device)
|
|
1919
|
-
|
|
1920
|
-
# Allocations that do not depend on any computation
|
|
1921
|
-
z_nnz = z.nnz_sync()
|
|
1922
|
-
self._copied_z_nnz = z_nnz if beta != 0.0 or z_aliasing else 0
|
|
1923
|
-
|
|
1924
|
-
if self._mm_row_min is None or self._mm_block_counts.size < z.nrow + 1:
|
|
1925
|
-
self._mm_row_min = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
|
|
1926
|
-
if self._mm_block_counts is None or self._mm_block_counts.size < x_nnz + 1:
|
|
1927
|
-
self._mm_block_counts = wp.empty(shape=(x_nnz + 1,), dtype=int, device=self.device)
|
|
1928
|
-
|
|
1929
|
-
if self._copied_z_nnz > 0:
|
|
1930
|
-
if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
|
|
1931
|
-
self._old_z_values = wp.empty(shape=(self._copied_z_nnz,), dtype=z.values.dtype, device=self.device)
|
|
1932
|
-
|
|
1933
|
-
if z_aliasing:
|
|
1934
|
-
if self._old_z_columns is None or self._old_z_columns.size < z_nnz:
|
|
1935
|
-
self._old_z_columns = wp.empty(shape=(z_nnz,), dtype=z.columns.dtype, device=self.device)
|
|
1936
|
-
if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
|
|
1937
|
-
self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
|
|
1938
|
-
|
|
1939
|
-
def _allocate_stage_2(self, mm_nnz: int):
|
|
1940
|
-
# Allocations that depend on unmerged nnz estimate
|
|
1941
|
-
self._mm_nnz = mm_nnz
|
|
1942
|
-
if self._mm_rows is None or self._mm_rows.size < mm_nnz:
|
|
1943
|
-
self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
1944
|
-
if self._mm_cols is None or self._mm_cols.size < mm_nnz:
|
|
1945
|
-
self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
1946
|
-
if self._mm_src_blocks is None or self._mm_src_blocks.size < mm_nnz:
|
|
1947
|
-
self._mm_src_blocks = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
def bsr_mm(
|
|
1951
|
-
x: BsrMatrixOrExpression[BlockType[Rows, Any, Scalar]],
|
|
1952
|
-
y: BsrMatrixOrExpression[BlockType[Any, Cols, Scalar]],
|
|
1953
|
-
z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
1954
|
-
alpha: Scalar = 1.0,
|
|
1955
|
-
beta: Scalar = 0.0,
|
|
1956
|
-
masked: bool = False,
|
|
1957
|
-
work_arrays: Optional[bsr_mm_work_arrays] = None,
|
|
1958
|
-
reuse_topology: bool = False,
|
|
1959
|
-
tile_size: int = 0,
|
|
1960
|
-
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
1961
|
-
"""
|
|
1962
|
-
Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
|
|
1963
|
-
|
|
1964
|
-
The ``x``, ``y`` and ``z`` matrices are allowed to alias.
|
|
1965
|
-
If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
|
|
1966
|
-
|
|
1967
|
-
Args:
|
|
1968
|
-
x: Read-only left operand of the matrix-matrix product.
|
|
1969
|
-
y: Read-only right operand of the matrix-matrix product.
|
|
1970
|
-
z: Mutable affine operand and result matrix. If ``z`` is not provided, it will be allocated and treated as zero.
|
|
1971
|
-
alpha: Uniform scaling factor for the ``x @ y`` product
|
|
1972
|
-
beta: Uniform scaling factor for ``z``
|
|
1973
|
-
masked: If ``True``, ignore all blocks from ``x @ y`` which are not existing non-zeros of ``y``
|
|
1974
|
-
work_arrays: In most cases, this function will require the use of temporary storage.
|
|
1975
|
-
This storage can be reused across calls by passing an instance of
|
|
1976
|
-
:class:`bsr_mm_work_arrays` in ``work_arrays``.
|
|
1977
|
-
reuse_topology: If ``True``, reuse the product topology information
|
|
1978
|
-
stored in ``work_arrays`` rather than recompute it from scratch.
|
|
1979
|
-
The matrices ``x``, ``y`` and ``z`` must be structurally similar to
|
|
1980
|
-
the previous call in which ``work_arrays`` were populated.
|
|
1981
|
-
This is necessary for ``bsr_mm`` to be captured in a CUDA graph.
|
|
1982
|
-
tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
|
|
1983
|
-
If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
|
|
1984
|
-
use tiles using using an heuristic based on the matrix shape and number of non-zeros..
|
|
1985
|
-
"""
|
|
1986
|
-
|
|
1987
|
-
x, x_scale = _extract_matrix_and_scale(x)
|
|
1988
|
-
alpha *= x_scale
|
|
1989
|
-
y, y_scale = _extract_matrix_and_scale(y)
|
|
1990
|
-
alpha *= y_scale
|
|
1991
|
-
|
|
1992
|
-
if z is None:
|
|
1993
|
-
if masked:
|
|
1994
|
-
raise ValueError("Left-hand-side 'z' matrix must be provided for masked multiplication")
|
|
1995
|
-
|
|
1996
|
-
# If not output matrix is provided, allocate it for convenience
|
|
1997
|
-
z_block_shape = (x.block_shape[0], y.block_shape[1])
|
|
1998
|
-
if z_block_shape == (1, 1):
|
|
1999
|
-
z_block_type = x.scalar_type
|
|
2000
|
-
else:
|
|
2001
|
-
z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
|
|
2002
|
-
z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
|
|
2003
|
-
z.values.requires_grad = x.requires_grad or y.requires_grad
|
|
2004
|
-
beta = 0.0
|
|
2005
|
-
|
|
2006
|
-
if x.values.device != y.values.device or x.values.device != z.values.device:
|
|
2007
|
-
raise ValueError(
|
|
2008
|
-
f"All arguments must reside on the same device, got {x.values.device}, {y.values.device} and {z.values.device}"
|
|
2009
|
-
)
|
|
2010
|
-
|
|
2011
|
-
if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
|
|
2012
|
-
raise ValueError(
|
|
2013
|
-
f"Matrices must have the same scalar type, got {x.scalar_type}, {y.scalar_type} and {z.scalar_type}"
|
|
2014
|
-
)
|
|
2015
|
-
|
|
2016
|
-
if (
|
|
2017
|
-
x.block_shape[0] != z.block_shape[0]
|
|
2018
|
-
or y.block_shape[1] != z.block_shape[1]
|
|
2019
|
-
or x.block_shape[1] != y.block_shape[0]
|
|
2020
|
-
):
|
|
2021
|
-
raise ValueError(
|
|
2022
|
-
f"Incompatible block sizes for matrix multiplication, got ({x.block_shape}, {y.block_shape}) and ({z.block_shape})"
|
|
2023
|
-
)
|
|
2024
|
-
|
|
2025
|
-
if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
|
|
2026
|
-
raise ValueError(
|
|
2027
|
-
f"Incompatible number of rows/columns for matrix multiplication, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
|
|
2028
|
-
)
|
|
2029
|
-
|
|
2030
|
-
device = z.values.device
|
|
2031
|
-
|
|
2032
|
-
if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
|
|
2033
|
-
# Easy case
|
|
2034
|
-
return bsr_scale(z, beta)
|
|
2035
|
-
|
|
2036
|
-
z_aliasing = z == x or z == y
|
|
2037
|
-
|
|
2038
|
-
if masked:
|
|
2039
|
-
# no need to copy z, scale in-place
|
|
2040
|
-
copied_z_nnz = 0
|
|
2041
|
-
mm_nnz = z.nnz
|
|
2042
|
-
|
|
2043
|
-
if z_aliasing:
|
|
2044
|
-
raise ValueError("`masked=True` is not supported for aliased inputs")
|
|
2045
|
-
|
|
2046
|
-
if beta == 0.0:
|
|
2047
|
-
# do not bsr_scale(0), this would not preserve topology
|
|
2048
|
-
z.values.zero_()
|
|
2049
|
-
else:
|
|
2050
|
-
bsr_scale(z, beta)
|
|
2051
|
-
elif reuse_topology:
|
|
2052
|
-
if work_arrays is None:
|
|
2053
|
-
raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
|
|
2054
|
-
|
|
2055
|
-
copied_z_nnz = work_arrays._copied_z_nnz
|
|
2056
|
-
mm_nnz = work_arrays._mm_nnz
|
|
2057
|
-
else:
|
|
2058
|
-
if device.is_capturing:
|
|
2059
|
-
raise RuntimeError(
|
|
2060
|
-
"`bsr_mm` requires either `reuse_topology=True` or `masked=True` for use in graph capture"
|
|
2061
|
-
)
|
|
2062
|
-
|
|
2063
|
-
if work_arrays is None:
|
|
2064
|
-
work_arrays = bsr_mm_work_arrays()
|
|
2065
|
-
|
|
2066
|
-
work_arrays._allocate_stage_1(device, x.nnz, z, beta, z_aliasing)
|
|
2067
|
-
copied_z_nnz = work_arrays._copied_z_nnz
|
|
2068
|
-
|
|
2069
|
-
# Prefix sum of number of (unmerged) mm blocks per row
|
|
2070
|
-
# Use either a thread or a block per row depending on avg nnz/row
|
|
2071
|
-
work_arrays._mm_block_counts.zero_()
|
|
2072
|
-
count_tile_size = 32
|
|
2073
|
-
if not device.is_cuda or x.nnz < 3 * count_tile_size * x.nrow:
|
|
2074
|
-
count_tile_size = 1
|
|
2075
|
-
|
|
2076
|
-
wp.launch(
|
|
2077
|
-
kernel=make_bsr_mm_count_coeffs(count_tile_size),
|
|
2078
|
-
device=device,
|
|
2079
|
-
dim=(z.nrow, count_tile_size),
|
|
2080
|
-
block_dim=count_tile_size if count_tile_size > 1 else 256,
|
|
2081
|
-
inputs=[
|
|
2082
|
-
y.ncol,
|
|
2083
|
-
copied_z_nnz,
|
|
2084
|
-
x.offsets,
|
|
2085
|
-
x.columns,
|
|
2086
|
-
y.offsets,
|
|
2087
|
-
y.columns,
|
|
2088
|
-
work_arrays._mm_row_min,
|
|
2089
|
-
work_arrays._mm_block_counts,
|
|
2090
|
-
],
|
|
2091
|
-
)
|
|
2092
|
-
warp.utils.array_scan(work_arrays._mm_block_counts[: x.nnz + 1], work_arrays._mm_block_counts[: x.nnz + 1])
|
|
2093
|
-
|
|
2094
|
-
# Get back total counts on host -- we need a synchronization here
|
|
2095
|
-
# Use pinned buffer from z, we are going to need it later anyway
|
|
2096
|
-
nnz_buf, _ = z._setup_nnz_transfer()
|
|
2097
|
-
stream = wp.get_stream(device) if device.is_cuda else None
|
|
2098
|
-
wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
|
|
2099
|
-
if device.is_cuda:
|
|
2100
|
-
wp.synchronize_stream(stream)
|
|
2101
|
-
mm_nnz = int(nnz_buf.numpy()[0])
|
|
2102
|
-
|
|
2103
|
-
if mm_nnz == copied_z_nnz:
|
|
2104
|
-
# x@y = 0
|
|
2105
|
-
return bsr_scale(z, beta)
|
|
2106
|
-
|
|
2107
|
-
work_arrays._allocate_stage_2(mm_nnz)
|
|
2108
|
-
|
|
2109
|
-
# If z has a non-zero scale, save current data before overwriting it
|
|
2110
|
-
if copied_z_nnz > 0:
|
|
2111
|
-
# Copy z row and column indices
|
|
2112
|
-
wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
|
|
2113
|
-
z.uncompress_rows(out=work_arrays._mm_rows)
|
|
2114
|
-
work_arrays._mm_src_blocks[:copied_z_nnz].fill_(-1)
|
|
2115
|
-
if z_aliasing:
|
|
2116
|
-
# If z is aliasing with x or y, need to save topology as well
|
|
2117
|
-
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
2118
|
-
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
2119
|
-
|
|
2120
|
-
# Fill unmerged mm blocks rows and columns
|
|
2121
|
-
wp.launch(
|
|
2122
|
-
kernel=_bsr_mm_list_coeffs,
|
|
2123
|
-
device=device,
|
|
2124
|
-
dim=mm_nnz - copied_z_nnz,
|
|
2125
|
-
inputs=[
|
|
2126
|
-
copied_z_nnz,
|
|
2127
|
-
x.nrow,
|
|
2128
|
-
x.offsets,
|
|
2129
|
-
x.columns,
|
|
2130
|
-
y.offsets,
|
|
2131
|
-
y.columns,
|
|
2132
|
-
work_arrays._mm_row_min,
|
|
2133
|
-
work_arrays._mm_block_counts,
|
|
2134
|
-
work_arrays._mm_rows,
|
|
2135
|
-
work_arrays._mm_cols,
|
|
2136
|
-
work_arrays._mm_src_blocks,
|
|
2137
|
-
],
|
|
2138
|
-
)
|
|
2139
|
-
|
|
2140
|
-
alpha = z.scalar_type(alpha)
|
|
2141
|
-
beta = z.scalar_type(beta)
|
|
2142
|
-
|
|
2143
|
-
if copied_z_nnz > 0:
|
|
2144
|
-
# Save current z values in temporary buffer
|
|
2145
|
-
wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
|
|
2146
|
-
|
|
2147
|
-
if not masked:
|
|
2148
|
-
# Increase dest array size if needed
|
|
2149
|
-
if z.columns.shape[0] < mm_nnz:
|
|
2150
|
-
z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
|
|
2151
|
-
|
|
2152
|
-
from warp.context import runtime
|
|
2153
|
-
|
|
2154
|
-
if device.is_cpu:
|
|
2155
|
-
native_func = runtime.core.wp_bsr_matrix_from_triplets_host
|
|
2156
|
-
else:
|
|
2157
|
-
native_func = runtime.core.wp_bsr_matrix_from_triplets_device
|
|
2158
|
-
|
|
2159
|
-
nnz_buf, nnz_event = z._setup_nnz_transfer()
|
|
2160
|
-
summed_triplet_offsets = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
|
|
2161
|
-
summed_triplet_indices = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
|
|
2162
|
-
|
|
2163
|
-
with wp.ScopedDevice(z.device):
|
|
2164
|
-
native_func(
|
|
2165
|
-
z.block_size,
|
|
2166
|
-
0, # scalar_size_in_bytes
|
|
2167
|
-
z.nrow,
|
|
2168
|
-
z.ncol,
|
|
2169
|
-
mm_nnz,
|
|
2170
|
-
None, # device nnz
|
|
2171
|
-
ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
2172
|
-
ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
2173
|
-
None, # triplet values
|
|
2174
|
-
0, # zero_value_mask
|
|
2175
|
-
False, # masked_topology
|
|
2176
|
-
ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
2177
|
-
ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
2178
|
-
ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
2179
|
-
ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
2180
|
-
_optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
|
|
2181
|
-
_optional_ctypes_event(nnz_event),
|
|
2182
|
-
)
|
|
2183
|
-
|
|
2184
|
-
# Resize z to fit mm result if necessary
|
|
2185
|
-
# If we are not reusing the product topology, this needs another synchronization
|
|
2186
|
-
if not reuse_topology:
|
|
2187
|
-
work_arrays.result_nnz = z.nnz_sync()
|
|
2188
|
-
|
|
2189
|
-
_bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
|
|
2190
|
-
z.values.zero_()
|
|
2191
|
-
|
|
2192
|
-
if copied_z_nnz > 0:
|
|
2193
|
-
# Add back original z values
|
|
2194
|
-
wp.launch(
|
|
2195
|
-
kernel=_bsr_axpy_add_block,
|
|
2196
|
-
device=device,
|
|
2197
|
-
dim=(copied_z_nnz, z.block_shape[0], z.block_shape[1]),
|
|
2198
|
-
inputs=[
|
|
2199
|
-
0,
|
|
2200
|
-
beta,
|
|
2201
|
-
work_arrays._mm_rows,
|
|
2202
|
-
work_arrays._mm_cols,
|
|
2203
|
-
z.offsets,
|
|
2204
|
-
z.columns,
|
|
2205
|
-
_as_3d_array(work_arrays._old_z_values, z.block_shape),
|
|
2206
|
-
z.scalar_values,
|
|
2207
|
-
],
|
|
2208
|
-
)
|
|
2209
|
-
|
|
2210
|
-
max_subblock_dim = 12
|
|
2211
|
-
if tile_size > 0:
|
|
2212
|
-
use_tiles = True
|
|
2213
|
-
elif tile_size < 0:
|
|
2214
|
-
use_tiles = False
|
|
2215
|
-
else:
|
|
2216
|
-
# Heuristic for using tiled variant: few or very large blocks
|
|
2217
|
-
tile_size = 64
|
|
2218
|
-
max_tiles_per_sm = 2048 // tile_size # assume 64 resident warps per SM
|
|
2219
|
-
use_tiles = device.is_cuda and (
|
|
2220
|
-
max(x.block_size, y.block_size, z.block_size) > max_subblock_dim**2
|
|
2221
|
-
or mm_nnz < max_tiles_per_sm * device.sm_count
|
|
2222
|
-
)
|
|
2223
|
-
|
|
2224
|
-
if use_tiles:
|
|
2225
|
-
subblock_rows = min(max_subblock_dim, z.block_shape[0])
|
|
2226
|
-
subblock_cols = min(max_subblock_dim, z.block_shape[1])
|
|
2227
|
-
|
|
2228
|
-
wp.launch(
|
|
2229
|
-
kernel=make_bsr_mm_compute_values_tiled_outer(
|
|
2230
|
-
subblock_rows, subblock_cols, x.block_shape[1], z.scalar_type, tile_size
|
|
2231
|
-
),
|
|
2232
|
-
device=device,
|
|
2233
|
-
dim=(
|
|
2234
|
-
z.nnz,
|
|
2235
|
-
(z.block_shape[0] + subblock_rows - 1) // subblock_rows,
|
|
2236
|
-
(z.block_shape[1] + subblock_cols - 1) // subblock_cols,
|
|
2237
|
-
tile_size,
|
|
2238
|
-
),
|
|
2239
|
-
block_dim=tile_size,
|
|
2240
|
-
inputs=[
|
|
2241
|
-
alpha,
|
|
2242
|
-
work_arrays._old_z_offsets if x == z else x.offsets,
|
|
2243
|
-
work_arrays._old_z_columns if x == z else x.columns,
|
|
2244
|
-
_as_3d_array(work_arrays._old_z_values, z.block_shape) if x == z else x.scalar_values,
|
|
2245
|
-
work_arrays._old_z_offsets if y == z else y.offsets,
|
|
2246
|
-
work_arrays._old_z_columns if y == z else y.columns,
|
|
2247
|
-
_as_3d_array(work_arrays._old_z_values, z.block_shape) if y == z else y.scalar_values,
|
|
2248
|
-
None if masked else work_arrays._mm_row_min,
|
|
2249
|
-
None if masked else summed_triplet_offsets,
|
|
2250
|
-
None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
|
|
2251
|
-
z.nrow,
|
|
2252
|
-
z.offsets,
|
|
2253
|
-
z.columns,
|
|
2254
|
-
z.scalar_values,
|
|
2255
|
-
],
|
|
2256
|
-
)
|
|
2257
|
-
|
|
2258
|
-
return z
|
|
2259
|
-
|
|
2260
|
-
# Add mm blocks to z values
|
|
2261
|
-
if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
|
|
2262
|
-
# Result block type is scalar, but operands are matrices
|
|
2263
|
-
# Cast result to (1x1) matrix to perform multiplication
|
|
2264
|
-
mm_values = z.values.view(wp.mat(shape=(1, 1), dtype=z.scalar_type))
|
|
2265
|
-
else:
|
|
2266
|
-
mm_values = z.values
|
|
2267
|
-
|
|
2268
|
-
wp.launch(
|
|
2269
|
-
kernel=_bsr_mm_compute_values,
|
|
2270
|
-
device=device,
|
|
2271
|
-
dim=z.nnz,
|
|
2272
|
-
inputs=[
|
|
2273
|
-
alpha,
|
|
2274
|
-
work_arrays._old_z_offsets if x == z else x.offsets,
|
|
2275
|
-
work_arrays._old_z_columns if x == z else x.columns,
|
|
2276
|
-
work_arrays._old_z_values if x == z else x.values,
|
|
2277
|
-
work_arrays._old_z_offsets if y == z else y.offsets,
|
|
2278
|
-
work_arrays._old_z_columns if y == z else y.columns,
|
|
2279
|
-
work_arrays._old_z_values if y == z else y.values,
|
|
2280
|
-
None if masked else work_arrays._mm_row_min,
|
|
2281
|
-
None if masked else summed_triplet_offsets,
|
|
2282
|
-
None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
|
|
2283
|
-
z.nrow,
|
|
2284
|
-
z.offsets,
|
|
2285
|
-
z.columns,
|
|
2286
|
-
mm_values,
|
|
2287
|
-
],
|
|
2288
|
-
)
|
|
2289
|
-
|
|
2290
|
-
return z
|
|
2291
|
-
|
|
2292
|
-
|
|
2293
|
-
def make_bsr_mv_kernel(block_cols: int):
|
|
2294
|
-
from warp.fem.cache import dynamic_kernel
|
|
2295
|
-
|
|
2296
|
-
@dynamic_kernel(suffix=f"{block_cols}", kernel_options={"enable_backward": False})
|
|
2297
|
-
def bsr_mv_kernel(
|
|
2298
|
-
alpha: Any,
|
|
2299
|
-
A_offsets: wp.array(dtype=int),
|
|
2300
|
-
A_columns: wp.array(dtype=int),
|
|
2301
|
-
A_values: wp.array3d(dtype=Any),
|
|
2302
|
-
x: wp.array(dtype=Any),
|
|
2303
|
-
beta: Any,
|
|
2304
|
-
y: wp.array(dtype=Any),
|
|
2305
|
-
):
|
|
2306
|
-
row, subrow = wp.tid()
|
|
2307
|
-
|
|
2308
|
-
block_rows = A_values.shape[1]
|
|
2309
|
-
|
|
2310
|
-
yi = row * block_rows + subrow
|
|
2311
|
-
|
|
2312
|
-
# zero-initialize with type of y elements
|
|
2313
|
-
scalar_zero = type(alpha)(0)
|
|
2314
|
-
v = scalar_zero
|
|
2315
|
-
|
|
2316
|
-
if alpha != scalar_zero:
|
|
2317
|
-
beg = A_offsets[row]
|
|
2318
|
-
end = A_offsets[row + 1]
|
|
2319
|
-
for block in range(beg, end):
|
|
2320
|
-
xs = A_columns[block] * block_cols
|
|
2321
|
-
for col in range(wp.static(block_cols)):
|
|
2322
|
-
v += A_values[block, subrow, col] * x[xs + col]
|
|
2323
|
-
v *= alpha
|
|
2324
|
-
|
|
2325
|
-
if beta != scalar_zero:
|
|
2326
|
-
v += beta * y[yi]
|
|
2327
|
-
|
|
2328
|
-
y[yi] = v
|
|
2329
|
-
|
|
2330
|
-
return bsr_mv_kernel
|
|
2331
|
-
|
|
2332
|
-
|
|
2333
|
-
def make_bsr_mv_tiled_kernel(tile_size: int):
|
|
2334
|
-
from warp.fem.cache import dynamic_kernel
|
|
2335
|
-
|
|
2336
|
-
@dynamic_kernel(suffix=f"{tile_size}", kernel_options={"enable_backward": False})
|
|
2337
|
-
def bsr_mv_tiled_kernel(
|
|
2338
|
-
alpha: Any,
|
|
2339
|
-
A_offsets: wp.array(dtype=int),
|
|
2340
|
-
A_columns: wp.array(dtype=int),
|
|
2341
|
-
A_values: wp.array3d(dtype=Any),
|
|
2342
|
-
x: wp.array(dtype=Any),
|
|
2343
|
-
beta: Any,
|
|
2344
|
-
y: wp.array(dtype=Any),
|
|
2345
|
-
):
|
|
2346
|
-
row, subrow, lane = wp.tid()
|
|
2347
|
-
|
|
2348
|
-
scalar_zero = type(alpha)(0)
|
|
2349
|
-
block_rows = A_values.shape[1]
|
|
2350
|
-
block_cols = A_values.shape[2]
|
|
2351
|
-
|
|
2352
|
-
yi = row * block_rows + subrow
|
|
2353
|
-
|
|
2354
|
-
if beta == scalar_zero:
|
|
2355
|
-
subrow_sum = wp.tile_zeros(shape=(1,), dtype=y.dtype)
|
|
2356
|
-
else:
|
|
2357
|
-
subrow_sum = beta * wp.tile_load(y, 1, yi)
|
|
2358
|
-
|
|
2359
|
-
if alpha != scalar_zero:
|
|
2360
|
-
block_beg = A_offsets[row]
|
|
2361
|
-
col_count = (A_offsets[row + 1] - block_beg) * block_cols
|
|
2362
|
-
|
|
2363
|
-
col = lane
|
|
2364
|
-
lane_sum = y.dtype(0)
|
|
2365
|
-
|
|
2366
|
-
for col in range(lane, col_count, tile_size):
|
|
2367
|
-
block = col // block_cols
|
|
2368
|
-
block_col = col - block * block_cols
|
|
2369
|
-
block += block_beg
|
|
2370
|
-
|
|
2371
|
-
xi = x[A_columns[block] * block_cols + block_col]
|
|
2372
|
-
lane_sum += A_values[block, subrow, block_col] * xi
|
|
2373
|
-
|
|
2374
|
-
lane_sum *= alpha
|
|
2375
|
-
subrow_sum += wp.tile_sum(wp.tile(lane_sum))
|
|
2376
|
-
|
|
2377
|
-
wp.tile_store(y, subrow_sum, yi)
|
|
2378
|
-
|
|
2379
|
-
return bsr_mv_tiled_kernel
|
|
2380
|
-
|
|
2381
|
-
|
|
2382
|
-
def make_bsr_mv_transpose_kernel(block_rows: int):
|
|
2383
|
-
from warp.fem.cache import dynamic_kernel
|
|
2384
|
-
|
|
2385
|
-
@dynamic_kernel(suffix=f"{block_rows}", kernel_options={"enable_backward": False})
|
|
2386
|
-
def bsr_mv_transpose_kernel(
|
|
2387
|
-
alpha: Any,
|
|
2388
|
-
A_row_count: int,
|
|
2389
|
-
A_offsets: wp.array(dtype=int),
|
|
2390
|
-
A_columns: wp.array(dtype=int),
|
|
2391
|
-
A_values: wp.array3d(dtype=Any),
|
|
2392
|
-
x: wp.array(dtype=Any),
|
|
2393
|
-
y: wp.array(dtype=Any),
|
|
2394
|
-
):
|
|
2395
|
-
block, subcol = wp.tid()
|
|
2396
|
-
|
|
2397
|
-
row = _bsr_row_index(A_offsets, A_row_count, block)
|
|
2398
|
-
if row == -1:
|
|
2399
|
-
return
|
|
2400
|
-
|
|
2401
|
-
block_cols = A_values.shape[2]
|
|
2402
|
-
|
|
2403
|
-
A_block = A_values[block]
|
|
2404
|
-
|
|
2405
|
-
col_sum = type(alpha)(0)
|
|
2406
|
-
for subrow in range(wp.static(block_rows)):
|
|
2407
|
-
col_sum += A_block[subrow, subcol] * x[row * block_rows + subrow]
|
|
2408
|
-
|
|
2409
|
-
wp.atomic_add(y, A_columns[block] * block_cols + subcol, alpha * col_sum)
|
|
2410
|
-
|
|
2411
|
-
return bsr_mv_transpose_kernel
|
|
2412
|
-
|
|
2413
|
-
|
|
2414
|
-
def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
|
|
2415
|
-
# cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
|
|
2416
|
-
|
|
2417
|
-
scalar_count = array.size * type_size(array.dtype)
|
|
2418
|
-
if scalar_count != expected_scalar_count:
|
|
2419
|
-
raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
|
|
2420
|
-
|
|
2421
|
-
if array.ndim == 1 and types_equal(array.dtype, dtype):
|
|
2422
|
-
return array
|
|
2423
|
-
|
|
2424
|
-
if type_scalar_type(array.dtype) != type_scalar_type(dtype):
|
|
2425
|
-
raise ValueError(f"Incompatible scalar types, expected {type_repr(array.dtype)}, got {type_repr(dtype)}")
|
|
2426
|
-
|
|
2427
|
-
if array.ndim > 2:
|
|
2428
|
-
raise ValueError(f"Incompatible array number of dimensions, expected 1 or 2, got {array.ndim}")
|
|
2429
|
-
|
|
2430
|
-
if not array.is_contiguous:
|
|
2431
|
-
raise ValueError("Array must be contiguous")
|
|
2432
|
-
|
|
2433
|
-
vec_length = type_size(dtype)
|
|
2434
|
-
vec_count = scalar_count // vec_length
|
|
2435
|
-
if vec_count * vec_length != scalar_count:
|
|
2436
|
-
raise ValueError(
|
|
2437
|
-
f"Array of shape {array.shape} and type {type_repr(array.dtype)} cannot be reshaped to an array of type {type_repr(dtype)}"
|
|
2438
|
-
)
|
|
2439
|
-
|
|
2440
|
-
def vec_view(array):
|
|
2441
|
-
return wp.array(
|
|
2442
|
-
data=None,
|
|
2443
|
-
ptr=array.ptr,
|
|
2444
|
-
capacity=array.capacity,
|
|
2445
|
-
device=array.device,
|
|
2446
|
-
dtype=dtype,
|
|
2447
|
-
shape=vec_count,
|
|
2448
|
-
grad=None if array.grad is None else vec_view(array.grad),
|
|
2449
|
-
)
|
|
2450
|
-
|
|
2451
|
-
view = vec_view(array)
|
|
2452
|
-
view._ref = array
|
|
2453
|
-
return view
|
|
2454
|
-
|
|
2455
|
-
|
|
2456
|
-
def bsr_mv(
|
|
2457
|
-
A: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
|
|
2458
|
-
x: "Array[Vector[Cols, Scalar] | Scalar]",
|
|
2459
|
-
y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
2460
|
-
alpha: Scalar = 1.0,
|
|
2461
|
-
beta: Scalar = 0.0,
|
|
2462
|
-
transpose: bool = False,
|
|
2463
|
-
work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
2464
|
-
tile_size: int = 0,
|
|
2465
|
-
) -> "Array[Vector[Rows, Scalar] | Scalar]":
|
|
2466
|
-
"""Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
|
|
2467
|
-
|
|
2468
|
-
The ``x`` and ``y`` vectors are allowed to alias.
|
|
2469
|
-
|
|
2470
|
-
Args:
|
|
2471
|
-
A: Read-only, left matrix operand of the matrix-vector product.
|
|
2472
|
-
x: Read-only, right vector operand of the matrix-vector product.
|
|
2473
|
-
y: Mutable affine operand and result vector. If ``y`` is not provided, it will be allocated and treated as zero.
|
|
2474
|
-
alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
|
|
2475
|
-
beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
|
|
2476
|
-
transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
|
|
2477
|
-
work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
|
|
2478
|
-
If provided, the ``work_buffer`` array will be used for this purpose,
|
|
2479
|
-
otherwise a temporary allocation will be performed.
|
|
2480
|
-
tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
|
|
2481
|
-
If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
|
|
2482
|
-
use tiles using using an heuristic based on the matrix shape and number of non-zeros..
|
|
2483
|
-
"""
|
|
2484
|
-
|
|
2485
|
-
A, A_scale = _extract_matrix_and_scale(A)
|
|
2486
|
-
alpha *= A_scale
|
|
2487
|
-
|
|
2488
|
-
if transpose:
|
|
2489
|
-
block_shape = A.block_shape[1], A.block_shape[0]
|
|
2490
|
-
nrow, ncol = A.ncol, A.nrow
|
|
2491
|
-
else:
|
|
2492
|
-
block_shape = A.block_shape
|
|
2493
|
-
nrow, ncol = A.nrow, A.ncol
|
|
2494
|
-
|
|
2495
|
-
if y is None:
|
|
2496
|
-
# If no output array is provided, allocate one for convenience
|
|
2497
|
-
y_vec_len = block_shape[0]
|
|
2498
|
-
y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
|
|
2499
|
-
y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype, requires_grad=x.requires_grad)
|
|
2500
|
-
beta = 0.0
|
|
2501
|
-
|
|
2502
|
-
alpha = A.scalar_type(alpha)
|
|
2503
|
-
beta = A.scalar_type(beta)
|
|
2504
|
-
|
|
2505
|
-
device = A.values.device
|
|
2506
|
-
if A.values.device != x.device or A.values.device != y.device:
|
|
2507
|
-
raise ValueError(
|
|
2508
|
-
f"A, x, and y must reside on the same device, got {A.values.device}, {x.device} and {y.device}"
|
|
2509
|
-
)
|
|
2510
|
-
|
|
2511
|
-
if x.ptr == y.ptr:
|
|
2512
|
-
# Aliasing case, need temporary storage
|
|
2513
|
-
if work_buffer is None:
|
|
2514
|
-
work_buffer = wp.empty_like(y)
|
|
2515
|
-
elif work_buffer.size < y.size:
|
|
2516
|
-
raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}, got {work_buffer.size}")
|
|
2517
|
-
elif not types_equal(work_buffer.dtype, y.dtype):
|
|
2518
|
-
raise ValueError(
|
|
2519
|
-
f"Work buffer must have same data type as y, {type_repr(y.dtype)} vs {type_repr(work_buffer.dtype)}"
|
|
2520
|
-
)
|
|
2521
|
-
|
|
2522
|
-
# Save old y values before overwriting vector
|
|
2523
|
-
wp.copy(dest=work_buffer, src=y, count=y.size)
|
|
2524
|
-
x = work_buffer
|
|
2525
|
-
|
|
2526
|
-
try:
|
|
2527
|
-
x_view = _vec_array_view(x, A.scalar_type, expected_scalar_count=ncol * block_shape[1])
|
|
2528
|
-
except ValueError as err:
|
|
2529
|
-
raise ValueError("Incompatible 'x' vector for bsr_mv") from err
|
|
2530
|
-
try:
|
|
2531
|
-
y_view = _vec_array_view(y, A.scalar_type, expected_scalar_count=nrow * block_shape[0])
|
|
2532
|
-
except ValueError as err:
|
|
2533
|
-
raise ValueError("Incompatible 'y' vector for bsr_mv") from err
|
|
2534
|
-
|
|
2535
|
-
# heuristic to use tiled version for long rows
|
|
2536
|
-
if tile_size > 0:
|
|
2537
|
-
use_tiles = True
|
|
2538
|
-
elif tile_size < 0:
|
|
2539
|
-
use_tiles = False
|
|
2540
|
-
else:
|
|
2541
|
-
tile_size = 64
|
|
2542
|
-
use_tiles = device.is_cuda and A.nnz * A.block_size > 2 * tile_size * A.shape[0]
|
|
2543
|
-
|
|
2544
|
-
if transpose:
|
|
2545
|
-
if beta.value == 0.0:
|
|
2546
|
-
y.zero_()
|
|
2547
|
-
elif beta.value != 1.0:
|
|
2548
|
-
wp.launch(
|
|
2549
|
-
kernel=_bsr_scale_kernel,
|
|
2550
|
-
device=y.device,
|
|
2551
|
-
dim=y_view.shape[0],
|
|
2552
|
-
inputs=[beta, y_view],
|
|
2553
|
-
)
|
|
2554
|
-
if alpha.value != 0.0:
|
|
2555
|
-
wp.launch(
|
|
2556
|
-
kernel=make_bsr_mv_transpose_kernel(block_rows=block_shape[1]),
|
|
2557
|
-
device=A.values.device,
|
|
2558
|
-
dim=(A.nnz, block_shape[0]),
|
|
2559
|
-
inputs=[alpha, A.nrow, A.offsets, A.columns, A.scalar_values, x_view, y_view],
|
|
2560
|
-
)
|
|
2561
|
-
elif use_tiles:
|
|
2562
|
-
wp.launch(
|
|
2563
|
-
kernel=make_bsr_mv_tiled_kernel(tile_size),
|
|
2564
|
-
device=A.values.device,
|
|
2565
|
-
dim=(nrow, block_shape[0], tile_size),
|
|
2566
|
-
block_dim=tile_size,
|
|
2567
|
-
inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
|
|
2568
|
-
)
|
|
2569
|
-
else:
|
|
2570
|
-
wp.launch(
|
|
2571
|
-
kernel=make_bsr_mv_kernel(block_cols=block_shape[1]),
|
|
2572
|
-
device=A.values.device,
|
|
2573
|
-
dim=(nrow, block_shape[0]),
|
|
2574
|
-
inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
|
|
2575
|
-
)
|
|
2576
|
-
|
|
2577
|
-
return y
|
|
16
|
+
# isort: skip_file
|
|
17
|
+
|
|
18
|
+
from warp._src.sparse import BsrMatrix as BsrMatrix
|
|
19
|
+
from warp._src.sparse import bsr_assign as bsr_assign
|
|
20
|
+
from warp._src.sparse import bsr_axpy as bsr_axpy
|
|
21
|
+
from warp._src.sparse import bsr_axpy_work_arrays as bsr_axpy_work_arrays
|
|
22
|
+
from warp._src.sparse import bsr_block_index as bsr_block_index
|
|
23
|
+
from warp._src.sparse import bsr_copy as bsr_copy
|
|
24
|
+
from warp._src.sparse import bsr_diag as bsr_diag
|
|
25
|
+
from warp._src.sparse import bsr_from_triplets as bsr_from_triplets
|
|
26
|
+
from warp._src.sparse import bsr_get_diag as bsr_get_diag
|
|
27
|
+
from warp._src.sparse import bsr_identity as bsr_identity
|
|
28
|
+
from warp._src.sparse import bsr_matrix_t as bsr_matrix_t
|
|
29
|
+
from warp._src.sparse import bsr_mm as bsr_mm
|
|
30
|
+
from warp._src.sparse import bsr_mm_work_arrays as bsr_mm_work_arrays
|
|
31
|
+
from warp._src.sparse import bsr_mv as bsr_mv
|
|
32
|
+
from warp._src.sparse import bsr_row_index as bsr_row_index
|
|
33
|
+
from warp._src.sparse import bsr_scale as bsr_scale
|
|
34
|
+
from warp._src.sparse import bsr_set_diag as bsr_set_diag
|
|
35
|
+
from warp._src.sparse import bsr_set_from_triplets as bsr_set_from_triplets
|
|
36
|
+
from warp._src.sparse import bsr_set_identity as bsr_set_identity
|
|
37
|
+
from warp._src.sparse import bsr_set_transpose as bsr_set_transpose
|
|
38
|
+
from warp._src.sparse import bsr_set_zero as bsr_set_zero
|
|
39
|
+
from warp._src.sparse import bsr_transposed as bsr_transposed
|
|
40
|
+
from warp._src.sparse import bsr_zeros as bsr_zeros
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# TODO: Remove after cleaning up the public API.
|
|
44
|
+
|
|
45
|
+
from warp._src import sparse as _sparse
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def __getattr__(name):
|
|
49
|
+
from warp._src.utils import get_deprecated_api
|
|
50
|
+
|
|
51
|
+
return get_deprecated_api(_sparse, "wp", name)
|