warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/sparse.py
CHANGED
|
@@ -1,14 +1,29 @@
|
|
|
1
|
+
from typing import Any, Generic, Optional, Tuple, TypeVar, Union
|
|
2
|
+
|
|
1
3
|
import warp as wp
|
|
2
4
|
import warp.types
|
|
3
5
|
import warp.utils
|
|
6
|
+
from warp.types import Array, Cols, Matrix, Rows, Scalar, Vector
|
|
7
|
+
|
|
8
|
+
# typing hints
|
|
9
|
+
|
|
10
|
+
_BlockType = TypeVar("BlockType")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _MatrixBlockType(Matrix):
|
|
14
|
+
pass
|
|
4
15
|
|
|
5
|
-
from typing import Tuple, Any, Union
|
|
6
16
|
|
|
17
|
+
class _ScalarBlockType(Generic[Scalar]):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
|
|
7
22
|
|
|
8
23
|
_struct_cache = dict()
|
|
9
24
|
|
|
10
25
|
|
|
11
|
-
class BsrMatrix:
|
|
26
|
+
class BsrMatrix(Generic[_BlockType]):
|
|
12
27
|
"""Untyped base class for BSR and CSR matrices.
|
|
13
28
|
|
|
14
29
|
Should not be constructed directly but through functions such as :func:`bsr_zeros`.
|
|
@@ -16,15 +31,15 @@ class BsrMatrix:
|
|
|
16
31
|
Attributes:
|
|
17
32
|
nrow (int): Number of rows of blocks
|
|
18
33
|
ncol (int): Number of columns of blocks
|
|
19
|
-
nnz (int): Number of non-zero blocks: equal to
|
|
20
|
-
offsets (
|
|
21
|
-
columns (
|
|
22
|
-
values (
|
|
34
|
+
nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
|
|
35
|
+
offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
|
|
36
|
+
columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
|
|
37
|
+
values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
|
|
23
38
|
"""
|
|
24
39
|
|
|
25
40
|
@property
|
|
26
|
-
def scalar_type(self) ->
|
|
27
|
-
"""Scalar type for
|
|
41
|
+
def scalar_type(self) -> Scalar:
|
|
42
|
+
"""Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
|
|
28
43
|
return warp.types.type_scalar_type(self.values.dtype)
|
|
29
44
|
|
|
30
45
|
@property
|
|
@@ -33,20 +48,35 @@ class BsrMatrix:
|
|
|
33
48
|
return getattr(self.values.dtype, "_shape_", (1, 1))
|
|
34
49
|
|
|
35
50
|
@property
|
|
36
|
-
def block_size(self) ->
|
|
37
|
-
"""Size of the individual blocks, i.e. number of rows per block times number of
|
|
51
|
+
def block_size(self) -> int:
|
|
52
|
+
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
|
|
38
53
|
return warp.types.type_length(self.values.dtype)
|
|
39
54
|
|
|
40
55
|
@property
|
|
41
56
|
def shape(self) -> Tuple[int, int]:
|
|
42
|
-
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/
|
|
57
|
+
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
|
|
43
58
|
block_shape = self.block_shape
|
|
44
59
|
return (self.nrow * block_shape[0], self.ncol * block_shape[1])
|
|
45
60
|
|
|
61
|
+
@property
|
|
62
|
+
def dtype(self) -> type:
|
|
63
|
+
"""Data type for individual block values"""
|
|
64
|
+
return self.values.dtype
|
|
46
65
|
|
|
47
|
-
|
|
66
|
+
@property
|
|
67
|
+
def device(self) -> wp.context.Device:
|
|
68
|
+
"""Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays """
|
|
69
|
+
return self.values.device
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def bsr_matrix_t(dtype: BlockType):
|
|
48
73
|
dtype = wp.types.type_to_warp(dtype)
|
|
49
74
|
|
|
75
|
+
if not warp.types.type_is_matrix(dtype) and not dtype in warp.types.scalar_types:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
|
|
78
|
+
)
|
|
79
|
+
|
|
50
80
|
class BsrMatrixTyped(BsrMatrix):
|
|
51
81
|
nrow: int
|
|
52
82
|
"""Number of rows of blocks"""
|
|
@@ -79,11 +109,23 @@ def bsr_matrix_t(dtype: type):
|
|
|
79
109
|
|
|
80
110
|
|
|
81
111
|
def bsr_zeros(
|
|
82
|
-
rows_of_blocks: int,
|
|
112
|
+
rows_of_blocks: int,
|
|
113
|
+
cols_of_blocks: int,
|
|
114
|
+
block_type: BlockType,
|
|
115
|
+
device: wp.context.Devicelike = None,
|
|
83
116
|
) -> BsrMatrix:
|
|
84
117
|
"""
|
|
85
|
-
Constructs an empty BSR or
|
|
118
|
+
Constructs and returns an empty BSR or CSR matrix with the given shape
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
bsr: The BSR or CSR matrix to set to zero
|
|
122
|
+
rows_of_blocks: Number of rows of blocks
|
|
123
|
+
cols_of_blocks: Number of columns of blocks
|
|
124
|
+
block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
|
|
125
|
+
for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
|
|
126
|
+
device: Device on which to allocate the matrix arrays
|
|
86
127
|
"""
|
|
128
|
+
|
|
87
129
|
bsr = bsr_matrix_t(block_type)()
|
|
88
130
|
|
|
89
131
|
bsr.nrow = rows_of_blocks
|
|
@@ -110,19 +152,42 @@ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
|
|
|
110
152
|
bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
|
|
111
153
|
|
|
112
154
|
|
|
155
|
+
def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
|
|
156
|
+
"""
|
|
157
|
+
Sets a BSR matrix to zero, possibly changing its size
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
bsr: The BSR or CSR matrix to set to zero
|
|
161
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
162
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
if rows_of_blocks is not None:
|
|
166
|
+
bsr.nrow = rows_of_blocks
|
|
167
|
+
if cols_of_blocks is not None:
|
|
168
|
+
bsr.ncol = cols_of_blocks
|
|
169
|
+
bsr.nnz = 0
|
|
170
|
+
_bsr_ensure_fits(bsr)
|
|
171
|
+
bsr.offsets.zero_()
|
|
172
|
+
|
|
173
|
+
|
|
113
174
|
def bsr_set_from_triplets(
|
|
114
|
-
dest: BsrMatrix,
|
|
115
|
-
rows:
|
|
116
|
-
columns:
|
|
117
|
-
values:
|
|
175
|
+
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
176
|
+
rows: "Array[int]",
|
|
177
|
+
columns: "Array[int]",
|
|
178
|
+
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
118
179
|
):
|
|
119
180
|
"""
|
|
120
|
-
Fills a BSR matrix
|
|
181
|
+
Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
121
182
|
|
|
122
|
-
|
|
123
|
-
or a 3d array with data type equal to the `dest` matrix scalar type.
|
|
183
|
+
The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
|
|
124
184
|
|
|
125
|
-
|
|
185
|
+
Args:
|
|
186
|
+
dest: Sparse matrix to populate
|
|
187
|
+
rows: Row index for each non-zero
|
|
188
|
+
columns: Columns index for each non-zero
|
|
189
|
+
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
190
|
+
to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
|
|
126
191
|
"""
|
|
127
192
|
|
|
128
193
|
if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
|
|
@@ -138,7 +203,7 @@ def bsr_set_from_triplets(
|
|
|
138
203
|
elif values.ndim == 3:
|
|
139
204
|
if values.shape[1:] != dest.block_shape:
|
|
140
205
|
raise ValueError(
|
|
141
|
-
f"Last two dimensions in values array ({values.shape[1:]})
|
|
206
|
+
f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
|
|
142
207
|
)
|
|
143
208
|
|
|
144
209
|
if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
|
|
@@ -150,6 +215,9 @@ def bsr_set_from_triplets(
|
|
|
150
215
|
raise ValueError("Number of dimension for values array should be 1 or 3")
|
|
151
216
|
|
|
152
217
|
nnz = rows.shape[0]
|
|
218
|
+
if nnz == 0:
|
|
219
|
+
bsr_set_zero(dest)
|
|
220
|
+
return
|
|
153
221
|
|
|
154
222
|
# Increase dest array sizes if needed
|
|
155
223
|
_bsr_ensure_fits(dest, nnz=nnz)
|
|
@@ -186,8 +254,8 @@ def bsr_set_from_triplets(
|
|
|
186
254
|
)
|
|
187
255
|
|
|
188
256
|
|
|
189
|
-
def bsr_assign(dest: BsrMatrix, src: BsrMatrix):
|
|
190
|
-
"""Copies the content of the `src` matrix to `dest`,
|
|
257
|
+
def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
|
|
258
|
+
"""Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
|
|
191
259
|
|
|
192
260
|
if dest.values.device != src.values.device:
|
|
193
261
|
raise ValueError("Source and destination matrices must reside on the same device")
|
|
@@ -202,13 +270,17 @@ def bsr_assign(dest: BsrMatrix, src: BsrMatrix):
|
|
|
202
270
|
_bsr_ensure_fits(dest)
|
|
203
271
|
|
|
204
272
|
wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
|
|
205
|
-
|
|
273
|
+
if src.nnz > 0:
|
|
274
|
+
wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
|
|
275
|
+
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
|
|
206
276
|
|
|
207
|
-
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
|
|
208
277
|
|
|
278
|
+
def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
|
|
279
|
+
"""Returns a copy of matrix ``A``, possibly changing its scalar type.
|
|
209
280
|
|
|
210
|
-
|
|
211
|
-
|
|
281
|
+
Args:
|
|
282
|
+
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
|
|
283
|
+
"""
|
|
212
284
|
if scalar_type is None:
|
|
213
285
|
block_type = A.values.dtype
|
|
214
286
|
elif A.block_shape == (1, 1):
|
|
@@ -221,7 +293,7 @@ def bsr_copy(A: BsrMatrix, scalar_type=None):
|
|
|
221
293
|
return copy
|
|
222
294
|
|
|
223
295
|
|
|
224
|
-
def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
|
|
296
|
+
def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
|
|
225
297
|
"""Assigns the transposed matrix `src` to matrix `dest`"""
|
|
226
298
|
|
|
227
299
|
if dest.values.device != src.values.device:
|
|
@@ -230,10 +302,7 @@ def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
|
|
|
230
302
|
if dest.scalar_type != src.scalar_type:
|
|
231
303
|
raise ValueError("All arguments must have the same scalar type")
|
|
232
304
|
|
|
233
|
-
|
|
234
|
-
transpose_block_shape = (1, 1)
|
|
235
|
-
else:
|
|
236
|
-
transpose_block_shape = src.block_shape[::-1]
|
|
305
|
+
transpose_block_shape = src.block_shape[::-1]
|
|
237
306
|
|
|
238
307
|
if dest.block_shape != transpose_block_shape:
|
|
239
308
|
raise ValueError(f"Destination block shape must be {transpose_block_shape}")
|
|
@@ -242,6 +311,9 @@ def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
|
|
|
242
311
|
dest.ncol = src.nrow
|
|
243
312
|
dest.nnz = src.nnz
|
|
244
313
|
|
|
314
|
+
if src.nnz == 0:
|
|
315
|
+
return
|
|
316
|
+
|
|
245
317
|
# Increase dest array sizes if needed
|
|
246
318
|
_bsr_ensure_fits(dest)
|
|
247
319
|
|
|
@@ -301,27 +373,33 @@ def _bsr_get_diag_kernel(
|
|
|
301
373
|
end = A_offsets[row + 1]
|
|
302
374
|
|
|
303
375
|
diag = wp.lower_bound(A_columns, beg, end, row)
|
|
304
|
-
if
|
|
305
|
-
|
|
376
|
+
if diag < end:
|
|
377
|
+
if A_columns[diag] == row:
|
|
378
|
+
out[row] = A_values[diag]
|
|
379
|
+
|
|
306
380
|
|
|
381
|
+
def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
|
|
382
|
+
"""Returns the array of blocks that constitute the diagonal of a sparse matrix.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
A: the sparse matrix from which to extract the diagonal
|
|
386
|
+
out: if provided, the array into which to store the diagonal blocks
|
|
387
|
+
"""
|
|
307
388
|
|
|
308
|
-
|
|
309
|
-
"""Returns the block diagonal of a square sparse matrix"""
|
|
310
|
-
if A.nrow != A.ncol:
|
|
311
|
-
raise ValueError("bsr_get_diag is only available for square sparse matrices")
|
|
389
|
+
dim = min(A.nrow, A.ncol)
|
|
312
390
|
|
|
313
391
|
if out is None:
|
|
314
|
-
out = wp.zeros(shape=(
|
|
392
|
+
out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
|
|
315
393
|
else:
|
|
316
394
|
if out.dtype != A.values.dtype:
|
|
317
395
|
raise ValueError(f"Output array must have type {A.values.dtype}")
|
|
318
396
|
if out.device != A.values.device:
|
|
319
397
|
raise ValueError(f"Output array must reside on device {A.values.device}")
|
|
320
|
-
if out.shape[0] <
|
|
321
|
-
raise ValueError(f"Output array must be of length at least {
|
|
398
|
+
if out.shape[0] < dim:
|
|
399
|
+
raise ValueError(f"Output array must be of length at least {dim}")
|
|
322
400
|
|
|
323
401
|
wp.launch(
|
|
324
|
-
kernel=_bsr_get_diag_kernel, dim=
|
|
402
|
+
kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
|
|
325
403
|
)
|
|
326
404
|
|
|
327
405
|
return out
|
|
@@ -329,40 +407,205 @@ def bsr_get_diag(A: BsrMatrix, out: wp.array = None):
|
|
|
329
407
|
|
|
330
408
|
@wp.kernel
|
|
331
409
|
def _bsr_set_diag_kernel(
|
|
410
|
+
diag: wp.array(dtype=Any),
|
|
411
|
+
A_offsets: wp.array(dtype=int),
|
|
412
|
+
A_columns: wp.array(dtype=int),
|
|
413
|
+
A_values: wp.array(dtype=Any),
|
|
414
|
+
):
|
|
415
|
+
row = wp.tid()
|
|
416
|
+
A_offsets[row + 1] = row + 1
|
|
417
|
+
A_columns[row] = row
|
|
418
|
+
A_values[row] = diag[row]
|
|
419
|
+
|
|
420
|
+
if row == 0:
|
|
421
|
+
A_offsets[0] = 0
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@wp.kernel
|
|
425
|
+
def _bsr_set_diag_constant_kernel(
|
|
426
|
+
diag_value: Any,
|
|
332
427
|
A_offsets: wp.array(dtype=int),
|
|
333
428
|
A_columns: wp.array(dtype=int),
|
|
429
|
+
A_values: wp.array(dtype=Any),
|
|
334
430
|
):
|
|
335
431
|
row = wp.tid()
|
|
336
432
|
A_offsets[row + 1] = row + 1
|
|
337
433
|
A_columns[row] = row
|
|
434
|
+
A_values[row] = diag_value
|
|
338
435
|
|
|
339
436
|
if row == 0:
|
|
340
437
|
A_offsets[0] = 0
|
|
341
438
|
|
|
342
439
|
|
|
343
|
-
def bsr_set_diag(
|
|
344
|
-
|
|
440
|
+
def bsr_set_diag(
|
|
441
|
+
A: BsrMatrix[BlockType],
|
|
442
|
+
diag: "Union[BlockType, Array[BlockType]]",
|
|
443
|
+
rows_of_blocks: Optional[int] = None,
|
|
444
|
+
cols_of_blocks: Optional[int] = None,
|
|
445
|
+
):
|
|
446
|
+
"""Sets `A` as a block-diagonal matrix
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
A: the sparse matrix to modify
|
|
450
|
+
diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
|
|
451
|
+
or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
|
|
452
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
453
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
454
|
+
|
|
455
|
+
The shape of the matrix will be defined one of the following, in that order:
|
|
456
|
+
- `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
|
|
457
|
+
- the first dimension of `diag`, if `diag` is an array
|
|
458
|
+
- the current dimensions of `A` otherwise
|
|
459
|
+
"""
|
|
460
|
+
|
|
461
|
+
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
462
|
+
rows_of_blocks = cols_of_blocks
|
|
463
|
+
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
464
|
+
cols_of_blocks = rows_of_blocks
|
|
465
|
+
|
|
466
|
+
if warp.types.is_array(diag):
|
|
467
|
+
if rows_of_blocks is None:
|
|
468
|
+
rows_of_blocks = diag.shape[0]
|
|
469
|
+
cols_of_blocks = diag.shape[0]
|
|
470
|
+
|
|
471
|
+
if rows_of_blocks is not None:
|
|
472
|
+
A.nrow = rows_of_blocks
|
|
473
|
+
A.ncol = cols_of_blocks
|
|
474
|
+
|
|
475
|
+
A.nnz = min(A.nrow, A.ncol)
|
|
476
|
+
_bsr_ensure_fits(A)
|
|
477
|
+
|
|
478
|
+
if warp.types.is_array(diag):
|
|
479
|
+
wp.launch(
|
|
480
|
+
kernel=_bsr_set_diag_kernel,
|
|
481
|
+
dim=A.nnz,
|
|
482
|
+
device=A.values.device,
|
|
483
|
+
inputs=[diag, A.offsets, A.columns, A.values],
|
|
484
|
+
)
|
|
485
|
+
else:
|
|
486
|
+
if not warp.types.type_is_value(type(diag)):
|
|
487
|
+
# Cast to launchable type
|
|
488
|
+
diag = A.values.dtype(diag)
|
|
489
|
+
wp.launch(
|
|
490
|
+
kernel=_bsr_set_diag_constant_kernel,
|
|
491
|
+
dim=A.nnz,
|
|
492
|
+
device=A.values.device,
|
|
493
|
+
inputs=[diag, A.offsets, A.columns, A.values],
|
|
494
|
+
)
|
|
345
495
|
|
|
346
|
-
A.nrow = diag.shape[0]
|
|
347
|
-
A.ncol = diag.shape[0]
|
|
348
|
-
A.nnz = diag.shape[0]
|
|
349
496
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
497
|
+
def bsr_diag(
|
|
498
|
+
diag: "Union[BlockType, Array[BlockType]]",
|
|
499
|
+
rows_of_blocks: Optional[int] = None,
|
|
500
|
+
cols_of_blocks: Optional[int] = None,
|
|
501
|
+
) -> BsrMatrix["BlockType"]:
|
|
502
|
+
"""Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
|
|
355
503
|
|
|
356
|
-
|
|
504
|
+
Args:
|
|
505
|
+
diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
|
|
506
|
+
or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
|
|
507
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
508
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
357
509
|
|
|
510
|
+
The shape of the matrix will be defined one of the following, in that order:
|
|
511
|
+
- `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
|
|
512
|
+
- the first dimension of `diag`, if `diag` is an array
|
|
513
|
+
"""
|
|
514
|
+
|
|
515
|
+
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
516
|
+
rows_of_blocks = cols_of_blocks
|
|
517
|
+
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
518
|
+
cols_of_blocks = rows_of_blocks
|
|
519
|
+
|
|
520
|
+
if warp.types.is_array(diag):
|
|
521
|
+
if rows_of_blocks is None:
|
|
522
|
+
rows_of_blocks = diag.shape[0]
|
|
523
|
+
cols_of_blocks = diag.shape[0]
|
|
524
|
+
|
|
525
|
+
A = bsr_zeros(
|
|
526
|
+
rows_of_blocks,
|
|
527
|
+
cols_of_blocks,
|
|
528
|
+
block_type=diag.dtype,
|
|
529
|
+
device=diag.device,
|
|
530
|
+
)
|
|
531
|
+
else:
|
|
532
|
+
if rows_of_blocks is None:
|
|
533
|
+
raise ValueError(
|
|
534
|
+
"rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
block_type = type(diag)
|
|
538
|
+
if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
|
|
539
|
+
block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
|
|
540
|
+
|
|
541
|
+
A = bsr_zeros(
|
|
542
|
+
rows_of_blocks,
|
|
543
|
+
cols_of_blocks,
|
|
544
|
+
block_type=block_type,
|
|
545
|
+
)
|
|
358
546
|
|
|
359
|
-
def bsr_diag(diag: wp.array):
|
|
360
|
-
"""Creates a square block-diagonal BSR matrix from the values array `diag`"""
|
|
361
|
-
A = bsr_zeros(rows_of_blocks=diag.shape[0], cols_of_blocks=diag.shape[0], block_type=diag.dtype, device=diag.device)
|
|
362
547
|
bsr_set_diag(A, diag)
|
|
363
548
|
return A
|
|
364
549
|
|
|
365
550
|
|
|
551
|
+
def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
|
|
552
|
+
"""Sets `A` as the identity matrix
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
A: the sparse matrix to modify
|
|
556
|
+
rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
|
|
557
|
+
"""
|
|
558
|
+
|
|
559
|
+
if A.block_shape == (1, 1):
|
|
560
|
+
identity = A.scalar_type(1.0)
|
|
561
|
+
else:
|
|
562
|
+
from numpy import eye
|
|
563
|
+
|
|
564
|
+
identity = eye(A.block_shape[0])
|
|
565
|
+
|
|
566
|
+
bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def bsr_identity(
|
|
570
|
+
rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
|
|
571
|
+
) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
|
|
572
|
+
"""Creates and returns a square identity matrix.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
rows_of_blocks: Number of rows and columns of blocks in the created matrix.
|
|
576
|
+
block_type: Block type for the newly created matrix -- must be square
|
|
577
|
+
device: Device onto which to allocate the data arrays
|
|
578
|
+
"""
|
|
579
|
+
A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
|
|
580
|
+
bsr_set_identity(A)
|
|
581
|
+
return A
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
@wp.kernel
|
|
585
|
+
def _bsr_scale_kernel(
|
|
586
|
+
alpha: Any,
|
|
587
|
+
values: wp.array(dtype=Any),
|
|
588
|
+
):
|
|
589
|
+
values[wp.tid()] = alpha * values[wp.tid()]
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
|
|
593
|
+
"""
|
|
594
|
+
Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
if alpha != 1.0 and x.nnz > 0:
|
|
598
|
+
if alpha == 0.0:
|
|
599
|
+
bsr_set_zero(x)
|
|
600
|
+
else:
|
|
601
|
+
if not isinstance(alpha, x.scalar_type):
|
|
602
|
+
alpha = x.scalar_type(alpha)
|
|
603
|
+
|
|
604
|
+
wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
|
|
605
|
+
|
|
606
|
+
return x
|
|
607
|
+
|
|
608
|
+
|
|
366
609
|
@wp.kernel
|
|
367
610
|
def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
|
|
368
611
|
i = wp.tid()
|
|
@@ -393,16 +636,75 @@ def _bsr_axpy_add_block(
|
|
|
393
636
|
dst_values[block] = dst_values[block] + scale * src_values[i]
|
|
394
637
|
|
|
395
638
|
|
|
396
|
-
|
|
639
|
+
class bsr_axpy_work_arrays:
|
|
640
|
+
"""Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
|
|
641
|
+
|
|
642
|
+
def __init__(self):
|
|
643
|
+
self._reset(None)
|
|
644
|
+
|
|
645
|
+
def _reset(self, device):
|
|
646
|
+
self.device = device
|
|
647
|
+
self._sum_rows = None
|
|
648
|
+
self._sum_cols = None
|
|
649
|
+
self._old_y_values = None
|
|
650
|
+
self._old_x_values = None
|
|
651
|
+
|
|
652
|
+
def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
|
|
653
|
+
if self.device != device:
|
|
654
|
+
self._reset(device)
|
|
655
|
+
|
|
656
|
+
if self._sum_rows is None or self._sum_rows.size < sum_nnz:
|
|
657
|
+
self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
658
|
+
if self._sum_cols is None or self._sum_cols.size < sum_nnz:
|
|
659
|
+
self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
660
|
+
|
|
661
|
+
if self._old_y_values is None or self._old_y_values.size < y.nnz:
|
|
662
|
+
self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def bsr_axpy(
|
|
666
|
+
x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
667
|
+
y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
668
|
+
alpha: Scalar = 1.0,
|
|
669
|
+
beta: Scalar = 1.0,
|
|
670
|
+
work_arrays: Optional[bsr_axpy_work_arrays] = None,
|
|
671
|
+
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
397
672
|
"""
|
|
398
|
-
Performs the
|
|
673
|
+
Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
|
|
674
|
+
|
|
675
|
+
The `x` and `y` matrices are allowed to alias.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
x: Read-only right-hand-side.
|
|
679
|
+
y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
|
|
680
|
+
alpha: Uniform scaling factor for `x`
|
|
681
|
+
beta: Uniform scaling factor for `y`
|
|
682
|
+
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
|
|
399
683
|
"""
|
|
400
684
|
|
|
401
685
|
if y is None:
|
|
402
|
-
|
|
686
|
+
# If not output matrix is provided, allocate it for convenience
|
|
687
|
+
y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
|
|
403
688
|
beta = 0.0
|
|
404
689
|
|
|
405
|
-
|
|
690
|
+
# Handle easy cases first
|
|
691
|
+
if beta == 0.0 or y.nnz == 0:
|
|
692
|
+
bsr_assign(src=x, dest=y)
|
|
693
|
+
return bsr_scale(y, alpha=alpha)
|
|
694
|
+
|
|
695
|
+
if alpha == 0.0 or x.nnz == 0:
|
|
696
|
+
return bsr_scale(y, alpha=beta)
|
|
697
|
+
|
|
698
|
+
if not isinstance(alpha, y.scalar_type):
|
|
699
|
+
alpha = y.scalar_type(alpha)
|
|
700
|
+
if not isinstance(beta, y.scalar_type):
|
|
701
|
+
beta = y.scalar_type(beta)
|
|
702
|
+
|
|
703
|
+
if x == y:
|
|
704
|
+
# Aliasing case
|
|
705
|
+
return bsr_scale(y, alpha=alpha.value + beta.value)
|
|
706
|
+
|
|
707
|
+
# General case
|
|
406
708
|
|
|
407
709
|
if x.values.device != y.values.device:
|
|
408
710
|
raise ValueError("All arguments must reside on the same device")
|
|
@@ -413,20 +715,21 @@ def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
|
|
|
413
715
|
if x.nrow != y.nrow or x.ncol != y.ncol:
|
|
414
716
|
raise ValueError("Matrices must have the same number of rows and columns")
|
|
415
717
|
|
|
416
|
-
|
|
417
|
-
|
|
718
|
+
if work_arrays is None:
|
|
719
|
+
work_arrays = bsr_axpy_work_arrays()
|
|
418
720
|
|
|
419
721
|
sum_nnz = x.nnz + y.nnz
|
|
420
|
-
|
|
421
|
-
|
|
722
|
+
device = y.values.device
|
|
723
|
+
work_arrays._allocate(device, y, sum_nnz)
|
|
724
|
+
|
|
725
|
+
wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
|
|
726
|
+
wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
|
|
422
727
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, sum_rows])
|
|
728
|
+
wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
|
|
729
|
+
wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
|
|
426
730
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, sum_rows])
|
|
731
|
+
# Save old y values before overwriting matrix
|
|
732
|
+
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
|
|
430
733
|
|
|
431
734
|
# Increase dest array sizes if needed
|
|
432
735
|
if y.columns.shape[0] < sum_nnz:
|
|
@@ -439,37 +742,55 @@ def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
|
|
|
439
742
|
else:
|
|
440
743
|
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
441
744
|
|
|
442
|
-
|
|
745
|
+
old_y_nnz = y.nnz
|
|
746
|
+
y.nnz = native_func(
|
|
443
747
|
y.block_shape[0],
|
|
444
748
|
y.block_shape[1],
|
|
445
749
|
y.nrow,
|
|
446
750
|
sum_nnz,
|
|
447
|
-
|
|
448
|
-
|
|
751
|
+
work_arrays._sum_rows.ptr,
|
|
752
|
+
work_arrays._sum_cols.ptr,
|
|
449
753
|
0,
|
|
450
754
|
y.offsets.ptr,
|
|
451
755
|
y.columns.ptr,
|
|
452
756
|
0,
|
|
453
757
|
)
|
|
454
758
|
|
|
455
|
-
|
|
759
|
+
_bsr_ensure_fits(y)
|
|
760
|
+
y.values.zero_()
|
|
456
761
|
|
|
457
762
|
wp.launch(
|
|
458
763
|
kernel=_bsr_axpy_add_block,
|
|
459
764
|
device=device,
|
|
460
|
-
dim=
|
|
461
|
-
inputs=[
|
|
765
|
+
dim=old_y_nnz,
|
|
766
|
+
inputs=[
|
|
767
|
+
0,
|
|
768
|
+
beta,
|
|
769
|
+
work_arrays._sum_rows,
|
|
770
|
+
work_arrays._sum_cols,
|
|
771
|
+
y.offsets,
|
|
772
|
+
y.columns,
|
|
773
|
+
work_arrays._old_y_values,
|
|
774
|
+
y.values,
|
|
775
|
+
],
|
|
462
776
|
)
|
|
777
|
+
|
|
463
778
|
wp.launch(
|
|
464
779
|
kernel=_bsr_axpy_add_block,
|
|
465
780
|
device=device,
|
|
466
781
|
dim=x.nnz,
|
|
467
|
-
inputs=[
|
|
782
|
+
inputs=[
|
|
783
|
+
old_y_nnz,
|
|
784
|
+
alpha,
|
|
785
|
+
work_arrays._sum_rows,
|
|
786
|
+
work_arrays._sum_cols,
|
|
787
|
+
y.offsets,
|
|
788
|
+
y.columns,
|
|
789
|
+
x.values,
|
|
790
|
+
y.values,
|
|
791
|
+
],
|
|
468
792
|
)
|
|
469
793
|
|
|
470
|
-
y.values = sum_values
|
|
471
|
-
y.nnz = sum_nnz
|
|
472
|
-
|
|
473
794
|
return y
|
|
474
795
|
|
|
475
796
|
|
|
@@ -555,23 +876,77 @@ def _bsr_mm_compute_values(
|
|
|
555
876
|
mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
|
|
556
877
|
|
|
557
878
|
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
def
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
879
|
+
class bsr_mm_work_arrays:
|
|
880
|
+
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
|
|
881
|
+
|
|
882
|
+
def __init__(self):
|
|
883
|
+
self._reset(None)
|
|
884
|
+
|
|
885
|
+
def _reset(self, device):
|
|
886
|
+
self.device = device
|
|
887
|
+
self._pinned_count_buffer = None
|
|
888
|
+
self._mm_row_counts = None
|
|
889
|
+
self._mm_rows = None
|
|
890
|
+
self._mm_cols = None
|
|
891
|
+
self._old_z_values = None
|
|
892
|
+
self._old_z_offsets = None
|
|
893
|
+
self._old_z_columns = None
|
|
894
|
+
|
|
895
|
+
def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
|
|
896
|
+
if self.device != device:
|
|
897
|
+
self._reset(device)
|
|
898
|
+
|
|
899
|
+
# Allocations that do not depend on any computation
|
|
900
|
+
if self.device.is_cuda:
|
|
901
|
+
if self._pinned_count_buffer is None:
|
|
902
|
+
self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
|
|
903
|
+
|
|
904
|
+
if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
|
|
905
|
+
self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
|
|
906
|
+
|
|
907
|
+
if copied_z_nnz > 0:
|
|
908
|
+
if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
|
|
909
|
+
self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
|
|
910
|
+
|
|
911
|
+
if z_aliasing:
|
|
912
|
+
if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
|
|
913
|
+
self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
|
|
914
|
+
if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
|
|
915
|
+
self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
|
|
916
|
+
|
|
917
|
+
def _allocate_stage_2(self, mm_nnz: int):
|
|
918
|
+
# Allocations that depend on unmerged nnz estimate
|
|
919
|
+
if self._mm_rows is None or self._mm_rows.size < mm_nnz:
|
|
920
|
+
self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
921
|
+
if self._mm_cols is None or self._mm_cols.size < mm_nnz:
|
|
922
|
+
self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
def bsr_mm(
|
|
926
|
+
x: BsrMatrix[BlockType[Rows, Any, Scalar]],
|
|
927
|
+
y: BsrMatrix[BlockType[Any, Cols, Scalar]],
|
|
928
|
+
z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
929
|
+
alpha: Scalar = 1.0,
|
|
930
|
+
beta: Scalar = 0.0,
|
|
931
|
+
work_arrays: Optional[bsr_mm_work_arrays] = None,
|
|
932
|
+
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
570
933
|
"""
|
|
571
|
-
Performs the
|
|
934
|
+
Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
|
|
935
|
+
|
|
936
|
+
The `x`, `y` and `z` matrices are allowed to alias.
|
|
937
|
+
If the matrix `z` is not provided as input, it will be allocated and treated as zero.
|
|
938
|
+
|
|
939
|
+
Args:
|
|
940
|
+
x: Read-only left factor of the matrix-matrix product.
|
|
941
|
+
y: Read-only right factor of the matrix-matrix product.
|
|
942
|
+
z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
|
|
943
|
+
alpha: Uniform scaling factor for the ``x * y`` product
|
|
944
|
+
beta: Uniform scaling factor for `z`
|
|
945
|
+
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
|
|
572
946
|
"""
|
|
573
947
|
|
|
574
948
|
if z is None:
|
|
949
|
+
# If not output matrix is provided, allocate it for convenience
|
|
575
950
|
z_block_shape = (x.block_shape[0], y.block_shape[1])
|
|
576
951
|
if z_block_shape == (1, 1):
|
|
577
952
|
z_block_type = x.scalar_type
|
|
@@ -586,52 +961,85 @@ def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0,
|
|
|
586
961
|
if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
|
|
587
962
|
raise ValueError("Matrices must have the same scalar type")
|
|
588
963
|
|
|
589
|
-
if
|
|
590
|
-
|
|
964
|
+
if (
|
|
965
|
+
x.block_shape[0] != z.block_shape[0]
|
|
966
|
+
or y.block_shape[1] != z.block_shape[1]
|
|
967
|
+
or x.block_shape[1] != y.block_shape[0]
|
|
968
|
+
):
|
|
969
|
+
raise ValueError("Incompatible block sizes for matrix multiplication")
|
|
591
970
|
|
|
592
|
-
if x.nrow != z.nrow or z.ncol != y.ncol:
|
|
971
|
+
if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
|
|
593
972
|
raise ValueError("Incompatible number of rows/columns for matrix multiplication")
|
|
594
973
|
|
|
595
974
|
device = z.values.device
|
|
596
975
|
|
|
597
|
-
alpha
|
|
598
|
-
|
|
976
|
+
if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
|
|
977
|
+
# Easy case
|
|
978
|
+
return bsr_scale(z, beta)
|
|
979
|
+
|
|
980
|
+
if not isinstance(alpha, z.scalar_type):
|
|
981
|
+
alpha = z.scalar_type(alpha)
|
|
982
|
+
if not isinstance(beta, z.scalar_type):
|
|
983
|
+
beta = z.scalar_type(beta)
|
|
984
|
+
|
|
985
|
+
if work_arrays is None:
|
|
986
|
+
work_arrays = bsr_mm_work_arrays()
|
|
987
|
+
|
|
988
|
+
z_aliasing = z == x or z == y
|
|
989
|
+
copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
|
|
990
|
+
|
|
991
|
+
work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
|
|
599
992
|
|
|
600
993
|
# Prefix sum of number of (unmerged) mm blocks per row
|
|
601
|
-
mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=device)
|
|
602
994
|
wp.launch(
|
|
603
995
|
kernel=_bsr_mm_count_coeffs,
|
|
604
996
|
device=device,
|
|
605
997
|
dim=z.nrow,
|
|
606
|
-
inputs=[
|
|
998
|
+
inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
|
|
607
999
|
)
|
|
608
|
-
warp.utils.array_scan(
|
|
1000
|
+
warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
|
|
609
1001
|
|
|
610
1002
|
# Get back total counts on host
|
|
611
1003
|
if device.is_cuda:
|
|
612
|
-
|
|
613
|
-
wp.
|
|
614
|
-
|
|
615
|
-
mm_nnz = int(mm_tot_count.numpy()[0])
|
|
1004
|
+
wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
|
|
1005
|
+
wp.synchronize_stream(wp.get_stream(device))
|
|
1006
|
+
mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
|
|
616
1007
|
else:
|
|
617
|
-
mm_nnz = int(
|
|
1008
|
+
mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
|
|
618
1009
|
|
|
619
|
-
|
|
620
|
-
mm_cols = wp.empty(shape=(mm_nnz), dtype=int, device=device)
|
|
1010
|
+
work_arrays._allocate_stage_2(mm_nnz)
|
|
621
1011
|
|
|
622
|
-
#
|
|
623
|
-
|
|
624
|
-
|
|
1012
|
+
# If z has a non-zero scale, save current data before overwriting it
|
|
1013
|
+
if copied_z_nnz > 0:
|
|
1014
|
+
# Copy z row and column indices
|
|
1015
|
+
wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
|
|
1016
|
+
wp.launch(
|
|
1017
|
+
kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
|
|
1018
|
+
)
|
|
1019
|
+
# Save current z values in temporary buffer
|
|
1020
|
+
wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
|
|
1021
|
+
if z_aliasing:
|
|
1022
|
+
# If z is aliasing with x or y, need to save topology as well
|
|
1023
|
+
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
1024
|
+
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
625
1025
|
|
|
626
1026
|
# Fill unmerged mm blocks rows and columns
|
|
627
1027
|
wp.launch(
|
|
628
1028
|
kernel=_bsr_mm_list_coeffs,
|
|
629
1029
|
device=device,
|
|
630
1030
|
dim=z.nrow,
|
|
631
|
-
inputs=[
|
|
1031
|
+
inputs=[
|
|
1032
|
+
x.offsets,
|
|
1033
|
+
x.columns,
|
|
1034
|
+
y.offsets,
|
|
1035
|
+
y.columns,
|
|
1036
|
+
work_arrays._mm_row_counts,
|
|
1037
|
+
work_arrays._mm_rows,
|
|
1038
|
+
work_arrays._mm_cols,
|
|
1039
|
+
],
|
|
632
1040
|
)
|
|
633
1041
|
|
|
634
|
-
# Increase dest array
|
|
1042
|
+
# Increase dest array size if needed
|
|
635
1043
|
if z.columns.shape[0] < mm_nnz:
|
|
636
1044
|
z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
|
|
637
1045
|
|
|
@@ -642,40 +1050,68 @@ def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0,
|
|
|
642
1050
|
else:
|
|
643
1051
|
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
644
1052
|
|
|
645
|
-
|
|
1053
|
+
z.nnz = native_func(
|
|
646
1054
|
z.block_shape[0],
|
|
647
1055
|
z.block_shape[1],
|
|
648
1056
|
z.nrow,
|
|
649
1057
|
mm_nnz,
|
|
650
|
-
|
|
651
|
-
|
|
1058
|
+
work_arrays._mm_rows.ptr,
|
|
1059
|
+
work_arrays._mm_cols.ptr,
|
|
652
1060
|
0,
|
|
653
1061
|
z.offsets.ptr,
|
|
654
1062
|
z.columns.ptr,
|
|
655
1063
|
0,
|
|
656
1064
|
)
|
|
657
1065
|
|
|
658
|
-
|
|
1066
|
+
_bsr_ensure_fits(z)
|
|
1067
|
+
z.values.zero_()
|
|
1068
|
+
|
|
1069
|
+
if copied_z_nnz > 0:
|
|
1070
|
+
# Add back original z values
|
|
1071
|
+
wp.launch(
|
|
1072
|
+
kernel=_bsr_axpy_add_block,
|
|
1073
|
+
device=device,
|
|
1074
|
+
dim=copied_z_nnz,
|
|
1075
|
+
inputs=[
|
|
1076
|
+
0,
|
|
1077
|
+
beta,
|
|
1078
|
+
work_arrays._mm_rows,
|
|
1079
|
+
work_arrays._mm_cols,
|
|
1080
|
+
z.offsets,
|
|
1081
|
+
z.columns,
|
|
1082
|
+
work_arrays._old_z_values,
|
|
1083
|
+
z.values,
|
|
1084
|
+
],
|
|
1085
|
+
)
|
|
659
1086
|
|
|
660
|
-
#
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
1087
|
+
# Add mm blocks to z values
|
|
1088
|
+
if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
|
|
1089
|
+
warp.types.type_is_matrix(z.values.dtype)
|
|
1090
|
+
):
|
|
1091
|
+
# Result block type is scalar, but operands are matrices
|
|
1092
|
+
# Cast result to (1x1) matrix to perform multiplication
|
|
1093
|
+
mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
|
|
1094
|
+
else:
|
|
1095
|
+
mm_values = z.values
|
|
667
1096
|
|
|
668
|
-
# Add mm blocks
|
|
669
1097
|
wp.launch(
|
|
670
1098
|
kernel=_bsr_mm_compute_values,
|
|
671
1099
|
device=device,
|
|
672
1100
|
dim=z.nrow,
|
|
673
|
-
inputs=[
|
|
1101
|
+
inputs=[
|
|
1102
|
+
alpha,
|
|
1103
|
+
work_arrays._old_z_offsets if x == z else x.offsets,
|
|
1104
|
+
work_arrays._old_z_columns if x == z else x.columns,
|
|
1105
|
+
work_arrays._old_z_values if x == z else x.values,
|
|
1106
|
+
work_arrays._old_z_offsets if y == z else y.offsets,
|
|
1107
|
+
work_arrays._old_z_columns if y == z else y.columns,
|
|
1108
|
+
work_arrays._old_z_values if y == z else y.values,
|
|
1109
|
+
z.offsets,
|
|
1110
|
+
z.columns,
|
|
1111
|
+
mm_values,
|
|
1112
|
+
],
|
|
674
1113
|
)
|
|
675
1114
|
|
|
676
|
-
z.values = mm_values
|
|
677
|
-
z.nnz = mm_nnz
|
|
678
|
-
|
|
679
1115
|
return z
|
|
680
1116
|
|
|
681
1117
|
|
|
@@ -690,44 +1126,96 @@ def _bsr_mv_kernel(
|
|
|
690
1126
|
y: wp.array(dtype=Any),
|
|
691
1127
|
):
|
|
692
1128
|
row = wp.tid()
|
|
693
|
-
beg = A_offsets[row]
|
|
694
|
-
end = A_offsets[row + 1]
|
|
695
1129
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
v = v + A_values[block] * x[A_columns[block]]
|
|
1130
|
+
# zero-initialize with type of y elements
|
|
1131
|
+
scalar_zero = type(alpha)(0)
|
|
1132
|
+
v = y.dtype(scalar_zero)
|
|
700
1133
|
|
|
701
|
-
|
|
1134
|
+
if alpha != scalar_zero:
|
|
1135
|
+
beg = A_offsets[row]
|
|
1136
|
+
end = A_offsets[row + 1]
|
|
1137
|
+
for block in range(beg, end):
|
|
1138
|
+
v += A_values[block] * x[A_columns[block]]
|
|
1139
|
+
v *= alpha
|
|
702
1140
|
|
|
1141
|
+
if beta != scalar_zero:
|
|
1142
|
+
v += beta * y[row]
|
|
703
1143
|
|
|
704
|
-
|
|
1144
|
+
y[row] = v
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
def bsr_mv(
|
|
1148
|
+
A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
1149
|
+
x: "Array[Vector[Cols, Scalar] | Scalar]",
|
|
1150
|
+
y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1151
|
+
alpha: Scalar = 1.0,
|
|
1152
|
+
beta: Scalar = 0.0,
|
|
1153
|
+
work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1154
|
+
) -> "Array[Vector[Rows, Scalar] | Scalar]":
|
|
705
1155
|
"""
|
|
706
|
-
|
|
1156
|
+
Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
|
|
1157
|
+
|
|
1158
|
+
The `x` and `y` vectors are allowed to alias.
|
|
1159
|
+
|
|
1160
|
+
Args:
|
|
1161
|
+
A: Read-only, left matrix factor of the matrix-vector product.
|
|
1162
|
+
x: Read-only, right vector factor of the matrix-vector product.
|
|
1163
|
+
y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
|
|
1164
|
+
alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
|
|
1165
|
+
beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
|
|
1166
|
+
work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
|
|
1167
|
+
will be used for this purpose, otherwise a temporary allocation will be performed.
|
|
707
1168
|
"""
|
|
708
|
-
alpha = A.scalar_type(alpha)
|
|
709
|
-
beta = A.scalar_type(beta)
|
|
710
1169
|
|
|
711
|
-
|
|
712
|
-
|
|
1170
|
+
if y is None:
|
|
1171
|
+
# If no output array is provided, allocate one for convenience
|
|
1172
|
+
y_vec_len = A.block_shape[0]
|
|
1173
|
+
y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
|
|
1174
|
+
y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
|
|
1175
|
+
y.zero_()
|
|
1176
|
+
beta = 0.0
|
|
1177
|
+
|
|
1178
|
+
if not isinstance(alpha, A.scalar_type):
|
|
1179
|
+
alpha = A.scalar_type(alpha)
|
|
1180
|
+
if not isinstance(beta, A.scalar_type):
|
|
1181
|
+
beta = A.scalar_type(beta)
|
|
713
1182
|
|
|
714
1183
|
if A.values.device != x.device or A.values.device != y.device:
|
|
715
|
-
raise ValueError("A, x and y must
|
|
1184
|
+
raise ValueError("A, x and y must reside on the same device")
|
|
716
1185
|
|
|
717
1186
|
if x.shape[0] != A.ncol:
|
|
718
1187
|
raise ValueError("Number of columns of A must match number of rows of x")
|
|
719
1188
|
if y.shape[0] != A.nrow:
|
|
720
1189
|
raise ValueError("Number of rows of A must match number of rows of y")
|
|
721
1190
|
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
1191
|
+
if x == y:
|
|
1192
|
+
# Aliasing case, need temporary storage
|
|
1193
|
+
if work_buffer is None:
|
|
1194
|
+
work_buffer = wp.empty_like(y)
|
|
1195
|
+
elif work_buffer.size < y.size:
|
|
1196
|
+
raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
|
|
1197
|
+
elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
|
|
1198
|
+
raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
|
|
1199
|
+
|
|
1200
|
+
# Save old y values before overwriting vector
|
|
1201
|
+
wp.copy(dest=work_buffer, src=y, count=y.size)
|
|
1202
|
+
x = work_buffer
|
|
1203
|
+
|
|
1204
|
+
# Promote scalar vectors to length-1 vecs and conversely
|
|
1205
|
+
if warp.types.type_is_matrix(A.values.dtype):
|
|
1206
|
+
if A.block_shape[0] == 1:
|
|
726
1207
|
if y.dtype == A.scalar_type:
|
|
727
1208
|
y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
728
|
-
if block_shape[1] == 1:
|
|
1209
|
+
if A.block_shape[1] == 1:
|
|
729
1210
|
if x.dtype == A.scalar_type:
|
|
730
1211
|
x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
1212
|
+
else:
|
|
1213
|
+
if A.block_shape[0] == 1:
|
|
1214
|
+
if y.dtype != A.scalar_type:
|
|
1215
|
+
y = y.view(dtype=A.scalar_type)
|
|
1216
|
+
if A.block_shape[1] == 1:
|
|
1217
|
+
if x.dtype != A.scalar_type:
|
|
1218
|
+
x = x.view(dtype=A.scalar_type)
|
|
731
1219
|
|
|
732
1220
|
wp.launch(
|
|
733
1221
|
kernel=_bsr_mv_kernel,
|
|
@@ -735,3 +1223,5 @@ def bsr_mv(A: BsrMatrix, x: wp.array, y: wp.array, alpha: float = 1.0, beta: flo
|
|
|
735
1223
|
dim=A.nrow,
|
|
736
1224
|
inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
|
|
737
1225
|
)
|
|
1226
|
+
|
|
1227
|
+
return y
|