warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- docs/conf.py +17 -5
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/env/env_usd.py +4 -1
- examples/env/environment.py +8 -9
- examples/example_dem.py +34 -33
- examples/example_diffray.py +364 -337
- examples/example_fluid.py +32 -23
- examples/example_jacobian_ik.py +97 -93
- examples/example_marching_cubes.py +6 -16
- examples/example_mesh.py +6 -16
- examples/example_mesh_intersect.py +16 -14
- examples/example_nvdb.py +14 -16
- examples/example_raycast.py +14 -13
- examples/example_raymarch.py +16 -23
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +82 -78
- examples/example_sim_cloth.py +45 -48
- examples/example_sim_fk_grad.py +51 -44
- examples/example_sim_fk_grad_torch.py +47 -40
- examples/example_sim_grad_bounce.py +108 -133
- examples/example_sim_grad_cloth.py +99 -113
- examples/example_sim_granular.py +5 -6
- examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
- examples/example_sim_neo_hookean.py +51 -55
- examples/example_sim_particle_chain.py +4 -4
- examples/example_sim_quadruped.py +126 -81
- examples/example_sim_rigid_chain.py +54 -61
- examples/example_sim_rigid_contact.py +66 -70
- examples/example_sim_rigid_fem.py +3 -3
- examples/example_sim_rigid_force.py +1 -1
- examples/example_sim_rigid_gyroscopic.py +3 -4
- examples/example_sim_rigid_kinematics.py +28 -39
- examples/example_sim_trajopt.py +112 -110
- examples/example_sph.py +9 -8
- examples/example_wave.py +7 -7
- examples/fem/bsr_utils.py +30 -17
- examples/fem/example_apic_fluid.py +85 -69
- examples/fem/example_convection_diffusion.py +97 -93
- examples/fem/example_convection_diffusion_dg.py +142 -149
- examples/fem/example_convection_diffusion_dg0.py +141 -136
- examples/fem/example_deformed_geometry.py +146 -0
- examples/fem/example_diffusion.py +115 -84
- examples/fem/example_diffusion_3d.py +116 -86
- examples/fem/example_diffusion_mgpu.py +102 -79
- examples/fem/example_mixed_elasticity.py +139 -100
- examples/fem/example_navier_stokes.py +175 -162
- examples/fem/example_stokes.py +143 -111
- examples/fem/example_stokes_transfer.py +186 -157
- examples/fem/mesh_utils.py +59 -97
- examples/fem/plot_utils.py +138 -17
- tools/ci/publishing/build_nodes_info.py +54 -0
- warp/__init__.py +4 -3
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +836 -492
- warp/codegen.py +864 -553
- warp/config.py +3 -1
- warp/context.py +389 -172
- warp/fem/__init__.py +24 -6
- warp/fem/cache.py +318 -25
- warp/fem/dirichlet.py +7 -3
- warp/fem/domain.py +14 -0
- warp/fem/field/__init__.py +30 -38
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +244 -138
- warp/fem/field/restriction.py +8 -6
- warp/fem/field/test.py +127 -59
- warp/fem/field/trial.py +117 -60
- warp/fem/geometry/__init__.py +5 -1
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +24 -1
- warp/fem/geometry/geometry.py +86 -14
- warp/fem/geometry/grid_2d.py +112 -54
- warp/fem/geometry/grid_3d.py +134 -65
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +85 -33
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +451 -115
- warp/fem/geometry/trimesh_2d.py +197 -92
- warp/fem/integrate.py +534 -268
- warp/fem/operator.py +58 -31
- warp/fem/polynomial.py +11 -0
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +150 -58
- warp/fem/quadrature/quadrature.py +209 -57
- warp/fem/space/__init__.py +230 -53
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +49 -2
- warp/fem/space/function_space.py +90 -39
- warp/fem/space/grid_2d_function_space.py +149 -496
- warp/fem/space/grid_3d_function_space.py +173 -538
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +129 -76
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +46 -34
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +132 -1039
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +104 -742
- warp/fem/types.py +13 -11
- warp/fem/utils.py +335 -60
- warp/native/array.h +120 -34
- warp/native/builtin.h +101 -72
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +22 -40
- warp/native/clang/clang.cpp +1 -0
- warp/native/crt.h +2 -0
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1522 -1243
- warp/native/intersect.h +19 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +76 -17
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -18
- warp/native/mesh.h +395 -40
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +44 -34
- warp/native/reduce.cpp +1 -1
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +163 -155
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +18 -14
- warp/native/vec.h +103 -21
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +28 -3
- warp/native/warp.h +4 -3
- warp/render/render_opengl.py +261 -109
- warp/sim/__init__.py +1 -2
- warp/sim/articulation.py +385 -185
- warp/sim/import_mjcf.py +59 -48
- warp/sim/import_urdf.py +15 -15
- warp/sim/import_usd.py +174 -102
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_xpbd.py +4 -3
- warp/sim/model.py +330 -250
- warp/sim/render.py +1 -1
- warp/sparse.py +625 -152
- warp/stubs.py +341 -309
- warp/tape.py +9 -6
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +94 -74
- warp/tests/test_array.py +82 -101
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +22 -12
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +18 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +165 -134
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +237 -0
- warp/tests/test_fabricarray.py +22 -24
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1034 -124
- warp/tests/test_fp16.py +23 -16
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +123 -181
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +35 -34
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +24 -25
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +304 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +60 -22
- warp/tests/test_mesh_query_aabb.py +21 -25
- warp/tests/test_mesh_query_point.py +111 -22
- warp/tests/test_mesh_query_ray.py +12 -24
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +90 -86
- warp/tests/test_transient_module.py +10 -12
- warp/tests/test_types.py +363 -0
- warp/tests/test_utils.py +451 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +418 -376
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/unittest_utils.py +342 -0
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +589 -0
- warp/types.py +622 -211
- warp/utils.py +54 -393
- warp_lang-1.0.0b6.dist-info/METADATA +238 -0
- warp_lang-1.0.0b6.dist-info/RECORD +409 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- examples/example_cache_management.py +0 -40
- examples/example_multigpu.py +0 -54
- examples/example_struct.py +0 -65
- examples/fem/example_stokes_transfer_3d.py +0 -210
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/fem/field/discrete_field.py +0 -80
- warp/fem/space/nodal_function_space.py +0 -233
- warp/tests/test_all.py +0 -223
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-1.0.0b2.dist-info/METADATA +0 -26
- warp_lang-1.0.0b2.dist-info/RECORD +0 -380
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/sparse.py
CHANGED
|
@@ -1,14 +1,29 @@
|
|
|
1
|
+
from typing import Any, Generic, Optional, Tuple, TypeVar, Union
|
|
2
|
+
|
|
1
3
|
import warp as wp
|
|
2
4
|
import warp.types
|
|
3
5
|
import warp.utils
|
|
6
|
+
from warp.types import Array, Cols, Matrix, Rows, Scalar, Vector
|
|
7
|
+
|
|
8
|
+
# typing hints
|
|
9
|
+
|
|
10
|
+
_BlockType = TypeVar("BlockType")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _MatrixBlockType(Matrix):
|
|
14
|
+
pass
|
|
4
15
|
|
|
5
|
-
from typing import Tuple, Any, Union
|
|
6
16
|
|
|
17
|
+
class _ScalarBlockType(Generic[Scalar]):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
|
|
7
22
|
|
|
8
23
|
_struct_cache = dict()
|
|
9
24
|
|
|
10
25
|
|
|
11
|
-
class BsrMatrix:
|
|
26
|
+
class BsrMatrix(Generic[_BlockType]):
|
|
12
27
|
"""Untyped base class for BSR and CSR matrices.
|
|
13
28
|
|
|
14
29
|
Should not be constructed directly but through functions such as :func:`bsr_zeros`.
|
|
@@ -16,15 +31,15 @@ class BsrMatrix:
|
|
|
16
31
|
Attributes:
|
|
17
32
|
nrow (int): Number of rows of blocks
|
|
18
33
|
ncol (int): Number of columns of blocks
|
|
19
|
-
nnz (int): Number of non-zero blocks: equal to
|
|
20
|
-
offsets (
|
|
21
|
-
columns (
|
|
22
|
-
values (
|
|
34
|
+
nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
|
|
35
|
+
offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
|
|
36
|
+
columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
|
|
37
|
+
values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
|
|
23
38
|
"""
|
|
24
39
|
|
|
25
40
|
@property
|
|
26
|
-
def scalar_type(self) ->
|
|
27
|
-
"""Scalar type for
|
|
41
|
+
def scalar_type(self) -> Scalar:
|
|
42
|
+
"""Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
|
|
28
43
|
return warp.types.type_scalar_type(self.values.dtype)
|
|
29
44
|
|
|
30
45
|
@property
|
|
@@ -33,20 +48,25 @@ class BsrMatrix:
|
|
|
33
48
|
return getattr(self.values.dtype, "_shape_", (1, 1))
|
|
34
49
|
|
|
35
50
|
@property
|
|
36
|
-
def block_size(self) ->
|
|
37
|
-
"""Size of the individual blocks, i.e. number of rows per block times number of
|
|
51
|
+
def block_size(self) -> int:
|
|
52
|
+
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
|
|
38
53
|
return warp.types.type_length(self.values.dtype)
|
|
39
54
|
|
|
40
55
|
@property
|
|
41
56
|
def shape(self) -> Tuple[int, int]:
|
|
42
|
-
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/
|
|
57
|
+
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
|
|
43
58
|
block_shape = self.block_shape
|
|
44
59
|
return (self.nrow * block_shape[0], self.ncol * block_shape[1])
|
|
45
60
|
|
|
46
61
|
|
|
47
|
-
def bsr_matrix_t(dtype:
|
|
62
|
+
def bsr_matrix_t(dtype: BlockType):
|
|
48
63
|
dtype = wp.types.type_to_warp(dtype)
|
|
49
64
|
|
|
65
|
+
if not warp.types.type_is_matrix(dtype) and not dtype in warp.types.scalar_types:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
|
|
68
|
+
)
|
|
69
|
+
|
|
50
70
|
class BsrMatrixTyped(BsrMatrix):
|
|
51
71
|
nrow: int
|
|
52
72
|
"""Number of rows of blocks"""
|
|
@@ -79,11 +99,23 @@ def bsr_matrix_t(dtype: type):
|
|
|
79
99
|
|
|
80
100
|
|
|
81
101
|
def bsr_zeros(
|
|
82
|
-
rows_of_blocks: int,
|
|
102
|
+
rows_of_blocks: int,
|
|
103
|
+
cols_of_blocks: int,
|
|
104
|
+
block_type: BlockType,
|
|
105
|
+
device: wp.context.Devicelike = None,
|
|
83
106
|
) -> BsrMatrix:
|
|
84
107
|
"""
|
|
85
|
-
Constructs an empty BSR or
|
|
108
|
+
Constructs and returns an empty BSR or CSR matrix with the given shape
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
bsr: The BSR or CSR matrix to set to zero
|
|
112
|
+
rows_of_blocks: Number of rows of blocks
|
|
113
|
+
cols_of_blocks: Number of columns of blocks
|
|
114
|
+
block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
|
|
115
|
+
for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
|
|
116
|
+
device: Device on which to allocate the matrix arrays
|
|
86
117
|
"""
|
|
118
|
+
|
|
87
119
|
bsr = bsr_matrix_t(block_type)()
|
|
88
120
|
|
|
89
121
|
bsr.nrow = rows_of_blocks
|
|
@@ -110,19 +142,42 @@ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
|
|
|
110
142
|
bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
|
|
111
143
|
|
|
112
144
|
|
|
145
|
+
def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
|
|
146
|
+
"""
|
|
147
|
+
Sets a BSR matrix to zero, possibly changing its size
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
bsr: The BSR or CSR matrix to set to zero
|
|
151
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
152
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
if rows_of_blocks is not None:
|
|
156
|
+
bsr.nrow = rows_of_blocks
|
|
157
|
+
if cols_of_blocks is not None:
|
|
158
|
+
bsr.ncol = cols_of_blocks
|
|
159
|
+
bsr.nnz = 0
|
|
160
|
+
_bsr_ensure_fits(bsr)
|
|
161
|
+
bsr.offsets.zero_()
|
|
162
|
+
|
|
163
|
+
|
|
113
164
|
def bsr_set_from_triplets(
|
|
114
|
-
dest: BsrMatrix,
|
|
115
|
-
rows:
|
|
116
|
-
columns:
|
|
117
|
-
values:
|
|
165
|
+
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
166
|
+
rows: "Array[int]",
|
|
167
|
+
columns: "Array[int]",
|
|
168
|
+
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
118
169
|
):
|
|
119
170
|
"""
|
|
120
|
-
Fills a BSR matrix
|
|
171
|
+
Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
121
172
|
|
|
122
|
-
|
|
123
|
-
or a 3d array with data type equal to the `dest` matrix scalar type.
|
|
173
|
+
The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
|
|
124
174
|
|
|
125
|
-
|
|
175
|
+
Args:
|
|
176
|
+
dest: Sparse matrix to populate
|
|
177
|
+
rows: Row index for each non-zero
|
|
178
|
+
columns: Columns index for each non-zero
|
|
179
|
+
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
180
|
+
to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
|
|
126
181
|
"""
|
|
127
182
|
|
|
128
183
|
if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
|
|
@@ -138,7 +193,7 @@ def bsr_set_from_triplets(
|
|
|
138
193
|
elif values.ndim == 3:
|
|
139
194
|
if values.shape[1:] != dest.block_shape:
|
|
140
195
|
raise ValueError(
|
|
141
|
-
f"Last two dimensions in values array ({values.shape[1:]})
|
|
196
|
+
f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
|
|
142
197
|
)
|
|
143
198
|
|
|
144
199
|
if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
|
|
@@ -150,6 +205,9 @@ def bsr_set_from_triplets(
|
|
|
150
205
|
raise ValueError("Number of dimension for values array should be 1 or 3")
|
|
151
206
|
|
|
152
207
|
nnz = rows.shape[0]
|
|
208
|
+
if nnz == 0:
|
|
209
|
+
bsr_set_zero(dest)
|
|
210
|
+
return
|
|
153
211
|
|
|
154
212
|
# Increase dest array sizes if needed
|
|
155
213
|
_bsr_ensure_fits(dest, nnz=nnz)
|
|
@@ -186,8 +244,8 @@ def bsr_set_from_triplets(
|
|
|
186
244
|
)
|
|
187
245
|
|
|
188
246
|
|
|
189
|
-
def bsr_assign(dest: BsrMatrix, src: BsrMatrix):
|
|
190
|
-
"""Copies the content of the `src` matrix to `dest`,
|
|
247
|
+
def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
|
|
248
|
+
"""Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
|
|
191
249
|
|
|
192
250
|
if dest.values.device != src.values.device:
|
|
193
251
|
raise ValueError("Source and destination matrices must reside on the same device")
|
|
@@ -207,8 +265,12 @@ def bsr_assign(dest: BsrMatrix, src: BsrMatrix):
|
|
|
207
265
|
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
|
|
208
266
|
|
|
209
267
|
|
|
210
|
-
def bsr_copy(A: BsrMatrix, scalar_type=None):
|
|
211
|
-
"""Returns a copy of matrix A
|
|
268
|
+
def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
|
|
269
|
+
"""Returns a copy of matrix ``A``, possibly changing its scalar type.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
|
|
273
|
+
"""
|
|
212
274
|
if scalar_type is None:
|
|
213
275
|
block_type = A.values.dtype
|
|
214
276
|
elif A.block_shape == (1, 1):
|
|
@@ -221,7 +283,7 @@ def bsr_copy(A: BsrMatrix, scalar_type=None):
|
|
|
221
283
|
return copy
|
|
222
284
|
|
|
223
285
|
|
|
224
|
-
def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
|
|
286
|
+
def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
|
|
225
287
|
"""Assigns the transposed matrix `src` to matrix `dest`"""
|
|
226
288
|
|
|
227
289
|
if dest.values.device != src.values.device:
|
|
@@ -230,10 +292,7 @@ def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
|
|
|
230
292
|
if dest.scalar_type != src.scalar_type:
|
|
231
293
|
raise ValueError("All arguments must have the same scalar type")
|
|
232
294
|
|
|
233
|
-
|
|
234
|
-
transpose_block_shape = (1, 1)
|
|
235
|
-
else:
|
|
236
|
-
transpose_block_shape = src.block_shape[::-1]
|
|
295
|
+
transpose_block_shape = src.block_shape[::-1]
|
|
237
296
|
|
|
238
297
|
if dest.block_shape != transpose_block_shape:
|
|
239
298
|
raise ValueError(f"Destination block shape must be {transpose_block_shape}")
|
|
@@ -242,6 +301,9 @@ def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
|
|
|
242
301
|
dest.ncol = src.nrow
|
|
243
302
|
dest.nnz = src.nnz
|
|
244
303
|
|
|
304
|
+
if src.nnz == 0:
|
|
305
|
+
return
|
|
306
|
+
|
|
245
307
|
# Increase dest array sizes if needed
|
|
246
308
|
_bsr_ensure_fits(dest)
|
|
247
309
|
|
|
@@ -301,27 +363,33 @@ def _bsr_get_diag_kernel(
|
|
|
301
363
|
end = A_offsets[row + 1]
|
|
302
364
|
|
|
303
365
|
diag = wp.lower_bound(A_columns, beg, end, row)
|
|
304
|
-
if
|
|
305
|
-
|
|
366
|
+
if diag < end:
|
|
367
|
+
if A_columns[diag] == row:
|
|
368
|
+
out[row] = A_values[diag]
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
|
|
372
|
+
"""Returns the array of blocks that constitute the diagonal of a sparse matrix.
|
|
306
373
|
|
|
374
|
+
Args:
|
|
375
|
+
A: the sparse matrix from which to extract the diagonal
|
|
376
|
+
out: if provided, the array into which to store the diagonal blocks
|
|
377
|
+
"""
|
|
307
378
|
|
|
308
|
-
|
|
309
|
-
"""Returns the block diagonal of a square sparse matrix"""
|
|
310
|
-
if A.nrow != A.ncol:
|
|
311
|
-
raise ValueError("bsr_get_diag is only available for square sparse matrices")
|
|
379
|
+
dim = min(A.nrow, A.ncol)
|
|
312
380
|
|
|
313
381
|
if out is None:
|
|
314
|
-
out = wp.zeros(shape=(
|
|
382
|
+
out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
|
|
315
383
|
else:
|
|
316
384
|
if out.dtype != A.values.dtype:
|
|
317
385
|
raise ValueError(f"Output array must have type {A.values.dtype}")
|
|
318
386
|
if out.device != A.values.device:
|
|
319
387
|
raise ValueError(f"Output array must reside on device {A.values.device}")
|
|
320
|
-
if out.shape[0] <
|
|
321
|
-
raise ValueError(f"Output array must be of length at least {
|
|
388
|
+
if out.shape[0] < dim:
|
|
389
|
+
raise ValueError(f"Output array must be of length at least {dim}")
|
|
322
390
|
|
|
323
391
|
wp.launch(
|
|
324
|
-
kernel=_bsr_get_diag_kernel, dim=
|
|
392
|
+
kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
|
|
325
393
|
)
|
|
326
394
|
|
|
327
395
|
return out
|
|
@@ -329,40 +397,205 @@ def bsr_get_diag(A: BsrMatrix, out: wp.array = None):
|
|
|
329
397
|
|
|
330
398
|
@wp.kernel
|
|
331
399
|
def _bsr_set_diag_kernel(
|
|
400
|
+
diag: wp.array(dtype=Any),
|
|
401
|
+
A_offsets: wp.array(dtype=int),
|
|
402
|
+
A_columns: wp.array(dtype=int),
|
|
403
|
+
A_values: wp.array(dtype=Any),
|
|
404
|
+
):
|
|
405
|
+
row = wp.tid()
|
|
406
|
+
A_offsets[row + 1] = row + 1
|
|
407
|
+
A_columns[row] = row
|
|
408
|
+
A_values[row] = diag[row]
|
|
409
|
+
|
|
410
|
+
if row == 0:
|
|
411
|
+
A_offsets[0] = 0
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@wp.kernel
|
|
415
|
+
def _bsr_set_diag_constant_kernel(
|
|
416
|
+
diag_value: Any,
|
|
332
417
|
A_offsets: wp.array(dtype=int),
|
|
333
418
|
A_columns: wp.array(dtype=int),
|
|
419
|
+
A_values: wp.array(dtype=Any),
|
|
334
420
|
):
|
|
335
421
|
row = wp.tid()
|
|
336
422
|
A_offsets[row + 1] = row + 1
|
|
337
423
|
A_columns[row] = row
|
|
424
|
+
A_values[row] = diag_value
|
|
338
425
|
|
|
339
426
|
if row == 0:
|
|
340
427
|
A_offsets[0] = 0
|
|
341
428
|
|
|
342
429
|
|
|
343
|
-
def bsr_set_diag(
|
|
344
|
-
|
|
430
|
+
def bsr_set_diag(
|
|
431
|
+
A: BsrMatrix[BlockType],
|
|
432
|
+
diag: "Union[BlockType, Array[BlockType]]",
|
|
433
|
+
rows_of_blocks: Optional[int] = None,
|
|
434
|
+
cols_of_blocks: Optional[int] = None,
|
|
435
|
+
):
|
|
436
|
+
"""Sets `A` as a block-diagonal matrix
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
A: the sparse matrix to modify
|
|
440
|
+
diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
|
|
441
|
+
or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
|
|
442
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
443
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
444
|
+
|
|
445
|
+
The shape of the matrix will be defined one of the following, in that order:
|
|
446
|
+
- `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
|
|
447
|
+
- the first dimension of `diag`, if `diag` is an array
|
|
448
|
+
- the current dimensions of `A` otherwise
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
452
|
+
rows_of_blocks = cols_of_blocks
|
|
453
|
+
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
454
|
+
cols_of_blocks = rows_of_blocks
|
|
455
|
+
|
|
456
|
+
if warp.types.is_array(diag):
|
|
457
|
+
if rows_of_blocks is None:
|
|
458
|
+
rows_of_blocks = diag.shape[0]
|
|
459
|
+
cols_of_blocks = diag.shape[0]
|
|
460
|
+
|
|
461
|
+
if rows_of_blocks is not None:
|
|
462
|
+
A.nrow = rows_of_blocks
|
|
463
|
+
A.ncol = cols_of_blocks
|
|
464
|
+
|
|
465
|
+
A.nnz = min(A.nrow, A.ncol)
|
|
466
|
+
_bsr_ensure_fits(A)
|
|
467
|
+
|
|
468
|
+
if warp.types.is_array(diag):
|
|
469
|
+
wp.launch(
|
|
470
|
+
kernel=_bsr_set_diag_kernel,
|
|
471
|
+
dim=A.nnz,
|
|
472
|
+
device=A.values.device,
|
|
473
|
+
inputs=[diag, A.offsets, A.columns, A.values],
|
|
474
|
+
)
|
|
475
|
+
else:
|
|
476
|
+
if not warp.types.type_is_value(type(diag)):
|
|
477
|
+
# Cast to launchable type
|
|
478
|
+
diag = A.values.dtype(diag)
|
|
479
|
+
wp.launch(
|
|
480
|
+
kernel=_bsr_set_diag_constant_kernel,
|
|
481
|
+
dim=A.nnz,
|
|
482
|
+
device=A.values.device,
|
|
483
|
+
inputs=[diag, A.offsets, A.columns, A.values],
|
|
484
|
+
)
|
|
345
485
|
|
|
346
|
-
A.nrow = diag.shape[0]
|
|
347
|
-
A.ncol = diag.shape[0]
|
|
348
|
-
A.nnz = diag.shape[0]
|
|
349
486
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
487
|
+
def bsr_diag(
|
|
488
|
+
diag: "Union[BlockType, Array[BlockType]]",
|
|
489
|
+
rows_of_blocks: Optional[int] = None,
|
|
490
|
+
cols_of_blocks: Optional[int] = None,
|
|
491
|
+
) -> BsrMatrix["BlockType"]:
|
|
492
|
+
"""Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
|
|
355
493
|
|
|
356
|
-
|
|
494
|
+
Args:
|
|
495
|
+
diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
|
|
496
|
+
or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
|
|
497
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
498
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
357
499
|
|
|
500
|
+
The shape of the matrix will be defined one of the following, in that order:
|
|
501
|
+
- `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
|
|
502
|
+
- the first dimension of `diag`, if `diag` is an array
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
506
|
+
rows_of_blocks = cols_of_blocks
|
|
507
|
+
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
508
|
+
cols_of_blocks = rows_of_blocks
|
|
509
|
+
|
|
510
|
+
if warp.types.is_array(diag):
|
|
511
|
+
if rows_of_blocks is None:
|
|
512
|
+
rows_of_blocks = diag.shape[0]
|
|
513
|
+
cols_of_blocks = diag.shape[0]
|
|
514
|
+
|
|
515
|
+
A = bsr_zeros(
|
|
516
|
+
rows_of_blocks,
|
|
517
|
+
cols_of_blocks,
|
|
518
|
+
block_type=diag.dtype,
|
|
519
|
+
device=diag.device,
|
|
520
|
+
)
|
|
521
|
+
else:
|
|
522
|
+
if rows_of_blocks is None:
|
|
523
|
+
raise ValueError(
|
|
524
|
+
"rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
block_type = type(diag)
|
|
528
|
+
if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
|
|
529
|
+
block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
|
|
530
|
+
|
|
531
|
+
A = bsr_zeros(
|
|
532
|
+
rows_of_blocks,
|
|
533
|
+
cols_of_blocks,
|
|
534
|
+
block_type=block_type,
|
|
535
|
+
)
|
|
358
536
|
|
|
359
|
-
def bsr_diag(diag: wp.array):
|
|
360
|
-
"""Creates a square block-diagonal BSR matrix from the values array `diag`"""
|
|
361
|
-
A = bsr_zeros(rows_of_blocks=diag.shape[0], cols_of_blocks=diag.shape[0], block_type=diag.dtype, device=diag.device)
|
|
362
537
|
bsr_set_diag(A, diag)
|
|
363
538
|
return A
|
|
364
539
|
|
|
365
540
|
|
|
541
|
+
def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
|
|
542
|
+
"""Sets `A` as the identity matrix
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
A: the sparse matrix to modify
|
|
546
|
+
rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
|
|
547
|
+
"""
|
|
548
|
+
|
|
549
|
+
if A.block_shape == (1, 1):
|
|
550
|
+
identity = A.scalar_type(1.0)
|
|
551
|
+
else:
|
|
552
|
+
from numpy import eye
|
|
553
|
+
|
|
554
|
+
identity = eye(A.block_shape[0])
|
|
555
|
+
|
|
556
|
+
bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def bsr_identity(
|
|
560
|
+
rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
|
|
561
|
+
) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
|
|
562
|
+
"""Creates and returns a square identity matrix.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
rows_of_blocks: Number of rows and columns of blocks in the created matrix.
|
|
566
|
+
block_type: Block type for the newly created matrix -- must be square
|
|
567
|
+
device: Device onto which to allocate the data arrays
|
|
568
|
+
"""
|
|
569
|
+
A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
|
|
570
|
+
bsr_set_identity(A)
|
|
571
|
+
return A
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
@wp.kernel
|
|
575
|
+
def _bsr_scale_kernel(
|
|
576
|
+
alpha: Any,
|
|
577
|
+
values: wp.array(dtype=Any),
|
|
578
|
+
):
|
|
579
|
+
values[wp.tid()] = alpha * values[wp.tid()]
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
|
|
583
|
+
"""
|
|
584
|
+
Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
|
|
585
|
+
"""
|
|
586
|
+
|
|
587
|
+
if alpha != 1.0 and x.nnz > 0:
|
|
588
|
+
if alpha == 0.0:
|
|
589
|
+
bsr_set_zero(x)
|
|
590
|
+
else:
|
|
591
|
+
if not isinstance(alpha, x.scalar_type):
|
|
592
|
+
alpha = x.scalar_type(alpha)
|
|
593
|
+
|
|
594
|
+
wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
|
|
595
|
+
|
|
596
|
+
return x
|
|
597
|
+
|
|
598
|
+
|
|
366
599
|
@wp.kernel
|
|
367
600
|
def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
|
|
368
601
|
i = wp.tid()
|
|
@@ -393,16 +626,75 @@ def _bsr_axpy_add_block(
|
|
|
393
626
|
dst_values[block] = dst_values[block] + scale * src_values[i]
|
|
394
627
|
|
|
395
628
|
|
|
396
|
-
|
|
629
|
+
class bsr_axpy_work_arrays:
|
|
630
|
+
"""Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
|
|
631
|
+
|
|
632
|
+
def __init__(self):
|
|
633
|
+
self._reset(None)
|
|
634
|
+
|
|
635
|
+
def _reset(self, device):
|
|
636
|
+
self.device = device
|
|
637
|
+
self._sum_rows = None
|
|
638
|
+
self._sum_cols = None
|
|
639
|
+
self._old_y_values = None
|
|
640
|
+
self._old_x_values = None
|
|
641
|
+
|
|
642
|
+
def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
|
|
643
|
+
if self.device != device:
|
|
644
|
+
self._reset(device)
|
|
645
|
+
|
|
646
|
+
if self._sum_rows is None or self._sum_rows.size < sum_nnz:
|
|
647
|
+
self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
648
|
+
if self._sum_cols is None or self._sum_cols.size < sum_nnz:
|
|
649
|
+
self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
650
|
+
|
|
651
|
+
if self._old_y_values is None or self._old_y_values.size < y.nnz:
|
|
652
|
+
self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def bsr_axpy(
|
|
656
|
+
x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
657
|
+
y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
658
|
+
alpha: Scalar = 1.0,
|
|
659
|
+
beta: Scalar = 1.0,
|
|
660
|
+
work_arrays: Optional[bsr_axpy_work_arrays] = None,
|
|
661
|
+
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
397
662
|
"""
|
|
398
|
-
Performs the
|
|
663
|
+
Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
|
|
664
|
+
|
|
665
|
+
The `x` and `y` matrices are allowed to alias.
|
|
666
|
+
|
|
667
|
+
Args:
|
|
668
|
+
x: Read-only right-hand-side.
|
|
669
|
+
y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
|
|
670
|
+
alpha: Uniform scaling factor for `x`
|
|
671
|
+
beta: Uniform scaling factor for `y`
|
|
672
|
+
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
|
|
399
673
|
"""
|
|
400
674
|
|
|
401
675
|
if y is None:
|
|
402
|
-
|
|
676
|
+
# If not output matrix is provided, allocate it for convenience
|
|
677
|
+
y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
|
|
403
678
|
beta = 0.0
|
|
404
679
|
|
|
405
|
-
|
|
680
|
+
# Handle easy cases first
|
|
681
|
+
if beta == 0.0 or y.nnz == 0:
|
|
682
|
+
bsr_assign(src=x, dest=y)
|
|
683
|
+
return bsr_scale(y, alpha=alpha)
|
|
684
|
+
|
|
685
|
+
if alpha == 0.0 or x.nnz == 0:
|
|
686
|
+
return bsr_scale(y, alpha=beta)
|
|
687
|
+
|
|
688
|
+
if not isinstance(alpha, y.scalar_type):
|
|
689
|
+
alpha = y.scalar_type(alpha)
|
|
690
|
+
if not isinstance(beta, y.scalar_type):
|
|
691
|
+
beta = y.scalar_type(beta)
|
|
692
|
+
|
|
693
|
+
if x == y:
|
|
694
|
+
# Aliasing case
|
|
695
|
+
return bsr_scale(y, alpha=alpha.value + beta.value)
|
|
696
|
+
|
|
697
|
+
# General case
|
|
406
698
|
|
|
407
699
|
if x.values.device != y.values.device:
|
|
408
700
|
raise ValueError("All arguments must reside on the same device")
|
|
@@ -413,20 +705,21 @@ def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
|
|
|
413
705
|
if x.nrow != y.nrow or x.ncol != y.ncol:
|
|
414
706
|
raise ValueError("Matrices must have the same number of rows and columns")
|
|
415
707
|
|
|
416
|
-
|
|
417
|
-
|
|
708
|
+
if work_arrays is None:
|
|
709
|
+
work_arrays = bsr_axpy_work_arrays()
|
|
418
710
|
|
|
419
711
|
sum_nnz = x.nnz + y.nnz
|
|
420
|
-
|
|
421
|
-
|
|
712
|
+
device = y.values.device
|
|
713
|
+
work_arrays._allocate(device, y, sum_nnz)
|
|
714
|
+
|
|
715
|
+
wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
|
|
716
|
+
wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
|
|
422
717
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, sum_rows])
|
|
718
|
+
wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
|
|
719
|
+
wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
|
|
426
720
|
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, sum_rows])
|
|
721
|
+
# Save old y values before overwriting matrix
|
|
722
|
+
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
|
|
430
723
|
|
|
431
724
|
# Increase dest array sizes if needed
|
|
432
725
|
if y.columns.shape[0] < sum_nnz:
|
|
@@ -439,37 +732,55 @@ def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
|
|
|
439
732
|
else:
|
|
440
733
|
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
441
734
|
|
|
442
|
-
|
|
735
|
+
old_y_nnz = y.nnz
|
|
736
|
+
y.nnz = native_func(
|
|
443
737
|
y.block_shape[0],
|
|
444
738
|
y.block_shape[1],
|
|
445
739
|
y.nrow,
|
|
446
740
|
sum_nnz,
|
|
447
|
-
|
|
448
|
-
|
|
741
|
+
work_arrays._sum_rows.ptr,
|
|
742
|
+
work_arrays._sum_cols.ptr,
|
|
449
743
|
0,
|
|
450
744
|
y.offsets.ptr,
|
|
451
745
|
y.columns.ptr,
|
|
452
746
|
0,
|
|
453
747
|
)
|
|
454
748
|
|
|
455
|
-
|
|
749
|
+
_bsr_ensure_fits(y)
|
|
750
|
+
y.values.zero_()
|
|
456
751
|
|
|
457
752
|
wp.launch(
|
|
458
753
|
kernel=_bsr_axpy_add_block,
|
|
459
754
|
device=device,
|
|
460
|
-
dim=
|
|
461
|
-
inputs=[
|
|
755
|
+
dim=old_y_nnz,
|
|
756
|
+
inputs=[
|
|
757
|
+
0,
|
|
758
|
+
beta,
|
|
759
|
+
work_arrays._sum_rows,
|
|
760
|
+
work_arrays._sum_cols,
|
|
761
|
+
y.offsets,
|
|
762
|
+
y.columns,
|
|
763
|
+
work_arrays._old_y_values,
|
|
764
|
+
y.values,
|
|
765
|
+
],
|
|
462
766
|
)
|
|
767
|
+
|
|
463
768
|
wp.launch(
|
|
464
769
|
kernel=_bsr_axpy_add_block,
|
|
465
770
|
device=device,
|
|
466
771
|
dim=x.nnz,
|
|
467
|
-
inputs=[
|
|
772
|
+
inputs=[
|
|
773
|
+
old_y_nnz,
|
|
774
|
+
alpha,
|
|
775
|
+
work_arrays._sum_rows,
|
|
776
|
+
work_arrays._sum_cols,
|
|
777
|
+
y.offsets,
|
|
778
|
+
y.columns,
|
|
779
|
+
x.values,
|
|
780
|
+
y.values,
|
|
781
|
+
],
|
|
468
782
|
)
|
|
469
783
|
|
|
470
|
-
y.values = sum_values
|
|
471
|
-
y.nnz = sum_nnz
|
|
472
|
-
|
|
473
784
|
return y
|
|
474
785
|
|
|
475
786
|
|
|
@@ -555,23 +866,77 @@ def _bsr_mm_compute_values(
|
|
|
555
866
|
mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
|
|
556
867
|
|
|
557
868
|
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
def
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
869
|
+
class bsr_mm_work_arrays:
|
|
870
|
+
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
|
|
871
|
+
|
|
872
|
+
def __init__(self):
|
|
873
|
+
self._reset(None)
|
|
874
|
+
|
|
875
|
+
def _reset(self, device):
|
|
876
|
+
self.device = device
|
|
877
|
+
self._pinned_count_buffer = None
|
|
878
|
+
self._mm_row_counts = None
|
|
879
|
+
self._mm_rows = None
|
|
880
|
+
self._mm_cols = None
|
|
881
|
+
self._old_z_values = None
|
|
882
|
+
self._old_z_offsets = None
|
|
883
|
+
self._old_z_columns = None
|
|
884
|
+
|
|
885
|
+
def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
|
|
886
|
+
if self.device != device:
|
|
887
|
+
self._reset(device)
|
|
888
|
+
|
|
889
|
+
# Allocations that do not depend on any computation
|
|
890
|
+
if self.device.is_cuda:
|
|
891
|
+
if self._pinned_count_buffer is None:
|
|
892
|
+
self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
|
|
893
|
+
|
|
894
|
+
if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
|
|
895
|
+
self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
|
|
896
|
+
|
|
897
|
+
if copied_z_nnz > 0:
|
|
898
|
+
if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
|
|
899
|
+
self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
|
|
900
|
+
|
|
901
|
+
if z_aliasing:
|
|
902
|
+
if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
|
|
903
|
+
self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
|
|
904
|
+
if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
|
|
905
|
+
self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
|
|
906
|
+
|
|
907
|
+
def _allocate_stage_2(self, mm_nnz: int):
|
|
908
|
+
# Allocations that depend on unmerged nnz estimate
|
|
909
|
+
if self._mm_rows is None or self._mm_rows.size < mm_nnz:
|
|
910
|
+
self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
911
|
+
if self._mm_cols is None or self._mm_cols.size < mm_nnz:
|
|
912
|
+
self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
def bsr_mm(
|
|
916
|
+
x: BsrMatrix[BlockType[Rows, Any, Scalar]],
|
|
917
|
+
y: BsrMatrix[BlockType[Any, Cols, Scalar]],
|
|
918
|
+
z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
919
|
+
alpha: Scalar = 1.0,
|
|
920
|
+
beta: Scalar = 0.0,
|
|
921
|
+
work_arrays: Optional[bsr_mm_work_arrays] = None,
|
|
922
|
+
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
570
923
|
"""
|
|
571
|
-
Performs the
|
|
924
|
+
Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
|
|
925
|
+
|
|
926
|
+
The `x`, `y` and `z` matrices are allowed to alias.
|
|
927
|
+
If the matrix `z` is not provided as input, it will be allocated and treated as zero.
|
|
928
|
+
|
|
929
|
+
Args:
|
|
930
|
+
x: Read-only left factor of the matrix-matrix product.
|
|
931
|
+
y: Read-only right factor of the matrix-matrix product.
|
|
932
|
+
z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
|
|
933
|
+
alpha: Uniform scaling factor for the ``x * y`` product
|
|
934
|
+
beta: Uniform scaling factor for `z`
|
|
935
|
+
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
|
|
572
936
|
"""
|
|
573
937
|
|
|
574
938
|
if z is None:
|
|
939
|
+
# If not output matrix is provided, allocate it for convenience
|
|
575
940
|
z_block_shape = (x.block_shape[0], y.block_shape[1])
|
|
576
941
|
if z_block_shape == (1, 1):
|
|
577
942
|
z_block_type = x.scalar_type
|
|
@@ -586,52 +951,85 @@ def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0,
|
|
|
586
951
|
if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
|
|
587
952
|
raise ValueError("Matrices must have the same scalar type")
|
|
588
953
|
|
|
589
|
-
if
|
|
590
|
-
|
|
954
|
+
if (
|
|
955
|
+
x.block_shape[0] != z.block_shape[0]
|
|
956
|
+
or y.block_shape[1] != z.block_shape[1]
|
|
957
|
+
or x.block_shape[1] != y.block_shape[0]
|
|
958
|
+
):
|
|
959
|
+
raise ValueError("Incompatible block sizes for matrix multiplication")
|
|
591
960
|
|
|
592
|
-
if x.nrow != z.nrow or z.ncol != y.ncol:
|
|
961
|
+
if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
|
|
593
962
|
raise ValueError("Incompatible number of rows/columns for matrix multiplication")
|
|
594
963
|
|
|
595
964
|
device = z.values.device
|
|
596
965
|
|
|
597
|
-
alpha
|
|
598
|
-
|
|
966
|
+
if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
|
|
967
|
+
# Easy case
|
|
968
|
+
return bsr_scale(z, beta)
|
|
969
|
+
|
|
970
|
+
if not isinstance(alpha, z.scalar_type):
|
|
971
|
+
alpha = z.scalar_type(alpha)
|
|
972
|
+
if not isinstance(beta, z.scalar_type):
|
|
973
|
+
beta = z.scalar_type(beta)
|
|
974
|
+
|
|
975
|
+
if work_arrays is None:
|
|
976
|
+
work_arrays = bsr_mm_work_arrays()
|
|
977
|
+
|
|
978
|
+
z_aliasing = z == x or z == y
|
|
979
|
+
copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
|
|
980
|
+
|
|
981
|
+
work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
|
|
599
982
|
|
|
600
983
|
# Prefix sum of number of (unmerged) mm blocks per row
|
|
601
|
-
mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=device)
|
|
602
984
|
wp.launch(
|
|
603
985
|
kernel=_bsr_mm_count_coeffs,
|
|
604
986
|
device=device,
|
|
605
987
|
dim=z.nrow,
|
|
606
|
-
inputs=[
|
|
988
|
+
inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
|
|
607
989
|
)
|
|
608
|
-
warp.utils.array_scan(
|
|
990
|
+
warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
|
|
609
991
|
|
|
610
992
|
# Get back total counts on host
|
|
611
993
|
if device.is_cuda:
|
|
612
|
-
|
|
613
|
-
wp.
|
|
614
|
-
|
|
615
|
-
mm_nnz = int(mm_tot_count.numpy()[0])
|
|
994
|
+
wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
|
|
995
|
+
wp.synchronize_stream(wp.get_stream(device))
|
|
996
|
+
mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
|
|
616
997
|
else:
|
|
617
|
-
mm_nnz = int(
|
|
998
|
+
mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
|
|
618
999
|
|
|
619
|
-
|
|
620
|
-
mm_cols = wp.empty(shape=(mm_nnz), dtype=int, device=device)
|
|
1000
|
+
work_arrays._allocate_stage_2(mm_nnz)
|
|
621
1001
|
|
|
622
|
-
#
|
|
623
|
-
|
|
624
|
-
|
|
1002
|
+
# If z has a non-zero scale, save current data before overwriting it
|
|
1003
|
+
if copied_z_nnz > 0:
|
|
1004
|
+
# Copy z row and column indices
|
|
1005
|
+
wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
|
|
1006
|
+
wp.launch(
|
|
1007
|
+
kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
|
|
1008
|
+
)
|
|
1009
|
+
# Save current z values in temporary buffer
|
|
1010
|
+
wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
|
|
1011
|
+
if z_aliasing:
|
|
1012
|
+
# If z is aliasing with x or y, need to save topology as well
|
|
1013
|
+
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
1014
|
+
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
625
1015
|
|
|
626
1016
|
# Fill unmerged mm blocks rows and columns
|
|
627
1017
|
wp.launch(
|
|
628
1018
|
kernel=_bsr_mm_list_coeffs,
|
|
629
1019
|
device=device,
|
|
630
1020
|
dim=z.nrow,
|
|
631
|
-
inputs=[
|
|
1021
|
+
inputs=[
|
|
1022
|
+
x.offsets,
|
|
1023
|
+
x.columns,
|
|
1024
|
+
y.offsets,
|
|
1025
|
+
y.columns,
|
|
1026
|
+
work_arrays._mm_row_counts,
|
|
1027
|
+
work_arrays._mm_rows,
|
|
1028
|
+
work_arrays._mm_cols,
|
|
1029
|
+
],
|
|
632
1030
|
)
|
|
633
1031
|
|
|
634
|
-
# Increase dest array
|
|
1032
|
+
# Increase dest array size if needed
|
|
635
1033
|
if z.columns.shape[0] < mm_nnz:
|
|
636
1034
|
z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
|
|
637
1035
|
|
|
@@ -642,45 +1040,66 @@ def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0,
|
|
|
642
1040
|
else:
|
|
643
1041
|
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
644
1042
|
|
|
645
|
-
|
|
1043
|
+
z.nnz = native_func(
|
|
646
1044
|
z.block_shape[0],
|
|
647
1045
|
z.block_shape[1],
|
|
648
1046
|
z.nrow,
|
|
649
1047
|
mm_nnz,
|
|
650
|
-
|
|
651
|
-
|
|
1048
|
+
work_arrays._mm_rows.ptr,
|
|
1049
|
+
work_arrays._mm_cols.ptr,
|
|
652
1050
|
0,
|
|
653
1051
|
z.offsets.ptr,
|
|
654
1052
|
z.columns.ptr,
|
|
655
1053
|
0,
|
|
656
1054
|
)
|
|
657
1055
|
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
1056
|
+
_bsr_ensure_fits(z)
|
|
1057
|
+
z.values.zero_()
|
|
1058
|
+
|
|
1059
|
+
if copied_z_nnz > 0:
|
|
1060
|
+
# Add back original z values
|
|
1061
|
+
wp.launch(
|
|
1062
|
+
kernel=_bsr_axpy_add_block,
|
|
1063
|
+
device=device,
|
|
1064
|
+
dim=copied_z_nnz,
|
|
1065
|
+
inputs=[
|
|
1066
|
+
0,
|
|
1067
|
+
beta,
|
|
1068
|
+
work_arrays._mm_rows,
|
|
1069
|
+
work_arrays._mm_cols,
|
|
1070
|
+
z.offsets,
|
|
1071
|
+
z.columns,
|
|
1072
|
+
work_arrays._old_z_values,
|
|
1073
|
+
z.values,
|
|
1074
|
+
],
|
|
1075
|
+
)
|
|
671
1076
|
|
|
672
1077
|
# Add mm blocks to z values
|
|
673
|
-
|
|
674
|
-
|
|
1078
|
+
if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
|
|
1079
|
+
warp.types.type_is_matrix(z.values.dtype)
|
|
1080
|
+
):
|
|
675
1081
|
# Result block type is scalar, but operands are matrices
|
|
676
1082
|
# Cast result to (1x1) matrix to perform multiplication
|
|
677
|
-
mm_values =
|
|
1083
|
+
mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
|
|
1084
|
+
else:
|
|
1085
|
+
mm_values = z.values
|
|
678
1086
|
|
|
679
1087
|
wp.launch(
|
|
680
1088
|
kernel=_bsr_mm_compute_values,
|
|
681
1089
|
device=device,
|
|
682
1090
|
dim=z.nrow,
|
|
683
|
-
inputs=[
|
|
1091
|
+
inputs=[
|
|
1092
|
+
alpha,
|
|
1093
|
+
work_arrays._old_z_offsets if x == z else x.offsets,
|
|
1094
|
+
work_arrays._old_z_columns if x == z else x.columns,
|
|
1095
|
+
work_arrays._old_z_values if x == z else x.values,
|
|
1096
|
+
work_arrays._old_z_offsets if y == z else y.offsets,
|
|
1097
|
+
work_arrays._old_z_columns if y == z else y.columns,
|
|
1098
|
+
work_arrays._old_z_values if y == z else y.values,
|
|
1099
|
+
z.offsets,
|
|
1100
|
+
z.columns,
|
|
1101
|
+
mm_values,
|
|
1102
|
+
],
|
|
684
1103
|
)
|
|
685
1104
|
|
|
686
1105
|
return z
|
|
@@ -697,44 +1116,96 @@ def _bsr_mv_kernel(
|
|
|
697
1116
|
y: wp.array(dtype=Any),
|
|
698
1117
|
):
|
|
699
1118
|
row = wp.tid()
|
|
700
|
-
beg = A_offsets[row]
|
|
701
|
-
end = A_offsets[row + 1]
|
|
702
1119
|
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
v = v + A_values[block] * x[A_columns[block]]
|
|
1120
|
+
# zero-initialize with type of y elements
|
|
1121
|
+
scalar_zero = type(alpha)(0)
|
|
1122
|
+
v = y.dtype(scalar_zero)
|
|
707
1123
|
|
|
708
|
-
|
|
1124
|
+
if alpha != scalar_zero:
|
|
1125
|
+
beg = A_offsets[row]
|
|
1126
|
+
end = A_offsets[row + 1]
|
|
1127
|
+
for block in range(beg, end):
|
|
1128
|
+
v += A_values[block] * x[A_columns[block]]
|
|
1129
|
+
v *= alpha
|
|
709
1130
|
|
|
1131
|
+
if beta != scalar_zero:
|
|
1132
|
+
v += beta * y[row]
|
|
710
1133
|
|
|
711
|
-
|
|
1134
|
+
y[row] = v
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
def bsr_mv(
|
|
1138
|
+
A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
1139
|
+
x: "Array[Vector[Cols, Scalar] | Scalar]",
|
|
1140
|
+
y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1141
|
+
alpha: Scalar = 1.0,
|
|
1142
|
+
beta: Scalar = 0.0,
|
|
1143
|
+
work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1144
|
+
) -> "Array[Vector[Rows, Scalar] | Scalar]":
|
|
712
1145
|
"""
|
|
713
|
-
|
|
1146
|
+
Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
|
|
1147
|
+
|
|
1148
|
+
The `x` and `y` vectors are allowed to alias.
|
|
1149
|
+
|
|
1150
|
+
Args:
|
|
1151
|
+
A: Read-only, left matrix factor of the matrix-vector product.
|
|
1152
|
+
x: Read-only, right vector factor of the matrix-vector product.
|
|
1153
|
+
y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
|
|
1154
|
+
alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
|
|
1155
|
+
beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
|
|
1156
|
+
work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
|
|
1157
|
+
will be used for this purpose, otherwise a temporary allocation wil be performed.
|
|
714
1158
|
"""
|
|
715
|
-
alpha = A.scalar_type(alpha)
|
|
716
|
-
beta = A.scalar_type(beta)
|
|
717
1159
|
|
|
718
|
-
|
|
719
|
-
|
|
1160
|
+
if y is None:
|
|
1161
|
+
# If no output array is provided, allocate one for convenience
|
|
1162
|
+
y_vec_len = A.block_shape[0]
|
|
1163
|
+
y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
|
|
1164
|
+
y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
|
|
1165
|
+
y.zero_()
|
|
1166
|
+
beta = 0.0
|
|
1167
|
+
|
|
1168
|
+
if not isinstance(alpha, A.scalar_type):
|
|
1169
|
+
alpha = A.scalar_type(alpha)
|
|
1170
|
+
if not isinstance(beta, A.scalar_type):
|
|
1171
|
+
beta = A.scalar_type(beta)
|
|
720
1172
|
|
|
721
1173
|
if A.values.device != x.device or A.values.device != y.device:
|
|
722
|
-
raise ValueError("A, x and y must
|
|
1174
|
+
raise ValueError("A, x and y must reside on the same device")
|
|
723
1175
|
|
|
724
1176
|
if x.shape[0] != A.ncol:
|
|
725
1177
|
raise ValueError("Number of columns of A must match number of rows of x")
|
|
726
1178
|
if y.shape[0] != A.nrow:
|
|
727
1179
|
raise ValueError("Number of rows of A must match number of rows of y")
|
|
728
1180
|
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
1181
|
+
if x == y:
|
|
1182
|
+
# Aliasing case, need temporary storage
|
|
1183
|
+
if work_buffer is None:
|
|
1184
|
+
work_buffer = wp.empty_like(y)
|
|
1185
|
+
elif work_buffer.size < y.size:
|
|
1186
|
+
raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
|
|
1187
|
+
elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
|
|
1188
|
+
raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
|
|
1189
|
+
|
|
1190
|
+
# Save old y values before overwriting vector
|
|
1191
|
+
wp.copy(dest=work_buffer, src=y, count=y.size)
|
|
1192
|
+
x = work_buffer
|
|
1193
|
+
|
|
1194
|
+
# Promote scalar vectors to length-1 vecs and conversely
|
|
1195
|
+
if warp.types.type_is_matrix(A.values.dtype):
|
|
1196
|
+
if A.block_shape[0] == 1:
|
|
733
1197
|
if y.dtype == A.scalar_type:
|
|
734
1198
|
y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
735
|
-
if block_shape[1] == 1:
|
|
1199
|
+
if A.block_shape[1] == 1:
|
|
736
1200
|
if x.dtype == A.scalar_type:
|
|
737
1201
|
x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
1202
|
+
else:
|
|
1203
|
+
if A.block_shape[0] == 1:
|
|
1204
|
+
if y.dtype != A.scalar_type:
|
|
1205
|
+
y = y.view(dtype=A.scalar_type)
|
|
1206
|
+
if A.block_shape[1] == 1:
|
|
1207
|
+
if x.dtype != A.scalar_type:
|
|
1208
|
+
x = x.view(dtype=A.scalar_type)
|
|
738
1209
|
|
|
739
1210
|
wp.launch(
|
|
740
1211
|
kernel=_bsr_mv_kernel,
|
|
@@ -742,3 +1213,5 @@ def bsr_mv(A: BsrMatrix, x: wp.array, y: wp.array, alpha: float = 1.0, beta: flo
|
|
|
742
1213
|
dim=A.nrow,
|
|
743
1214
|
inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
|
|
744
1215
|
)
|
|
1216
|
+
|
|
1217
|
+
return y
|