warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/autograd.py +12 -2
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +463 -372
- warp/codegen.py +196 -124
- warp/config.py +42 -6
- warp/context.py +496 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_cloth.py +1 -1
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/distributed/example_jacobi_mpi.py +507 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/field.py +11 -1
- warp/fem/field/nodal_field.py +56 -88
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +16 -13
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +7 -20
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
- warp/jax_experimental/ffi.py +702 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +312 -116
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +100 -11
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/render/render_opengl.py +19 -17
- warp/render/render_usd.py +93 -3
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +32 -19
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/inertia.py +189 -156
- warp/sim/integrator_euler.py +8 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +8 -5
- warp/sim/model.py +71 -25
- warp/sim/render.py +4 -0
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +217 -20
- warp/tests/__main__.py +0 -15
- warp/tests/assets/torus.usda +1 -1
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
- warp/tests/sim/test_inertia.py +161 -0
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/sim/test_xpbd.py +399 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_codegen.py +24 -3
- warp/tests/test_examples.py +40 -38
- warp/tests/test_fem.py +98 -14
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +577 -156
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +356 -151
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +336 -178
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -62
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +175 -666
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/sparse.py
CHANGED
|
@@ -19,7 +19,21 @@ from typing import Any, Generic, Optional, Tuple, TypeVar, Union
|
|
|
19
19
|
import warp as wp
|
|
20
20
|
import warp.types
|
|
21
21
|
import warp.utils
|
|
22
|
-
from warp.types import
|
|
22
|
+
from warp.types import (
|
|
23
|
+
Array,
|
|
24
|
+
Cols,
|
|
25
|
+
Rows,
|
|
26
|
+
Scalar,
|
|
27
|
+
Vector,
|
|
28
|
+
is_array,
|
|
29
|
+
scalar_types,
|
|
30
|
+
type_is_matrix,
|
|
31
|
+
type_length,
|
|
32
|
+
type_repr,
|
|
33
|
+
type_scalar_type,
|
|
34
|
+
type_to_warp,
|
|
35
|
+
types_equal,
|
|
36
|
+
)
|
|
23
37
|
|
|
24
38
|
# typing hints
|
|
25
39
|
|
|
@@ -45,50 +59,89 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
45
59
|
Should not be constructed directly but through functions such as :func:`bsr_zeros`.
|
|
46
60
|
|
|
47
61
|
Attributes:
|
|
48
|
-
nrow (int): Number of rows of blocks
|
|
49
|
-
ncol (int): Number of columns of blocks
|
|
50
|
-
nnz (int): Upper bound for the number of non-zero blocks, used for
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
62
|
+
nrow (int): Number of rows of blocks.
|
|
63
|
+
ncol (int): Number of columns of blocks.
|
|
64
|
+
nnz (int): Upper bound for the number of non-zero blocks, used for
|
|
65
|
+
dimensioning launches. The exact number is at ``offsets[nrow-1]``.
|
|
66
|
+
See also :meth:`nnz_sync`.
|
|
67
|
+
offsets (Array[int]): Array of size at least ``1 + nrow`` such that the
|
|
68
|
+
start and end indices of the blocks of row ``r`` are ``offsets[r]``
|
|
69
|
+
and ``offsets[r+1]``, respectively.
|
|
70
|
+
columns (Array[int]): Array of size at least equal to ``nnz`` containing
|
|
71
|
+
block column indices.
|
|
72
|
+
values (Array[BlockType]): Array of size at least equal to ``nnz``
|
|
73
|
+
containing block values.
|
|
54
74
|
"""
|
|
55
75
|
|
|
56
76
|
@property
|
|
57
77
|
def scalar_type(self) -> Scalar:
|
|
58
|
-
"""Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
|
|
59
|
-
return
|
|
78
|
+
"""Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type."""
|
|
79
|
+
return type_scalar_type(self.values.dtype)
|
|
60
80
|
|
|
61
81
|
@property
|
|
62
82
|
def block_shape(self) -> Tuple[int, int]:
|
|
63
|
-
"""Shape of the individual blocks"""
|
|
83
|
+
"""Shape of the individual blocks."""
|
|
64
84
|
return getattr(self.values.dtype, "_shape_", (1, 1))
|
|
65
85
|
|
|
66
86
|
@property
|
|
67
87
|
def block_size(self) -> int:
|
|
68
|
-
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
|
|
69
|
-
return
|
|
88
|
+
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
|
|
89
|
+
return type_length(self.values.dtype)
|
|
70
90
|
|
|
71
91
|
@property
|
|
72
92
|
def shape(self) -> Tuple[int, int]:
|
|
73
|
-
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
|
|
93
|
+
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block."""
|
|
74
94
|
block_shape = self.block_shape
|
|
75
95
|
return (self.nrow * block_shape[0], self.ncol * block_shape[1])
|
|
76
96
|
|
|
77
97
|
@property
|
|
78
98
|
def dtype(self) -> type:
|
|
79
|
-
"""Data type for individual block values"""
|
|
99
|
+
"""Data type for individual block values."""
|
|
80
100
|
return self.values.dtype
|
|
81
101
|
|
|
82
102
|
@property
|
|
83
103
|
def device(self) -> wp.context.Device:
|
|
84
|
-
"""Device on which offsets
|
|
104
|
+
"""Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
|
|
85
105
|
return self.values.device
|
|
86
106
|
|
|
107
|
+
@property
|
|
108
|
+
def scalar_values(self) -> wp.array:
|
|
109
|
+
"""Accesses the ``values`` array as a 3d scalar array."""
|
|
110
|
+
if self.block_shape == (1, 1):
|
|
111
|
+
return self.values.reshape((self.nnz, 1, 1))
|
|
112
|
+
|
|
113
|
+
def _as_3d_array(arr):
|
|
114
|
+
return wp.array(
|
|
115
|
+
ptr=arr.ptr,
|
|
116
|
+
capacity=arr.capacity,
|
|
117
|
+
device=arr.device,
|
|
118
|
+
dtype=self.scalar_type,
|
|
119
|
+
shape=(self.nnz, *self.block_shape),
|
|
120
|
+
grad=None if arr.grad is None else _as_3d_array(arr.grad),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
values_view = _as_3d_array(self.values)
|
|
124
|
+
values_view._ref = self.values # keep ref in case we're garbage collected
|
|
125
|
+
return values_view
|
|
126
|
+
|
|
127
|
+
def uncompress_rows(self, out: wp.array = None) -> wp.array:
|
|
128
|
+
"""Compute the row index for each non-zero block from the compressed row offsets."""
|
|
129
|
+
if out is None:
|
|
130
|
+
out = wp.empty(self.nnz, dtype=int, device=self.device)
|
|
131
|
+
|
|
132
|
+
wp.launch(
|
|
133
|
+
kernel=_bsr_get_block_row,
|
|
134
|
+
device=self.device,
|
|
135
|
+
dim=self.nnz,
|
|
136
|
+
inputs=[self.nrow, self.offsets, out],
|
|
137
|
+
)
|
|
138
|
+
return out
|
|
139
|
+
|
|
87
140
|
def nnz_sync(self):
|
|
88
|
-
"""
|
|
89
|
-
and
|
|
141
|
+
"""Ensure that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed
|
|
142
|
+
and update the nnz upper bound.
|
|
90
143
|
|
|
91
|
-
See also :meth:`copy_nnz_async
|
|
144
|
+
See also :meth:`copy_nnz_async`.
|
|
92
145
|
"""
|
|
93
146
|
|
|
94
147
|
if self._is_nnz_transfer_setup():
|
|
@@ -99,10 +152,11 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
99
152
|
|
|
100
153
|
def copy_nnz_async(self, known_nnz: int = None):
|
|
101
154
|
"""
|
|
102
|
-
|
|
155
|
+
Start the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
|
|
156
|
+
|
|
103
157
|
Needs to be called whenever the offsets array has been modified from outside ``warp.sparse``.
|
|
104
158
|
|
|
105
|
-
See also :meth:`nnz_sync
|
|
159
|
+
See also :meth:`nnz_sync`.
|
|
106
160
|
"""
|
|
107
161
|
if known_nnz is not None:
|
|
108
162
|
self.nnz = int(known_nnz)
|
|
@@ -186,35 +240,33 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
186
240
|
return _BsrScalingExpression(self, -1.0)
|
|
187
241
|
|
|
188
242
|
def transpose(self):
|
|
189
|
-
"""
|
|
243
|
+
"""Return a transposed copy of this matrix."""
|
|
190
244
|
return bsr_transposed(self)
|
|
191
245
|
|
|
192
246
|
|
|
193
247
|
def bsr_matrix_t(dtype: BlockType):
|
|
194
|
-
dtype =
|
|
248
|
+
dtype = type_to_warp(dtype)
|
|
195
249
|
|
|
196
|
-
if not
|
|
197
|
-
raise ValueError(
|
|
198
|
-
f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
|
|
199
|
-
)
|
|
250
|
+
if not type_is_matrix(dtype) and dtype not in scalar_types:
|
|
251
|
+
raise ValueError(f"BsrMatrix block type must be either warp matrix or scalar; got {type_repr(dtype)}")
|
|
200
252
|
|
|
201
253
|
class BsrMatrixTyped(BsrMatrix):
|
|
202
254
|
nrow: int
|
|
203
|
-
"""Number of rows of blocks"""
|
|
255
|
+
"""Number of rows of blocks."""
|
|
204
256
|
ncol: int
|
|
205
|
-
"""Number of columns of blocks"""
|
|
257
|
+
"""Number of columns of blocks."""
|
|
206
258
|
nnz: int
|
|
207
|
-
"""Upper bound for the number of non-zeros"""
|
|
259
|
+
"""Upper bound for the number of non-zeros."""
|
|
208
260
|
offsets: wp.array(dtype=int)
|
|
209
|
-
"""Array of size at least 1 +
|
|
261
|
+
"""Array of size at least ``1 + nrow``."""
|
|
210
262
|
columns: wp.array(dtype=int)
|
|
211
|
-
"""Array of size at least equal to nnz"""
|
|
263
|
+
"""Array of size at least equal to ``nnz``."""
|
|
212
264
|
values: wp.array(dtype=dtype)
|
|
213
265
|
|
|
214
266
|
module = wp.get_module(BsrMatrix.__module__)
|
|
215
267
|
|
|
216
268
|
if hasattr(dtype, "_shape_"):
|
|
217
|
-
type_str = f"{
|
|
269
|
+
type_str = f"{type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
|
|
218
270
|
else:
|
|
219
271
|
type_str = dtype.__name__
|
|
220
272
|
key = f"{BsrMatrix.__qualname__}_{type_str}"
|
|
@@ -235,16 +287,16 @@ def bsr_zeros(
|
|
|
235
287
|
block_type: BlockType,
|
|
236
288
|
device: wp.context.Devicelike = None,
|
|
237
289
|
) -> BsrMatrix:
|
|
238
|
-
"""
|
|
239
|
-
Constructs and returns an empty BSR or CSR matrix with the given shape
|
|
290
|
+
"""Construct and return an empty BSR or CSR matrix with the given shape.
|
|
240
291
|
|
|
241
292
|
Args:
|
|
242
|
-
bsr: The BSR or CSR matrix to set to zero
|
|
243
|
-
rows_of_blocks: Number of rows of blocks
|
|
244
|
-
cols_of_blocks: Number of columns of blocks
|
|
245
|
-
block_type: Type of individual blocks.
|
|
246
|
-
|
|
247
|
-
|
|
293
|
+
bsr: The BSR or CSR matrix to set to zero.
|
|
294
|
+
rows_of_blocks: Number of rows of blocks.
|
|
295
|
+
cols_of_blocks: Number of columns of blocks.
|
|
296
|
+
block_type: Type of individual blocks.
|
|
297
|
+
For CSR matrices, this should be a scalar type.
|
|
298
|
+
For BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`).
|
|
299
|
+
device: Device on which to allocate the matrix arrays.
|
|
248
300
|
"""
|
|
249
301
|
|
|
250
302
|
bsr = bsr_matrix_t(block_type)()
|
|
@@ -281,13 +333,12 @@ def bsr_set_zero(
|
|
|
281
333
|
rows_of_blocks: Optional[int] = None,
|
|
282
334
|
cols_of_blocks: Optional[int] = None,
|
|
283
335
|
):
|
|
284
|
-
"""
|
|
285
|
-
Sets a BSR matrix to zero, possibly changing its size
|
|
336
|
+
"""Set a BSR matrix to zero, possibly changing its size.
|
|
286
337
|
|
|
287
338
|
Args:
|
|
288
|
-
bsr: The BSR or CSR matrix to set to zero
|
|
289
|
-
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
290
|
-
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
339
|
+
bsr: The BSR or CSR matrix to set to zero.
|
|
340
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks.
|
|
341
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks.
|
|
291
342
|
"""
|
|
292
343
|
|
|
293
344
|
if rows_of_blocks is not None:
|
|
@@ -304,46 +355,55 @@ def bsr_set_from_triplets(
|
|
|
304
355
|
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
305
356
|
rows: "Array[int]",
|
|
306
357
|
columns: "Array[int]",
|
|
307
|
-
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
358
|
+
values: Optional["Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]"] = None,
|
|
308
359
|
prune_numerical_zeros: bool = True,
|
|
360
|
+
masked: bool = False,
|
|
309
361
|
):
|
|
310
|
-
"""
|
|
311
|
-
Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
362
|
+
"""Fill a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
312
363
|
|
|
313
364
|
The first dimension of the three input arrays must match and indicates the number of COO triplets.
|
|
314
365
|
|
|
315
366
|
Args:
|
|
316
|
-
dest: Sparse matrix to populate
|
|
317
|
-
rows: Row index for each non-zero
|
|
318
|
-
columns: Columns index for each non-zero
|
|
367
|
+
dest: Sparse matrix to populate.
|
|
368
|
+
rows: Row index for each non-zero.
|
|
369
|
+
columns: Columns index for each non-zero.
|
|
319
370
|
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
320
|
-
to the
|
|
321
|
-
|
|
371
|
+
to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
|
|
372
|
+
If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
|
|
373
|
+
prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
|
|
374
|
+
masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
|
|
322
375
|
"""
|
|
323
376
|
|
|
324
|
-
if
|
|
377
|
+
if rows.device != columns.device or rows.device != dest.device:
|
|
325
378
|
raise ValueError("All arguments must reside on the same device")
|
|
326
379
|
|
|
327
|
-
if
|
|
380
|
+
if rows.shape[0] != columns.shape[0]:
|
|
328
381
|
raise ValueError("All triplet arrays must have the same length")
|
|
329
382
|
|
|
330
383
|
# Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
|
|
331
|
-
if values
|
|
332
|
-
if values.
|
|
333
|
-
raise ValueError("
|
|
334
|
-
|
|
335
|
-
if values.shape[
|
|
336
|
-
raise ValueError(
|
|
337
|
-
|
|
338
|
-
|
|
384
|
+
if values is not None:
|
|
385
|
+
if values.device != rows.device:
|
|
386
|
+
raise ValueError("All arguments must reside on the same device")
|
|
387
|
+
|
|
388
|
+
if values.shape[0] != rows.shape[0]:
|
|
389
|
+
raise ValueError("All triplet arrays must have the same length")
|
|
390
|
+
|
|
391
|
+
if values.ndim == 1:
|
|
392
|
+
if values.dtype != dest.values.dtype:
|
|
393
|
+
raise ValueError("Values array type must correspond to that of dest matrix")
|
|
394
|
+
elif values.ndim == 3:
|
|
395
|
+
if values.shape[1:] != dest.block_shape:
|
|
396
|
+
raise ValueError(
|
|
397
|
+
f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
|
|
398
|
+
)
|
|
339
399
|
|
|
340
|
-
|
|
341
|
-
|
|
400
|
+
if type_scalar_type(values.dtype) != dest.scalar_type:
|
|
401
|
+
raise ValueError("Scalar type of values array should correspond to that of matrix")
|
|
342
402
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
403
|
+
if not values.is_contiguous:
|
|
404
|
+
raise ValueError("Multi-dimensional values array should be contiguous")
|
|
405
|
+
else:
|
|
406
|
+
raise ValueError("Number of dimension for values array should be 1 or 3")
|
|
347
407
|
|
|
348
408
|
nnz = rows.shape[0]
|
|
349
409
|
if nnz == 0:
|
|
@@ -351,7 +411,8 @@ def bsr_set_from_triplets(
|
|
|
351
411
|
return
|
|
352
412
|
|
|
353
413
|
# Increase dest array sizes if needed
|
|
354
|
-
|
|
414
|
+
if not masked:
|
|
415
|
+
_bsr_ensure_fits(dest, nnz=nnz)
|
|
355
416
|
|
|
356
417
|
device = dest.values.device
|
|
357
418
|
scalar_type = dest.scalar_type
|
|
@@ -381,16 +442,51 @@ def bsr_set_from_triplets(
|
|
|
381
442
|
nnz,
|
|
382
443
|
ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
383
444
|
ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
384
|
-
ctypes.cast(values.ptr, ctypes.c_void_p),
|
|
445
|
+
None if values is None else ctypes.cast(values.ptr, ctypes.c_void_p),
|
|
385
446
|
prune_numerical_zeros,
|
|
447
|
+
masked,
|
|
386
448
|
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
387
449
|
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
388
|
-
ctypes.cast(dest.values.ptr, ctypes.c_void_p),
|
|
450
|
+
None if values is None else ctypes.cast(dest.values.ptr, ctypes.c_void_p),
|
|
389
451
|
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
390
452
|
nnz_event,
|
|
391
453
|
)
|
|
392
454
|
|
|
393
455
|
|
|
456
|
+
def bsr_from_triplets(
|
|
457
|
+
rows_of_blocks: int,
|
|
458
|
+
cols_of_blocks: int,
|
|
459
|
+
rows: "Array[int]",
|
|
460
|
+
columns: "Array[int]",
|
|
461
|
+
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
462
|
+
prune_numerical_zeros: bool = True,
|
|
463
|
+
):
|
|
464
|
+
"""Constructs a BSR matrix with values defined by coordinate-oriented (COO) triplets.
|
|
465
|
+
|
|
466
|
+
The first dimension of the three input arrays must match and indicates the number of COO triplets.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
rows_of_blocks: Number of rows of blocks.
|
|
470
|
+
cols_of_blocks: Number of columns of blocks.
|
|
471
|
+
rows: Row index for each non-zero.
|
|
472
|
+
columns: Columns index for each non-zero.
|
|
473
|
+
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
474
|
+
to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
|
|
475
|
+
prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
|
|
476
|
+
"""
|
|
477
|
+
|
|
478
|
+
if values.ndim == 3:
|
|
479
|
+
block_type = wp.mat(shape=values.shape[1:], dtype=values.dtype)
|
|
480
|
+
else:
|
|
481
|
+
block_type = values.dtype
|
|
482
|
+
|
|
483
|
+
A = bsr_zeros(
|
|
484
|
+
rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
|
|
485
|
+
)
|
|
486
|
+
bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
|
|
487
|
+
return A
|
|
488
|
+
|
|
489
|
+
|
|
394
490
|
class _BsrExpression(Generic[_BlockType]):
|
|
395
491
|
pass
|
|
396
492
|
|
|
@@ -501,96 +597,73 @@ def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
|
|
|
501
597
|
raise ValueError("Argument cannot be interpreted as a BsrMatrix")
|
|
502
598
|
|
|
503
599
|
|
|
504
|
-
@wp.
|
|
505
|
-
def
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
dest_offsets: wp.array(dtype=int),
|
|
600
|
+
@wp.func
|
|
601
|
+
def _bsr_row_index(
|
|
602
|
+
offsets: wp.array(dtype=int),
|
|
603
|
+
row_count: int,
|
|
604
|
+
block: int,
|
|
510
605
|
):
|
|
511
|
-
row
|
|
512
|
-
|
|
513
|
-
base_offset = src_offsets[row] * row_factor * col_factor
|
|
514
|
-
row_count = src_offsets[1 + row] - src_offsets[row]
|
|
606
|
+
"""Index of the row containing a block, or -1 if non-existing."""
|
|
607
|
+
return wp.where(block < offsets[row_count], wp.lower_bound(offsets, 0, row_count + 1, block + 1), 0) - 1
|
|
515
608
|
|
|
516
|
-
for k in range(row_factor):
|
|
517
|
-
dest_offsets[1 + k + row_factor * row] = base_offset + row_count * col_factor * (k + 1)
|
|
518
609
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
structure_only: wp.bool,
|
|
526
|
-
scale: Any,
|
|
527
|
-
row_factor: int,
|
|
528
|
-
col_factor: int,
|
|
529
|
-
dest_row_count: int,
|
|
530
|
-
src_offsets: wp.array(dtype=int),
|
|
531
|
-
src_columns: wp.array(dtype=int),
|
|
532
|
-
src_values: wp.array3d(dtype=Any),
|
|
533
|
-
dest_offsets: wp.array(dtype=int),
|
|
534
|
-
dest_columns: wp.array(dtype=int),
|
|
535
|
-
dest_values: wp.array3d(dtype=Any),
|
|
610
|
+
@wp.func
|
|
611
|
+
def _bsr_block_index(
|
|
612
|
+
row: int,
|
|
613
|
+
col: int,
|
|
614
|
+
bsr_offsets: wp.array(dtype=int),
|
|
615
|
+
bsr_columns: wp.array(dtype=int),
|
|
536
616
|
):
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
return
|
|
541
|
-
|
|
542
|
-
dest_row = wp.lower_bound(dest_offsets, 0, dest_row_count + 1, dest_block + 1) - 1
|
|
543
|
-
src_row = dest_row // row_factor
|
|
544
|
-
|
|
545
|
-
dest_col_in_row = dest_block - dest_offsets[dest_row]
|
|
546
|
-
src_col_in_row = dest_col_in_row // col_factor
|
|
547
|
-
|
|
548
|
-
src_block = src_offsets[src_row] + src_col_in_row
|
|
617
|
+
"""Index of the block at block-coordinates (row, col), or -1 if non-existing.
|
|
618
|
+
Assumes bsr_columns is sorted.
|
|
619
|
+
"""
|
|
549
620
|
|
|
550
|
-
|
|
551
|
-
|
|
621
|
+
if row < 0:
|
|
622
|
+
return -1
|
|
552
623
|
|
|
553
|
-
|
|
554
|
-
|
|
624
|
+
mask_row_beg = bsr_offsets[row]
|
|
625
|
+
mask_row_end = bsr_offsets[row + 1]
|
|
555
626
|
|
|
556
|
-
|
|
627
|
+
if mask_row_beg == mask_row_end:
|
|
628
|
+
return -1
|
|
557
629
|
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
src_base_j = split_col * dest_cols_per_block
|
|
561
|
-
for i in range(dest_rows_per_block):
|
|
562
|
-
for j in range(dest_cols_per_block):
|
|
563
|
-
dest_values[dest_block, i, j] = dest_values.dtype(
|
|
564
|
-
scale * src_values[src_block, i + src_base_i, j + src_base_j]
|
|
565
|
-
)
|
|
630
|
+
block_index = wp.lower_bound(bsr_columns, mask_row_beg, mask_row_end, col)
|
|
631
|
+
return wp.where(bsr_columns[block_index] == col, block_index, -1)
|
|
566
632
|
|
|
567
633
|
|
|
568
|
-
@wp.kernel
|
|
569
|
-
def
|
|
570
|
-
|
|
571
|
-
|
|
634
|
+
@wp.kernel(enable_backward=False)
|
|
635
|
+
def _bsr_assign_list_blocks(
|
|
636
|
+
src_subrows: int,
|
|
637
|
+
src_subcols: int,
|
|
638
|
+
dest_subrows: int,
|
|
639
|
+
dest_subcols: int,
|
|
572
640
|
src_row_count: int,
|
|
573
641
|
src_offsets: wp.array(dtype=int),
|
|
574
642
|
src_columns: wp.array(dtype=int),
|
|
575
643
|
dest_rows: wp.array(dtype=int),
|
|
576
644
|
dest_cols: wp.array(dtype=int),
|
|
577
645
|
):
|
|
578
|
-
block = wp.tid()
|
|
646
|
+
block, subrow, subcol = wp.tid()
|
|
647
|
+
dest_block = (block * src_subcols + subcol) * src_subrows + subrow
|
|
579
648
|
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
649
|
+
row = _bsr_row_index(src_offsets, src_row_count, block)
|
|
650
|
+
if row == -1:
|
|
651
|
+
dest_rows[dest_block] = row # invalid
|
|
652
|
+
dest_cols[dest_block] = row
|
|
583
653
|
else:
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
654
|
+
dest_subrow = row * src_subrows + subrow
|
|
655
|
+
dest_subcol = src_columns[block] * src_subcols + subcol
|
|
656
|
+
dest_rows[dest_block] = dest_subrow // dest_subrows
|
|
657
|
+
dest_cols[dest_block] = dest_subcol // dest_subcols
|
|
587
658
|
|
|
588
659
|
|
|
589
660
|
@wp.kernel
|
|
590
|
-
def
|
|
661
|
+
def _bsr_assign_copy_blocks(
|
|
591
662
|
scale: Any,
|
|
592
|
-
|
|
593
|
-
|
|
663
|
+
src_subrows: int,
|
|
664
|
+
src_subcols: int,
|
|
665
|
+
dest_subrows: int,
|
|
666
|
+
dest_subcols: int,
|
|
594
667
|
src_row_count: int,
|
|
595
668
|
src_offsets: wp.array(dtype=int),
|
|
596
669
|
src_columns: wp.array(dtype=int),
|
|
@@ -600,61 +673,58 @@ def _bsr_assign_merge_blocks(
|
|
|
600
673
|
dest_values: wp.array3d(dtype=Any),
|
|
601
674
|
):
|
|
602
675
|
src_block = wp.tid()
|
|
676
|
+
src_block, subrow, subcol = wp.tid()
|
|
603
677
|
|
|
604
|
-
|
|
678
|
+
src_row = _bsr_row_index(src_offsets, src_row_count, src_block)
|
|
679
|
+
if src_row == -1:
|
|
605
680
|
return
|
|
606
681
|
|
|
607
|
-
src_row = wp.lower_bound(src_offsets, 0, src_row_count + 1, src_block + 1) - 1
|
|
608
682
|
src_col = src_columns[src_block]
|
|
609
683
|
|
|
610
|
-
|
|
611
|
-
|
|
684
|
+
dest_subrow = src_row * src_subrows + subrow
|
|
685
|
+
dest_subcol = src_col * src_subcols + subcol
|
|
686
|
+
dest_row = dest_subrow // dest_subrows
|
|
687
|
+
dest_col = dest_subcol // dest_subcols
|
|
612
688
|
|
|
613
|
-
dest_block =
|
|
689
|
+
dest_block = _bsr_block_index(dest_row, dest_col, dest_offsets, dest_columns)
|
|
690
|
+
if dest_block == -1:
|
|
691
|
+
return
|
|
614
692
|
|
|
615
|
-
|
|
616
|
-
|
|
693
|
+
split_row = dest_subrow - dest_subrows * dest_row
|
|
694
|
+
split_col = dest_subcol - dest_subcols * dest_col
|
|
617
695
|
|
|
618
|
-
|
|
619
|
-
|
|
696
|
+
rows_per_subblock = src_values.shape[1] // src_subrows
|
|
697
|
+
cols_per_subblock = src_values.shape[2] // src_subcols
|
|
620
698
|
|
|
621
|
-
dest_base_i = split_row *
|
|
622
|
-
dest_base_j = split_col *
|
|
699
|
+
dest_base_i = split_row * rows_per_subblock
|
|
700
|
+
dest_base_j = split_col * cols_per_subblock
|
|
623
701
|
|
|
624
|
-
|
|
625
|
-
|
|
702
|
+
src_base_i = subrow * rows_per_subblock
|
|
703
|
+
src_base_j = subcol * cols_per_subblock
|
|
704
|
+
|
|
705
|
+
for i in range(rows_per_subblock):
|
|
706
|
+
for j in range(cols_per_subblock):
|
|
626
707
|
dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
|
|
627
|
-
scale * src_values[src_block, i, j]
|
|
708
|
+
scale * src_values[src_block, i + src_base_i, j + src_base_j]
|
|
628
709
|
)
|
|
629
710
|
|
|
630
711
|
|
|
631
|
-
def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array:
|
|
632
|
-
if A.block_shape == (1, 1):
|
|
633
|
-
return A.values.reshape((A.values.shape[0], 1, 1))
|
|
634
|
-
|
|
635
|
-
return wp.array(
|
|
636
|
-
data=None,
|
|
637
|
-
ptr=A.values.ptr,
|
|
638
|
-
capacity=A.values.capacity,
|
|
639
|
-
device=A.device,
|
|
640
|
-
dtype=A.scalar_type,
|
|
641
|
-
shape=(A.values.shape[0], A.block_shape[0], A.block_shape[1]),
|
|
642
|
-
)
|
|
643
|
-
|
|
644
|
-
|
|
645
712
|
def bsr_assign(
|
|
646
713
|
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
647
714
|
src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
|
|
648
715
|
structure_only: bool = False,
|
|
716
|
+
masked: bool = False,
|
|
649
717
|
):
|
|
650
|
-
"""
|
|
718
|
+
"""Copy the content of the ``src`` BSR matrix to ``dest``.
|
|
651
719
|
|
|
652
720
|
Args:
|
|
653
|
-
src: Matrix to be copied
|
|
654
|
-
dest: Destination matrix. May have a different block shape
|
|
721
|
+
src: Matrix to be copied.
|
|
722
|
+
dest: Destination matrix. May have a different block shape or scalar type
|
|
723
|
+
than ``src``, in which case the required casting will be performed.
|
|
655
724
|
structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
|
|
656
|
-
to accommodate at least
|
|
725
|
+
to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
|
|
657
726
|
casting if the two matrices use distinct scalar types.
|
|
727
|
+
masked: If ``True``, prevent the assignment operation from adding new non-zeros blocks to ``dest``.
|
|
658
728
|
"""
|
|
659
729
|
|
|
660
730
|
src, src_scale = _extract_matrix_and_scale(src)
|
|
@@ -662,13 +732,50 @@ def bsr_assign(
|
|
|
662
732
|
if dest.values.device != src.values.device:
|
|
663
733
|
raise ValueError("Source and destination matrices must reside on the same device")
|
|
664
734
|
|
|
665
|
-
if
|
|
666
|
-
|
|
667
|
-
|
|
735
|
+
if src.block_shape[0] >= dest.block_shape[0]:
|
|
736
|
+
src_subrows = src.block_shape[0] // dest.block_shape[0]
|
|
737
|
+
dest_subrows = 1
|
|
738
|
+
else:
|
|
739
|
+
dest_subrows = dest.block_shape[0] // src.block_shape[0]
|
|
740
|
+
src_subrows = 1
|
|
741
|
+
|
|
742
|
+
if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
|
|
743
|
+
raise ValueError(
|
|
744
|
+
f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {src.block_shape[0]}, {dest.block_shape[0]})"
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
if src.block_shape[1] >= dest.block_shape[1]:
|
|
748
|
+
src_subcols = src.block_shape[1] // dest.block_shape[1]
|
|
749
|
+
dest_subcols = 1
|
|
750
|
+
else:
|
|
751
|
+
dest_subcols = dest.block_shape[1] // src.block_shape[1]
|
|
752
|
+
src_subcols = 1
|
|
753
|
+
|
|
754
|
+
if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
|
|
755
|
+
raise ValueError(
|
|
756
|
+
f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {src.block_shape[1]}, {dest.block_shape[1]})"
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
dest_nrow = (src.nrow * src_subrows) // dest_subrows
|
|
760
|
+
dest_ncol = (src.ncol * src_subcols) // dest_subcols
|
|
668
761
|
|
|
669
|
-
|
|
762
|
+
if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
|
|
763
|
+
raise ValueError("The requested block shape does not evenly divide the source matrix")
|
|
764
|
+
|
|
765
|
+
nnz_alloc = src.nnz * src_subrows * src_subcols
|
|
766
|
+
if masked:
|
|
767
|
+
if dest_nrow != dest.nrow or dest_ncol != dest.ncol:
|
|
768
|
+
raise ValueError(
|
|
769
|
+
f"Incompatible destination matrix size, expected ({dest_nrow}, {dest_ncol}), got ({dest.nrow}, {dest.ncol})"
|
|
770
|
+
)
|
|
771
|
+
else:
|
|
772
|
+
dest.nrow = dest_nrow
|
|
773
|
+
dest.ncol = dest_ncol
|
|
670
774
|
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
671
775
|
|
|
776
|
+
if dest.block_shape == src.block_shape and not masked:
|
|
777
|
+
# Direct copy
|
|
778
|
+
|
|
672
779
|
wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
|
|
673
780
|
dest.copy_nnz_async()
|
|
674
781
|
|
|
@@ -679,86 +786,29 @@ def bsr_assign(
|
|
|
679
786
|
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
|
|
680
787
|
bsr_scale(dest, src_scale)
|
|
681
788
|
|
|
682
|
-
|
|
683
|
-
#
|
|
684
|
-
|
|
685
|
-
row_factor = src.block_shape[0] // dest.block_shape[0]
|
|
686
|
-
col_factor = src.block_shape[1] // dest.block_shape[1]
|
|
687
|
-
|
|
688
|
-
if (
|
|
689
|
-
row_factor * dest.block_shape[0] != src.block_shape[0]
|
|
690
|
-
or col_factor * dest.block_shape[1] != src.block_shape[1]
|
|
691
|
-
):
|
|
692
|
-
raise ValueError(
|
|
693
|
-
f"Dest block shape {dest.block_shape} is not an exact divider of src block shape {src.block_shape}"
|
|
694
|
-
)
|
|
695
|
-
|
|
696
|
-
dest.nrow = src.nrow * row_factor
|
|
697
|
-
dest.ncol = src.ncol * col_factor
|
|
698
|
-
|
|
699
|
-
nnz_alloc = src.nnz * row_factor * col_factor
|
|
700
|
-
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
789
|
+
else:
|
|
790
|
+
# Masked and/or multiple src blocks per dest block, go through COO format
|
|
701
791
|
|
|
792
|
+
# Compute destination rows and columns
|
|
793
|
+
dest_rows = wp.empty(nnz_alloc, dtype=int, device=dest.device)
|
|
794
|
+
dest_cols = wp.empty(nnz_alloc, dtype=int, device=dest.device)
|
|
702
795
|
wp.launch(
|
|
703
|
-
|
|
704
|
-
dim=src.
|
|
705
|
-
device=dest.device,
|
|
706
|
-
inputs=[row_factor, col_factor, src.offsets, dest.offsets],
|
|
707
|
-
)
|
|
708
|
-
wp.launch(
|
|
709
|
-
_bsr_assign_split_blocks,
|
|
710
|
-
dim=dest.nnz,
|
|
796
|
+
_bsr_assign_list_blocks,
|
|
797
|
+
dim=(src.nnz, src_subrows, src_subcols),
|
|
711
798
|
device=dest.device,
|
|
712
799
|
inputs=[
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
800
|
+
src_subrows,
|
|
801
|
+
src_subcols,
|
|
802
|
+
dest_subrows,
|
|
803
|
+
dest_subcols,
|
|
804
|
+
src.nrow,
|
|
718
805
|
src.offsets,
|
|
719
806
|
src.columns,
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
dest.columns,
|
|
723
|
-
_bsr_values_as_3d_array(dest),
|
|
807
|
+
dest_rows,
|
|
808
|
+
dest_cols,
|
|
724
809
|
],
|
|
725
810
|
)
|
|
726
811
|
|
|
727
|
-
elif src.block_shape[0] <= dest.block_shape[0] and src.block_shape[1] <= dest.block_shape[1]:
|
|
728
|
-
# Merge blocks
|
|
729
|
-
|
|
730
|
-
row_factor = dest.block_shape[0] // src.block_shape[0]
|
|
731
|
-
col_factor = dest.block_shape[1] // src.block_shape[1]
|
|
732
|
-
|
|
733
|
-
if (
|
|
734
|
-
row_factor * src.block_shape[0] != dest.block_shape[0]
|
|
735
|
-
or col_factor * src.block_shape[1] != dest.block_shape[1]
|
|
736
|
-
):
|
|
737
|
-
raise ValueError(
|
|
738
|
-
f"Dest block shape {dest.block_shape} is not an exact multiple of src block shape {src.block_shape}"
|
|
739
|
-
)
|
|
740
|
-
|
|
741
|
-
if src.nrow % row_factor != 0 or src.ncol % col_factor != 0:
|
|
742
|
-
raise ValueError(
|
|
743
|
-
"The total rows and columns of the src matrix cannot be evenly divided using the requested block shape"
|
|
744
|
-
)
|
|
745
|
-
|
|
746
|
-
dest.nrow = src.nrow // row_factor
|
|
747
|
-
dest.ncol = src.ncol // col_factor
|
|
748
|
-
|
|
749
|
-
nnz_alloc = src.nnz # Conservative, in case all nnz in src belong to distinct merged blocks
|
|
750
|
-
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
751
|
-
|
|
752
|
-
# Compute destination rows and columns
|
|
753
|
-
dest_rows = wp.empty_like(src.columns)
|
|
754
|
-
dest_cols = wp.empty_like(src.columns)
|
|
755
|
-
wp.launch(
|
|
756
|
-
_bsr_assign_merge_row_col,
|
|
757
|
-
dim=src.nnz,
|
|
758
|
-
device=dest.device,
|
|
759
|
-
inputs=[row_factor, col_factor, src.nrow, src.offsets, src.columns, dest_rows, dest_cols],
|
|
760
|
-
)
|
|
761
|
-
|
|
762
812
|
# Compute destination offsets from triplets
|
|
763
813
|
from warp.context import runtime
|
|
764
814
|
|
|
@@ -773,11 +823,12 @@ def bsr_assign(
|
|
|
773
823
|
dest.block_shape[0],
|
|
774
824
|
dest.block_shape[1],
|
|
775
825
|
dest.nrow,
|
|
776
|
-
|
|
826
|
+
nnz_alloc,
|
|
777
827
|
ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
778
828
|
ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
779
829
|
0,
|
|
780
830
|
False,
|
|
831
|
+
masked,
|
|
781
832
|
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
782
833
|
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
783
834
|
0,
|
|
@@ -789,26 +840,25 @@ def bsr_assign(
|
|
|
789
840
|
if not structure_only:
|
|
790
841
|
dest.values.zero_()
|
|
791
842
|
wp.launch(
|
|
792
|
-
|
|
793
|
-
dim=src.nnz,
|
|
843
|
+
_bsr_assign_copy_blocks,
|
|
844
|
+
dim=(src.nnz, src_subrows, src_subcols),
|
|
794
845
|
device=dest.device,
|
|
795
846
|
inputs=[
|
|
796
847
|
src.scalar_type(src_scale),
|
|
797
|
-
|
|
798
|
-
|
|
848
|
+
src_subrows,
|
|
849
|
+
src_subcols,
|
|
850
|
+
dest_subrows,
|
|
851
|
+
dest_subcols,
|
|
799
852
|
src.nrow,
|
|
800
853
|
src.offsets,
|
|
801
854
|
src.columns,
|
|
802
|
-
|
|
855
|
+
src.scalar_values,
|
|
803
856
|
dest.offsets,
|
|
804
857
|
dest.columns,
|
|
805
|
-
|
|
858
|
+
dest.scalar_values,
|
|
806
859
|
],
|
|
807
860
|
)
|
|
808
861
|
|
|
809
|
-
else:
|
|
810
|
-
raise ValueError("Incompatible dest and src block shapes")
|
|
811
|
-
|
|
812
862
|
|
|
813
863
|
def bsr_copy(
|
|
814
864
|
A: BsrMatrixOrExpression,
|
|
@@ -816,15 +866,15 @@ def bsr_copy(
|
|
|
816
866
|
block_shape: Optional[Tuple[int, int]] = None,
|
|
817
867
|
structure_only: bool = False,
|
|
818
868
|
):
|
|
819
|
-
"""
|
|
869
|
+
"""Return a copy of matrix ``A``, possibly changing its scalar type.
|
|
820
870
|
|
|
821
871
|
Args:
|
|
822
|
-
A: Matrix to be copied
|
|
823
|
-
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from
|
|
824
|
-
block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from
|
|
825
|
-
Both dimensions of
|
|
872
|
+
A: Matrix to be copied.
|
|
873
|
+
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from ``A``.
|
|
874
|
+
block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from ``A``.
|
|
875
|
+
Both dimensions of ``block_shape`` must be either a multiple or an exact divider of the ones from ``A``.
|
|
826
876
|
structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
|
|
827
|
-
to accommodate at least
|
|
877
|
+
to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
|
|
828
878
|
casting if the two matrices use distinct scalar types.
|
|
829
879
|
"""
|
|
830
880
|
if scalar_type is None:
|
|
@@ -835,7 +885,7 @@ def bsr_copy(
|
|
|
835
885
|
if block_shape == (1, 1):
|
|
836
886
|
block_type = scalar_type
|
|
837
887
|
else:
|
|
838
|
-
block_type = wp.
|
|
888
|
+
block_type = wp.mat(shape=block_shape, dtype=scalar_type)
|
|
839
889
|
|
|
840
890
|
copy = bsr_zeros(
|
|
841
891
|
rows_of_blocks=A.nrow,
|
|
@@ -851,7 +901,7 @@ def bsr_set_transpose(
|
|
|
851
901
|
dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
|
|
852
902
|
src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
|
|
853
903
|
):
|
|
854
|
-
"""
|
|
904
|
+
"""Assign the transposed matrix ``src`` to matrix ``dest``."""
|
|
855
905
|
|
|
856
906
|
src, src_scale = _extract_matrix_and_scale(src)
|
|
857
907
|
|
|
@@ -912,13 +962,13 @@ def bsr_set_transpose(
|
|
|
912
962
|
bsr_scale(dest, src_scale)
|
|
913
963
|
|
|
914
964
|
|
|
915
|
-
def bsr_transposed(A: BsrMatrixOrExpression):
|
|
916
|
-
"""
|
|
965
|
+
def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
|
|
966
|
+
"""Return a copy of the transposed matrix ``A``."""
|
|
917
967
|
|
|
918
968
|
if A.block_shape == (1, 1):
|
|
919
969
|
block_type = A.values.dtype
|
|
920
970
|
else:
|
|
921
|
-
block_type = wp.
|
|
971
|
+
block_type = wp.mat(shape=A.block_shape[::-1], dtype=A.scalar_type)
|
|
922
972
|
|
|
923
973
|
transposed = bsr_zeros(
|
|
924
974
|
rows_of_blocks=A.ncol,
|
|
@@ -939,21 +989,18 @@ def _bsr_get_diag_kernel(
|
|
|
939
989
|
out: wp.array(dtype=Any),
|
|
940
990
|
):
|
|
941
991
|
row = wp.tid()
|
|
942
|
-
beg = A_offsets[row]
|
|
943
|
-
end = A_offsets[row + 1]
|
|
944
992
|
|
|
945
|
-
diag =
|
|
946
|
-
if diag
|
|
947
|
-
|
|
948
|
-
out[row] = scale * A_values[diag]
|
|
993
|
+
diag = _bsr_block_index(row, row, A_offsets, A_columns)
|
|
994
|
+
if diag != -1:
|
|
995
|
+
out[row] = scale * A_values[diag]
|
|
949
996
|
|
|
950
997
|
|
|
951
998
|
def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
|
|
952
|
-
"""
|
|
999
|
+
"""Return the array of blocks that constitute the diagonal of a sparse matrix.
|
|
953
1000
|
|
|
954
1001
|
Args:
|
|
955
|
-
A:
|
|
956
|
-
out:
|
|
1002
|
+
A: The sparse matrix from which to extract the diagonal.
|
|
1003
|
+
out: If provided, the array into which to store the diagonal blocks.
|
|
957
1004
|
"""
|
|
958
1005
|
|
|
959
1006
|
A, scale = _extract_matrix_and_scale(A)
|
|
@@ -980,36 +1027,16 @@ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[Block
|
|
|
980
1027
|
return out
|
|
981
1028
|
|
|
982
1029
|
|
|
983
|
-
@wp.kernel
|
|
1030
|
+
@wp.kernel(enable_backward=False)
|
|
984
1031
|
def _bsr_set_diag_kernel(
|
|
985
|
-
|
|
986
|
-
A_offsets: wp.array(dtype=int),
|
|
987
|
-
A_columns: wp.array(dtype=int),
|
|
988
|
-
A_values: wp.array(dtype=Any),
|
|
989
|
-
):
|
|
990
|
-
row = wp.tid()
|
|
991
|
-
A_offsets[row + 1] = row + 1
|
|
992
|
-
A_columns[row] = row
|
|
993
|
-
A_values[row] = diag[row]
|
|
994
|
-
|
|
995
|
-
if row == 0:
|
|
996
|
-
A_offsets[0] = 0
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
@wp.kernel
|
|
1000
|
-
def _bsr_set_diag_constant_kernel(
|
|
1001
|
-
diag_value: Any,
|
|
1032
|
+
nnz: int,
|
|
1002
1033
|
A_offsets: wp.array(dtype=int),
|
|
1003
1034
|
A_columns: wp.array(dtype=int),
|
|
1004
|
-
A_values: wp.array(dtype=Any),
|
|
1005
1035
|
):
|
|
1006
1036
|
row = wp.tid()
|
|
1007
|
-
A_offsets[row
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
if row == 0:
|
|
1012
|
-
A_offsets[0] = 0
|
|
1037
|
+
A_offsets[row] = wp.min(row, nnz)
|
|
1038
|
+
if row < nnz:
|
|
1039
|
+
A_columns[row] = row
|
|
1013
1040
|
|
|
1014
1041
|
|
|
1015
1042
|
def bsr_set_diag(
|
|
@@ -1017,20 +1044,26 @@ def bsr_set_diag(
|
|
|
1017
1044
|
diag: "Union[BlockType, Array[BlockType]]",
|
|
1018
1045
|
rows_of_blocks: Optional[int] = None,
|
|
1019
1046
|
cols_of_blocks: Optional[int] = None,
|
|
1020
|
-
):
|
|
1021
|
-
"""
|
|
1047
|
+
) -> None:
|
|
1048
|
+
"""Set ``A`` as a block-diagonal matrix.
|
|
1022
1049
|
|
|
1023
1050
|
Args:
|
|
1024
|
-
A:
|
|
1025
|
-
diag:
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1051
|
+
A: The sparse matrix to modify.
|
|
1052
|
+
diag: Specifies the values for diagonal blocks. Can be one of:
|
|
1053
|
+
|
|
1054
|
+
- A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
|
|
1055
|
+
- A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
|
|
1056
|
+
- ``None``: Diagonal block values are left uninitialized
|
|
1057
|
+
|
|
1058
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks.
|
|
1059
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks.
|
|
1060
|
+
|
|
1061
|
+
The shape of the matrix will be defined one of the following, in this order:
|
|
1029
1062
|
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1063
|
+
- ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
|
|
1064
|
+
If only one is given, the second is assumed equal.
|
|
1065
|
+
- The first dimension of ``diag``, if ``diag`` is an array
|
|
1066
|
+
- The current dimensions of ``A`` otherwise
|
|
1034
1067
|
"""
|
|
1035
1068
|
|
|
1036
1069
|
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
@@ -1038,7 +1071,7 @@ def bsr_set_diag(
|
|
|
1038
1071
|
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
1039
1072
|
cols_of_blocks = rows_of_blocks
|
|
1040
1073
|
|
|
1041
|
-
if
|
|
1074
|
+
if is_array(diag):
|
|
1042
1075
|
if rows_of_blocks is None:
|
|
1043
1076
|
rows_of_blocks = diag.shape[0]
|
|
1044
1077
|
cols_of_blocks = diag.shape[0]
|
|
@@ -1050,43 +1083,45 @@ def bsr_set_diag(
|
|
|
1050
1083
|
nnz = min(A.nrow, A.ncol)
|
|
1051
1084
|
_bsr_ensure_fits(A, nnz=nnz)
|
|
1052
1085
|
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
wp.launch(
|
|
1065
|
-
kernel=_bsr_set_diag_constant_kernel,
|
|
1066
|
-
dim=nnz,
|
|
1067
|
-
device=A.values.device,
|
|
1068
|
-
inputs=[diag, A.offsets, A.columns, A.values],
|
|
1069
|
-
)
|
|
1086
|
+
wp.launch(
|
|
1087
|
+
kernel=_bsr_set_diag_kernel,
|
|
1088
|
+
dim=nnz + 1,
|
|
1089
|
+
device=A.offsets.device,
|
|
1090
|
+
inputs=[nnz, A.offsets, A.columns],
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
if is_array(diag):
|
|
1094
|
+
wp.copy(src=diag, dest=A.values, count=nnz)
|
|
1095
|
+
elif diag is not None:
|
|
1096
|
+
A.values.fill_(diag)
|
|
1070
1097
|
|
|
1071
1098
|
A.copy_nnz_async(known_nnz=nnz)
|
|
1072
1099
|
|
|
1073
1100
|
|
|
1074
1101
|
def bsr_diag(
|
|
1075
|
-
diag:
|
|
1102
|
+
diag: Optional[Union[BlockType, Array[BlockType]]] = None,
|
|
1076
1103
|
rows_of_blocks: Optional[int] = None,
|
|
1077
1104
|
cols_of_blocks: Optional[int] = None,
|
|
1105
|
+
block_type: Optional[BlockType] = None,
|
|
1106
|
+
device=None,
|
|
1078
1107
|
) -> BsrMatrix["BlockType"]:
|
|
1079
|
-
"""
|
|
1108
|
+
"""Create and return a block-diagonal BSR matrix from an given block value or array of block values.
|
|
1080
1109
|
|
|
1081
1110
|
Args:
|
|
1082
|
-
diag:
|
|
1083
|
-
|
|
1111
|
+
diag: Specifies the values for diagonal blocks. Can be one of:
|
|
1112
|
+
|
|
1113
|
+
- A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
|
|
1114
|
+
- A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
|
|
1084
1115
|
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
1085
1116
|
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
1117
|
+
block_type: If ``diag`` is ``None``, block type of the matrix. Otherwise deduced from ``diag``
|
|
1118
|
+
device: If ``diag`` is not a Warp array, device on which to allocate the matrix. Otherwise deduced from ``diag``
|
|
1119
|
+
|
|
1120
|
+
The shape of the matrix will be defined one of the following, in this order:
|
|
1086
1121
|
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1122
|
+
- ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
|
|
1123
|
+
If only one is given, the second is assumed equal.
|
|
1124
|
+
- The first dimension of ``diag`` if ``diag`` is an array.
|
|
1090
1125
|
"""
|
|
1091
1126
|
|
|
1092
1127
|
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
@@ -1094,43 +1129,39 @@ def bsr_diag(
|
|
|
1094
1129
|
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
1095
1130
|
cols_of_blocks = rows_of_blocks
|
|
1096
1131
|
|
|
1097
|
-
if
|
|
1132
|
+
if is_array(diag):
|
|
1098
1133
|
if rows_of_blocks is None:
|
|
1099
1134
|
rows_of_blocks = diag.shape[0]
|
|
1100
1135
|
cols_of_blocks = diag.shape[0]
|
|
1101
1136
|
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
cols_of_blocks,
|
|
1105
|
-
block_type=diag.dtype,
|
|
1106
|
-
device=diag.device,
|
|
1107
|
-
)
|
|
1137
|
+
block_type = diag.dtype
|
|
1138
|
+
device = diag.device
|
|
1108
1139
|
else:
|
|
1109
1140
|
if rows_of_blocks is None:
|
|
1110
1141
|
raise ValueError(
|
|
1111
1142
|
"rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
|
|
1112
1143
|
)
|
|
1113
1144
|
|
|
1145
|
+
if block_type is None:
|
|
1146
|
+
if diag is None:
|
|
1147
|
+
raise ValueError("Either `diag` or `block_type` needs to be provided")
|
|
1148
|
+
|
|
1114
1149
|
block_type = type(diag)
|
|
1115
|
-
if not
|
|
1150
|
+
if not type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
|
|
1116
1151
|
block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
|
|
1117
1152
|
|
|
1118
|
-
|
|
1119
|
-
rows_of_blocks,
|
|
1120
|
-
cols_of_blocks,
|
|
1121
|
-
block_type=block_type,
|
|
1122
|
-
)
|
|
1123
|
-
|
|
1153
|
+
A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
|
|
1124
1154
|
bsr_set_diag(A, diag)
|
|
1125
1155
|
return A
|
|
1126
1156
|
|
|
1127
1157
|
|
|
1128
|
-
def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
|
|
1129
|
-
"""
|
|
1158
|
+
def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None) -> None:
|
|
1159
|
+
"""Set ``A`` as the identity matrix.
|
|
1130
1160
|
|
|
1131
1161
|
Args:
|
|
1132
|
-
A:
|
|
1133
|
-
rows_of_blocks:
|
|
1162
|
+
A: The sparse matrix to modify.
|
|
1163
|
+
rows_of_blocks: If provided, the matrix will be resized as a square
|
|
1164
|
+
matrix with ``rows_of_blocks`` rows and columns.
|
|
1134
1165
|
"""
|
|
1135
1166
|
|
|
1136
1167
|
if A.block_shape == (1, 1):
|
|
@@ -1148,11 +1179,11 @@ def bsr_identity(
|
|
|
1148
1179
|
block_type: BlockType[Rows, Rows, Scalar],
|
|
1149
1180
|
device: wp.context.Devicelike = None,
|
|
1150
1181
|
) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
|
|
1151
|
-
"""
|
|
1182
|
+
"""Create and return a square identity matrix.
|
|
1152
1183
|
|
|
1153
1184
|
Args:
|
|
1154
1185
|
rows_of_blocks: Number of rows and columns of blocks in the created matrix.
|
|
1155
|
-
block_type: Block type for the newly created matrix
|
|
1186
|
+
block_type: Block type for the newly created matrix. Must be square
|
|
1156
1187
|
device: Device onto which to allocate the data arrays
|
|
1157
1188
|
"""
|
|
1158
1189
|
A = bsr_zeros(
|
|
@@ -1174,9 +1205,7 @@ def _bsr_scale_kernel(
|
|
|
1174
1205
|
|
|
1175
1206
|
|
|
1176
1207
|
def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
1177
|
-
"""
|
|
1178
|
-
Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
|
|
1179
|
-
"""
|
|
1208
|
+
"""Perform the operation ``x := alpha * x`` on BSR matrix ``x`` and return ``x``."""
|
|
1180
1209
|
|
|
1181
1210
|
x, scale = _extract_matrix_and_scale(x)
|
|
1182
1211
|
alpha *= scale
|
|
@@ -1185,8 +1214,7 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
|
1185
1214
|
if alpha == 0.0:
|
|
1186
1215
|
bsr_set_zero(x)
|
|
1187
1216
|
else:
|
|
1188
|
-
|
|
1189
|
-
alpha = x.scalar_type(alpha)
|
|
1217
|
+
alpha = x.scalar_type(alpha)
|
|
1190
1218
|
|
|
1191
1219
|
wp.launch(
|
|
1192
1220
|
kernel=_bsr_scale_kernel,
|
|
@@ -1198,15 +1226,10 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
|
1198
1226
|
return x
|
|
1199
1227
|
|
|
1200
1228
|
|
|
1201
|
-
@wp.kernel
|
|
1202
|
-
def _bsr_get_block_row(
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
if i >= bsr_offsets[row_count]:
|
|
1206
|
-
rows[dest_offset + i] = -1 # invalid
|
|
1207
|
-
else:
|
|
1208
|
-
row = wp.lower_bound(bsr_offsets, 0, row_count + 1, i + 1) - 1
|
|
1209
|
-
rows[dest_offset + i] = row
|
|
1229
|
+
@wp.kernel(enable_backward=False)
|
|
1230
|
+
def _bsr_get_block_row(row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
|
|
1231
|
+
block = wp.tid()
|
|
1232
|
+
rows[block] = _bsr_row_index(bsr_offsets, row_count, block)
|
|
1210
1233
|
|
|
1211
1234
|
|
|
1212
1235
|
@wp.kernel
|
|
@@ -1222,21 +1245,15 @@ def _bsr_axpy_add_block(
|
|
|
1222
1245
|
):
|
|
1223
1246
|
i = wp.tid()
|
|
1224
1247
|
row = rows[i + src_offset]
|
|
1225
|
-
|
|
1226
|
-
if row < 0:
|
|
1227
|
-
return
|
|
1228
|
-
|
|
1229
1248
|
col = cols[i + src_offset]
|
|
1230
|
-
beg = dst_offsets[row]
|
|
1231
|
-
end = dst_offsets[row + 1]
|
|
1232
1249
|
|
|
1233
|
-
block =
|
|
1234
|
-
|
|
1235
|
-
|
|
1250
|
+
block = _bsr_block_index(row, col, dst_offsets, dst_columns)
|
|
1251
|
+
if block != -1:
|
|
1252
|
+
dst_values[block] += scale * src_values[i]
|
|
1236
1253
|
|
|
1237
1254
|
|
|
1238
1255
|
class bsr_axpy_work_arrays:
|
|
1239
|
-
"""Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
|
|
1256
|
+
"""Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls."""
|
|
1240
1257
|
|
|
1241
1258
|
def __init__(self):
|
|
1242
1259
|
self._reset(None)
|
|
@@ -1266,25 +1283,33 @@ def bsr_axpy(
|
|
|
1266
1283
|
y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
1267
1284
|
alpha: Scalar = 1.0,
|
|
1268
1285
|
beta: Scalar = 1.0,
|
|
1286
|
+
masked: bool = False,
|
|
1269
1287
|
work_arrays: Optional[bsr_axpy_work_arrays] = None,
|
|
1270
1288
|
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
1271
1289
|
"""
|
|
1272
|
-
|
|
1290
|
+
Perform the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices ``x`` and ``y`` and return ``y``.
|
|
1273
1291
|
|
|
1274
|
-
The
|
|
1292
|
+
The ``x`` and ``y`` matrices are allowed to alias.
|
|
1275
1293
|
|
|
1276
1294
|
Args:
|
|
1277
1295
|
x: Read-only right-hand-side.
|
|
1278
|
-
y: Mutable left-hand-side. If
|
|
1279
|
-
alpha: Uniform scaling factor for
|
|
1280
|
-
beta: Uniform scaling factor for
|
|
1281
|
-
|
|
1296
|
+
y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
|
|
1297
|
+
alpha: Uniform scaling factor for ``x``.
|
|
1298
|
+
beta: Uniform scaling factor for ``y``.
|
|
1299
|
+
masked: If ``True``, discard all blocks from ``x`` which are not
|
|
1300
|
+
existing non-zeros of ``y``.
|
|
1301
|
+
work_arrays: In most cases, this function will require the use of temporary storage.
|
|
1302
|
+
This storage can be reused across calls by passing an instance of
|
|
1303
|
+
:class:`bsr_axpy_work_arrays` in ``work_arrays``.
|
|
1282
1304
|
"""
|
|
1283
1305
|
|
|
1284
1306
|
x, x_scale = _extract_matrix_and_scale(x)
|
|
1285
1307
|
alpha *= x_scale
|
|
1286
1308
|
|
|
1287
1309
|
if y is None:
|
|
1310
|
+
if masked:
|
|
1311
|
+
raise ValueError("Left-hand-side 'y' matrix must be provided for masked addition")
|
|
1312
|
+
|
|
1288
1313
|
# If not output matrix is provided, allocate it for convenience
|
|
1289
1314
|
y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
|
|
1290
1315
|
beta = 0.0
|
|
@@ -1328,27 +1353,17 @@ def bsr_axpy(
|
|
|
1328
1353
|
work_arrays._allocate(device, y, sum_nnz)
|
|
1329
1354
|
|
|
1330
1355
|
wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
|
|
1331
|
-
|
|
1332
|
-
kernel=_bsr_get_block_row,
|
|
1333
|
-
device=device,
|
|
1334
|
-
dim=y_nnz,
|
|
1335
|
-
inputs=[0, y.nrow, y.offsets, work_arrays._sum_rows],
|
|
1336
|
-
)
|
|
1356
|
+
y.uncompress_rows(out=work_arrays._sum_rows)
|
|
1337
1357
|
|
|
1338
1358
|
wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
|
|
1339
|
-
|
|
1340
|
-
kernel=_bsr_get_block_row,
|
|
1341
|
-
device=device,
|
|
1342
|
-
dim=x_nnz,
|
|
1343
|
-
inputs=[y_nnz, x.nrow, x.offsets, work_arrays._sum_rows],
|
|
1344
|
-
)
|
|
1359
|
+
x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
|
|
1345
1360
|
|
|
1346
1361
|
# Save old y values before overwriting matrix
|
|
1347
1362
|
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y_nnz)
|
|
1348
1363
|
|
|
1349
1364
|
# Increase dest array sizes if needed
|
|
1350
|
-
if
|
|
1351
|
-
y
|
|
1365
|
+
if not masked:
|
|
1366
|
+
_bsr_ensure_fits(y, nnz=sum_nnz)
|
|
1352
1367
|
|
|
1353
1368
|
from warp.context import runtime
|
|
1354
1369
|
|
|
@@ -1370,6 +1385,7 @@ def bsr_axpy(
|
|
|
1370
1385
|
ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1371
1386
|
0,
|
|
1372
1387
|
False,
|
|
1388
|
+
masked,
|
|
1373
1389
|
ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1374
1390
|
ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1375
1391
|
0,
|
|
@@ -1377,8 +1393,6 @@ def bsr_axpy(
|
|
|
1377
1393
|
nnz_event,
|
|
1378
1394
|
)
|
|
1379
1395
|
|
|
1380
|
-
_bsr_ensure_fits(y, nnz=sum_nnz)
|
|
1381
|
-
|
|
1382
1396
|
y.values.zero_()
|
|
1383
1397
|
|
|
1384
1398
|
wp.launch(
|
|
@@ -1416,55 +1430,90 @@ def bsr_axpy(
|
|
|
1416
1430
|
return y
|
|
1417
1431
|
|
|
1418
1432
|
|
|
1419
|
-
@wp.kernel
|
|
1433
|
+
@wp.kernel(enable_backward=False)
|
|
1420
1434
|
def _bsr_mm_count_coeffs(
|
|
1435
|
+
y_ncol: int,
|
|
1421
1436
|
z_nnz: int,
|
|
1422
1437
|
x_offsets: wp.array(dtype=int),
|
|
1423
1438
|
x_columns: wp.array(dtype=int),
|
|
1424
1439
|
y_offsets: wp.array(dtype=int),
|
|
1425
|
-
|
|
1440
|
+
y_columns: wp.array(dtype=int),
|
|
1441
|
+
row_min: wp.array(dtype=int),
|
|
1442
|
+
block_counts: wp.array(dtype=int),
|
|
1426
1443
|
):
|
|
1427
1444
|
row = wp.tid()
|
|
1428
|
-
|
|
1445
|
+
row_count = int(0)
|
|
1429
1446
|
|
|
1430
1447
|
x_beg = x_offsets[row]
|
|
1431
1448
|
x_end = x_offsets[row + 1]
|
|
1432
1449
|
|
|
1450
|
+
min_col = y_ncol
|
|
1451
|
+
max_col = int(0)
|
|
1452
|
+
|
|
1433
1453
|
for x_block in range(x_beg, x_end):
|
|
1434
1454
|
x_col = x_columns[x_block]
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1455
|
+
y_row_end = y_offsets[x_col + 1]
|
|
1456
|
+
y_row_beg = y_offsets[x_col]
|
|
1457
|
+
block_count = y_row_end - y_row_beg
|
|
1458
|
+
if block_count != 0:
|
|
1459
|
+
min_col = wp.min(y_columns[y_row_beg], min_col)
|
|
1460
|
+
max_col = wp.max(y_columns[y_row_end - 1], max_col)
|
|
1461
|
+
|
|
1462
|
+
block_counts[x_block + 1] = block_count
|
|
1463
|
+
row_count += block_count
|
|
1464
|
+
|
|
1465
|
+
if row_count > wp.max(0, max_col - min_col):
|
|
1466
|
+
row_min[row] = min_col
|
|
1467
|
+
block_counts[x_end] = max_col + 1 - min_col
|
|
1468
|
+
for x_block in range(x_beg, x_end - 1):
|
|
1469
|
+
block_counts[x_block + 1] = 0
|
|
1470
|
+
else:
|
|
1471
|
+
row_min[row] = -1
|
|
1438
1472
|
|
|
1439
1473
|
if row == 0:
|
|
1440
|
-
|
|
1474
|
+
block_counts[0] = z_nnz
|
|
1441
1475
|
|
|
1442
1476
|
|
|
1443
|
-
@wp.kernel
|
|
1477
|
+
@wp.kernel(enable_backward=False)
|
|
1444
1478
|
def _bsr_mm_list_coeffs(
|
|
1479
|
+
x_nrow: int,
|
|
1445
1480
|
x_offsets: wp.array(dtype=int),
|
|
1446
1481
|
x_columns: wp.array(dtype=int),
|
|
1447
1482
|
y_offsets: wp.array(dtype=int),
|
|
1448
1483
|
y_columns: wp.array(dtype=int),
|
|
1484
|
+
mm_row_min: wp.array(dtype=int),
|
|
1449
1485
|
mm_offsets: wp.array(dtype=int),
|
|
1450
1486
|
mm_rows: wp.array(dtype=int),
|
|
1451
1487
|
mm_cols: wp.array(dtype=int),
|
|
1452
1488
|
):
|
|
1453
|
-
|
|
1454
|
-
mm_block = mm_offsets[
|
|
1489
|
+
x_block = wp.tid()
|
|
1490
|
+
mm_block = mm_offsets[x_block]
|
|
1455
1491
|
|
|
1456
|
-
|
|
1457
|
-
|
|
1492
|
+
row = _bsr_row_index(x_offsets, x_nrow, x_block)
|
|
1493
|
+
if row == -1:
|
|
1494
|
+
return
|
|
1458
1495
|
|
|
1459
|
-
|
|
1496
|
+
row_min_col = mm_row_min[row]
|
|
1497
|
+
if row_min_col != -1:
|
|
1460
1498
|
x_col = x_columns[x_block]
|
|
1461
1499
|
|
|
1462
1500
|
y_beg = y_offsets[x_col]
|
|
1463
1501
|
y_end = y_offsets[x_col + 1]
|
|
1502
|
+
|
|
1464
1503
|
for y_block in range(y_beg, y_end):
|
|
1465
|
-
|
|
1466
|
-
mm_rows[mm_block] = row
|
|
1467
|
-
mm_block
|
|
1504
|
+
col = y_columns[y_block]
|
|
1505
|
+
mm_rows[mm_block + col - row_min_col] = row
|
|
1506
|
+
mm_cols[mm_block + col - row_min_col] = col
|
|
1507
|
+
|
|
1508
|
+
return
|
|
1509
|
+
|
|
1510
|
+
x_col = x_columns[x_block]
|
|
1511
|
+
y_beg = y_offsets[x_col]
|
|
1512
|
+
y_end = y_offsets[x_col + 1]
|
|
1513
|
+
for y_block in range(y_beg, y_end):
|
|
1514
|
+
mm_cols[mm_block] = y_columns[y_block]
|
|
1515
|
+
mm_rows[mm_block] = row
|
|
1516
|
+
mm_block += 1
|
|
1468
1517
|
|
|
1469
1518
|
|
|
1470
1519
|
@wp.kernel
|
|
@@ -1483,7 +1532,10 @@ def _bsr_mm_compute_values(
|
|
|
1483
1532
|
):
|
|
1484
1533
|
mm_block = wp.tid()
|
|
1485
1534
|
|
|
1486
|
-
row =
|
|
1535
|
+
row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
|
|
1536
|
+
if row == -1:
|
|
1537
|
+
return
|
|
1538
|
+
|
|
1487
1539
|
col = mm_cols[mm_block]
|
|
1488
1540
|
|
|
1489
1541
|
mm_val = mm_values.dtype(type(alpha)(0.0))
|
|
@@ -1492,26 +1544,23 @@ def _bsr_mm_compute_values(
|
|
|
1492
1544
|
x_end = x_offsets[row + 1]
|
|
1493
1545
|
for x_block in range(x_beg, x_end):
|
|
1494
1546
|
x_col = x_columns[x_block]
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
y_block = wp.lower_bound(y_columns, y_beg, y_end, col)
|
|
1499
|
-
if y_block < y_end:
|
|
1500
|
-
if y_columns[y_block] == col:
|
|
1501
|
-
mm_val += x_values[x_block] * y_values[y_block]
|
|
1547
|
+
y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
|
|
1548
|
+
if y_block != -1:
|
|
1549
|
+
mm_val += x_values[x_block] * y_values[y_block]
|
|
1502
1550
|
|
|
1503
1551
|
mm_values[mm_block] += alpha * mm_val
|
|
1504
1552
|
|
|
1505
1553
|
|
|
1506
1554
|
class bsr_mm_work_arrays:
|
|
1507
|
-
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
|
|
1555
|
+
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
|
|
1508
1556
|
|
|
1509
1557
|
def __init__(self):
|
|
1510
1558
|
self._reset(None)
|
|
1511
1559
|
|
|
1512
1560
|
def _reset(self, device):
|
|
1513
1561
|
self.device = device
|
|
1514
|
-
self.
|
|
1562
|
+
self._mm_row_min = None
|
|
1563
|
+
self._mm_block_counts = None
|
|
1515
1564
|
self._mm_rows = None
|
|
1516
1565
|
self._mm_cols = None
|
|
1517
1566
|
self._old_z_values = None
|
|
@@ -1519,7 +1568,7 @@ class bsr_mm_work_arrays:
|
|
|
1519
1568
|
self._old_z_columns = None
|
|
1520
1569
|
self._mm_nnz = 0
|
|
1521
1570
|
|
|
1522
|
-
def _allocate_stage_1(self, device, z: BsrMatrix, beta: float, z_aliasing: bool):
|
|
1571
|
+
def _allocate_stage_1(self, device, x_nnz: int, z: BsrMatrix, beta: float, z_aliasing: bool):
|
|
1523
1572
|
if self.device != device:
|
|
1524
1573
|
self._reset(device)
|
|
1525
1574
|
|
|
@@ -1527,8 +1576,10 @@ class bsr_mm_work_arrays:
|
|
|
1527
1576
|
z_nnz = z.nnz_sync()
|
|
1528
1577
|
self._copied_z_nnz = z_nnz if beta != 0.0 or z_aliasing else 0
|
|
1529
1578
|
|
|
1530
|
-
if self.
|
|
1531
|
-
self.
|
|
1579
|
+
if self._mm_row_min is None or self._mm_block_counts.size < z.nrow + 1:
|
|
1580
|
+
self._mm_row_min = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
|
|
1581
|
+
if self._mm_block_counts is None or self._mm_block_counts.size < x_nnz + 1:
|
|
1582
|
+
self._mm_block_counts = wp.empty(shape=(x_nnz + 1,), dtype=int, device=self.device)
|
|
1532
1583
|
|
|
1533
1584
|
if self._copied_z_nnz > 0:
|
|
1534
1585
|
if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
|
|
@@ -1555,25 +1606,31 @@ def bsr_mm(
|
|
|
1555
1606
|
z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
1556
1607
|
alpha: Scalar = 1.0,
|
|
1557
1608
|
beta: Scalar = 0.0,
|
|
1609
|
+
masked: bool = False,
|
|
1558
1610
|
work_arrays: Optional[bsr_mm_work_arrays] = None,
|
|
1559
1611
|
reuse_topology: bool = False,
|
|
1560
1612
|
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
1561
1613
|
"""
|
|
1562
|
-
|
|
1614
|
+
Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
|
|
1563
1615
|
|
|
1564
|
-
The
|
|
1565
|
-
If the matrix
|
|
1616
|
+
The ``x``, ``y`` and ``z`` matrices are allowed to alias.
|
|
1617
|
+
If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
|
|
1566
1618
|
|
|
1567
1619
|
Args:
|
|
1568
1620
|
x: Read-only left factor of the matrix-matrix product.
|
|
1569
1621
|
y: Read-only right factor of the matrix-matrix product.
|
|
1570
|
-
z: Mutable left-hand-side. If
|
|
1571
|
-
alpha: Uniform scaling factor for the ``x
|
|
1572
|
-
beta: Uniform scaling factor for
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1622
|
+
z: Mutable left-hand-side. If ``z`` is not provided, it will be allocated and treated as zero.
|
|
1623
|
+
alpha: Uniform scaling factor for the ``x @ y`` product
|
|
1624
|
+
beta: Uniform scaling factor for ``z``
|
|
1625
|
+
masked: If ``True``, ignore all blocks from ``x @ y`` which are not existing non-zeros of ``y``
|
|
1626
|
+
work_arrays: In most cases, this function will require the use of temporary storage.
|
|
1627
|
+
This storage can be reused across calls by passing an instance of
|
|
1628
|
+
:class:`bsr_mm_work_arrays` in ``work_arrays``.
|
|
1629
|
+
reuse_topology: If ``True``, reuse the product topology information
|
|
1630
|
+
stored in ``work_arrays`` rather than recompute it from scratch.
|
|
1631
|
+
The matrices ``x``, ``y`` and ``z`` must be structurally similar to
|
|
1632
|
+
the previous call in which ``work_arrays`` were populated.
|
|
1633
|
+
This is necessary for ``bsr_mm`` to be captured in a CUDA graph.
|
|
1577
1634
|
"""
|
|
1578
1635
|
|
|
1579
1636
|
x, x_scale = _extract_matrix_and_scale(x)
|
|
@@ -1582,12 +1639,15 @@ def bsr_mm(
|
|
|
1582
1639
|
alpha *= y_scale
|
|
1583
1640
|
|
|
1584
1641
|
if z is None:
|
|
1642
|
+
if masked:
|
|
1643
|
+
raise ValueError("Left-hand-side 'z' matrix must be provided for masked multiplication")
|
|
1644
|
+
|
|
1585
1645
|
# If not output matrix is provided, allocate it for convenience
|
|
1586
1646
|
z_block_shape = (x.block_shape[0], y.block_shape[1])
|
|
1587
1647
|
if z_block_shape == (1, 1):
|
|
1588
1648
|
z_block_type = x.scalar_type
|
|
1589
1649
|
else:
|
|
1590
|
-
z_block_type = wp.
|
|
1650
|
+
z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
|
|
1591
1651
|
z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
|
|
1592
1652
|
beta = 0.0
|
|
1593
1653
|
|
|
@@ -1613,14 +1673,22 @@ def bsr_mm(
|
|
|
1613
1673
|
# Easy case
|
|
1614
1674
|
return bsr_scale(z, beta)
|
|
1615
1675
|
|
|
1616
|
-
if not isinstance(alpha, z.scalar_type):
|
|
1617
|
-
alpha = z.scalar_type(alpha)
|
|
1618
|
-
if not isinstance(beta, z.scalar_type):
|
|
1619
|
-
beta = z.scalar_type(beta)
|
|
1620
|
-
|
|
1621
1676
|
z_aliasing = z == x or z == y
|
|
1622
1677
|
|
|
1623
|
-
if
|
|
1678
|
+
if masked:
|
|
1679
|
+
# no need to copy z, scale in-place
|
|
1680
|
+
copied_z_nnz = 0
|
|
1681
|
+
mm_nnz = z.nnz
|
|
1682
|
+
|
|
1683
|
+
if z_aliasing:
|
|
1684
|
+
raise ValueError("`masked=True` is not supported for aliased inputs")
|
|
1685
|
+
|
|
1686
|
+
if beta == 0.0:
|
|
1687
|
+
# do not bsr_scale(0), this would not preserve topology
|
|
1688
|
+
z.values.zero_()
|
|
1689
|
+
else:
|
|
1690
|
+
bsr_scale(z, beta)
|
|
1691
|
+
elif reuse_topology:
|
|
1624
1692
|
if work_arrays is None:
|
|
1625
1693
|
raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
|
|
1626
1694
|
|
|
@@ -1633,133 +1701,142 @@ def bsr_mm(
|
|
|
1633
1701
|
if work_arrays is None:
|
|
1634
1702
|
work_arrays = bsr_mm_work_arrays()
|
|
1635
1703
|
|
|
1636
|
-
work_arrays._allocate_stage_1(device, z, beta, z_aliasing)
|
|
1704
|
+
work_arrays._allocate_stage_1(device, x.nnz, z, beta, z_aliasing)
|
|
1637
1705
|
copied_z_nnz = work_arrays._copied_z_nnz
|
|
1638
1706
|
|
|
1639
1707
|
# Prefix sum of number of (unmerged) mm blocks per row
|
|
1708
|
+
work_arrays._mm_block_counts.zero_()
|
|
1640
1709
|
wp.launch(
|
|
1641
1710
|
kernel=_bsr_mm_count_coeffs,
|
|
1642
1711
|
device=device,
|
|
1643
1712
|
dim=z.nrow,
|
|
1644
1713
|
inputs=[
|
|
1714
|
+
y.ncol,
|
|
1645
1715
|
copied_z_nnz,
|
|
1646
1716
|
x.offsets,
|
|
1647
1717
|
x.columns,
|
|
1648
1718
|
y.offsets,
|
|
1649
|
-
|
|
1719
|
+
y.columns,
|
|
1720
|
+
work_arrays._mm_row_min,
|
|
1721
|
+
work_arrays._mm_block_counts,
|
|
1650
1722
|
],
|
|
1651
1723
|
)
|
|
1652
|
-
warp.utils.array_scan(work_arrays.
|
|
1724
|
+
warp.utils.array_scan(work_arrays._mm_block_counts, work_arrays._mm_block_counts)
|
|
1653
1725
|
|
|
1654
1726
|
# Get back total counts on host -- we need a synchronization here
|
|
1655
1727
|
# Use pinned buffer from z, we are going to need it later anyway
|
|
1656
1728
|
nnz_buf, _ = z._nnz_transfer_buf_and_event()
|
|
1657
1729
|
stream = wp.get_stream(device) if device.is_cuda else None
|
|
1658
|
-
wp.copy(dest=nnz_buf, src=work_arrays.
|
|
1730
|
+
wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
|
|
1659
1731
|
if device.is_cuda:
|
|
1660
1732
|
wp.synchronize_stream(stream)
|
|
1661
1733
|
mm_nnz = int(nnz_buf.numpy()[0])
|
|
1662
1734
|
|
|
1735
|
+
if mm_nnz == copied_z_nnz:
|
|
1736
|
+
# x@y = 0
|
|
1737
|
+
return bsr_scale(z, beta)
|
|
1738
|
+
|
|
1663
1739
|
work_arrays._allocate_stage_2(mm_nnz)
|
|
1664
1740
|
|
|
1665
1741
|
# If z has a non-zero scale, save current data before overwriting it
|
|
1666
1742
|
if copied_z_nnz > 0:
|
|
1667
1743
|
# Copy z row and column indices
|
|
1668
1744
|
wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
|
|
1669
|
-
|
|
1670
|
-
kernel=_bsr_get_block_row,
|
|
1671
|
-
device=device,
|
|
1672
|
-
dim=copied_z_nnz,
|
|
1673
|
-
inputs=[0, z.nrow, z.offsets, work_arrays._mm_rows],
|
|
1674
|
-
)
|
|
1745
|
+
z.uncompress_rows(out=work_arrays._mm_rows)
|
|
1675
1746
|
if z_aliasing:
|
|
1676
1747
|
# If z is aliasing with x or y, need to save topology as well
|
|
1677
1748
|
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
1678
1749
|
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
1679
1750
|
|
|
1680
1751
|
# Fill unmerged mm blocks rows and columns
|
|
1752
|
+
work_arrays._mm_rows[copied_z_nnz:].fill_(-1)
|
|
1681
1753
|
wp.launch(
|
|
1682
1754
|
kernel=_bsr_mm_list_coeffs,
|
|
1683
1755
|
device=device,
|
|
1684
|
-
dim=
|
|
1756
|
+
dim=x.nnz,
|
|
1685
1757
|
inputs=[
|
|
1758
|
+
x.nrow,
|
|
1686
1759
|
x.offsets,
|
|
1687
1760
|
x.columns,
|
|
1688
1761
|
y.offsets,
|
|
1689
1762
|
y.columns,
|
|
1690
|
-
work_arrays.
|
|
1763
|
+
work_arrays._mm_row_min,
|
|
1764
|
+
work_arrays._mm_block_counts,
|
|
1691
1765
|
work_arrays._mm_rows,
|
|
1692
1766
|
work_arrays._mm_cols,
|
|
1693
1767
|
],
|
|
1694
1768
|
)
|
|
1695
1769
|
|
|
1770
|
+
alpha = z.scalar_type(alpha)
|
|
1771
|
+
beta = z.scalar_type(beta)
|
|
1772
|
+
|
|
1696
1773
|
if copied_z_nnz > 0:
|
|
1697
1774
|
# Save current z values in temporary buffer
|
|
1698
1775
|
wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
|
|
1699
1776
|
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
z.columns
|
|
1777
|
+
if not masked:
|
|
1778
|
+
# Increase dest array size if needed
|
|
1779
|
+
if z.columns.shape[0] < mm_nnz:
|
|
1780
|
+
z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
|
|
1703
1781
|
|
|
1704
|
-
|
|
1782
|
+
from warp.context import runtime
|
|
1705
1783
|
|
|
1706
|
-
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
|
|
1784
|
+
if device.is_cpu:
|
|
1785
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
1786
|
+
else:
|
|
1787
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
1710
1788
|
|
|
1711
|
-
|
|
1789
|
+
nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
|
|
1712
1790
|
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1791
|
+
with wp.ScopedDevice(z.device):
|
|
1792
|
+
native_func(
|
|
1793
|
+
z.block_shape[0],
|
|
1794
|
+
z.block_shape[1],
|
|
1795
|
+
z.nrow,
|
|
1796
|
+
mm_nnz,
|
|
1797
|
+
ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1798
|
+
ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1799
|
+
0,
|
|
1800
|
+
False,
|
|
1801
|
+
masked,
|
|
1802
|
+
ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1803
|
+
ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1804
|
+
0,
|
|
1805
|
+
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1806
|
+
nnz_event,
|
|
1807
|
+
)
|
|
1729
1808
|
|
|
1730
|
-
|
|
1731
|
-
|
|
1732
|
-
|
|
1733
|
-
|
|
1734
|
-
_bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
|
|
1809
|
+
# Resize z to fit mm result if necessary
|
|
1810
|
+
# If we are not reusing the product topology, this needs another synchronization
|
|
1811
|
+
if not reuse_topology:
|
|
1812
|
+
work_arrays.result_nnz = z.nnz_sync()
|
|
1735
1813
|
|
|
1736
|
-
|
|
1814
|
+
_bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
|
|
1815
|
+
z.values.zero_()
|
|
1737
1816
|
|
|
1738
|
-
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
|
|
1747
|
-
|
|
1748
|
-
|
|
1749
|
-
|
|
1750
|
-
|
|
1751
|
-
|
|
1752
|
-
|
|
1753
|
-
|
|
1754
|
-
|
|
1817
|
+
if copied_z_nnz > 0:
|
|
1818
|
+
# Add back original z values
|
|
1819
|
+
wp.launch(
|
|
1820
|
+
kernel=_bsr_axpy_add_block,
|
|
1821
|
+
device=device,
|
|
1822
|
+
dim=copied_z_nnz,
|
|
1823
|
+
inputs=[
|
|
1824
|
+
0,
|
|
1825
|
+
beta,
|
|
1826
|
+
work_arrays._mm_rows,
|
|
1827
|
+
work_arrays._mm_cols,
|
|
1828
|
+
z.offsets,
|
|
1829
|
+
z.columns,
|
|
1830
|
+
work_arrays._old_z_values,
|
|
1831
|
+
z.values,
|
|
1832
|
+
],
|
|
1833
|
+
)
|
|
1755
1834
|
|
|
1756
1835
|
# Add mm blocks to z values
|
|
1757
|
-
if (
|
|
1758
|
-
warp.types.type_is_matrix(z.values.dtype)
|
|
1759
|
-
):
|
|
1836
|
+
if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
|
|
1760
1837
|
# Result block type is scalar, but operands are matrices
|
|
1761
1838
|
# Cast result to (1x1) matrix to perform multiplication
|
|
1762
|
-
mm_values = z.values.view(wp.
|
|
1839
|
+
mm_values = z.values.view(wp.mat(shape=(1, 1), dtype=z.scalar_type))
|
|
1763
1840
|
else:
|
|
1764
1841
|
mm_values = z.values
|
|
1765
1842
|
|
|
@@ -1832,15 +1909,31 @@ def _bsr_mv_transpose_kernel(
|
|
|
1832
1909
|
wp.atomic_add(y, A_columns[block], v)
|
|
1833
1910
|
|
|
1834
1911
|
|
|
1835
|
-
def
|
|
1836
|
-
|
|
1912
|
+
def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
|
|
1913
|
+
# cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
|
|
1914
|
+
|
|
1915
|
+
scalar_count = array.size * type_length(array.dtype)
|
|
1916
|
+
if scalar_count != expected_scalar_count:
|
|
1917
|
+
raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
|
|
1918
|
+
|
|
1919
|
+
if array.ndim == 1 and types_equal(array.dtype, dtype):
|
|
1837
1920
|
return array
|
|
1838
1921
|
|
|
1922
|
+
if type_scalar_type(array.dtype) != type_scalar_type(dtype):
|
|
1923
|
+
raise ValueError(f"Incompatible scalar types, {type_repr(array.dtype)} vs {type_repr(dtype)}")
|
|
1924
|
+
|
|
1839
1925
|
if array.ndim > 2:
|
|
1840
1926
|
raise ValueError(f"Incompatible array number of dimensions {array.ndim}")
|
|
1841
1927
|
|
|
1842
1928
|
if not array.is_contiguous:
|
|
1843
|
-
raise ValueError("
|
|
1929
|
+
raise ValueError("Array must be contiguous")
|
|
1930
|
+
|
|
1931
|
+
vec_length = type_length(dtype)
|
|
1932
|
+
vec_count = scalar_count // vec_length
|
|
1933
|
+
if vec_count * vec_length != scalar_count:
|
|
1934
|
+
raise ValueError(
|
|
1935
|
+
f"Array of shape {array.shape} and type {type_repr(array.dtype)} cannot be reshaped to an array of type {type_repr(dtype)}"
|
|
1936
|
+
)
|
|
1844
1937
|
|
|
1845
1938
|
def vec_view(array):
|
|
1846
1939
|
return wp.array(
|
|
@@ -1848,8 +1941,8 @@ def _bsr_mv_as_vec_array(array: wp.array) -> wp.array:
|
|
|
1848
1941
|
ptr=array.ptr,
|
|
1849
1942
|
capacity=array.capacity,
|
|
1850
1943
|
device=array.device,
|
|
1851
|
-
dtype=
|
|
1852
|
-
shape=
|
|
1944
|
+
dtype=dtype,
|
|
1945
|
+
shape=vec_count,
|
|
1853
1946
|
grad=None if array.grad is None else vec_view(array.grad),
|
|
1854
1947
|
)
|
|
1855
1948
|
|
|
@@ -1867,20 +1960,20 @@ def bsr_mv(
|
|
|
1867
1960
|
transpose: bool = False,
|
|
1868
1961
|
work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1869
1962
|
) -> "Array[Vector[Rows, Scalar] | Scalar]":
|
|
1870
|
-
"""
|
|
1871
|
-
Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
|
|
1963
|
+
"""Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
|
|
1872
1964
|
|
|
1873
|
-
The
|
|
1965
|
+
The ``x`` and ``y`` vectors are allowed to alias.
|
|
1874
1966
|
|
|
1875
1967
|
Args:
|
|
1876
1968
|
A: Read-only, left matrix factor of the matrix-vector product.
|
|
1877
1969
|
x: Read-only, right vector factor of the matrix-vector product.
|
|
1878
|
-
y: Mutable left-hand-side. If
|
|
1879
|
-
alpha: Uniform scaling factor for
|
|
1880
|
-
beta: Uniform scaling factor for
|
|
1881
|
-
transpose: If ``True``, use the transpose of the matrix
|
|
1882
|
-
work_buffer: Temporary storage is required if and only if
|
|
1883
|
-
|
|
1970
|
+
y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
|
|
1971
|
+
alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
|
|
1972
|
+
beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
|
|
1973
|
+
transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
|
|
1974
|
+
work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
|
|
1975
|
+
If provided, the ``work_buffer`` array will be used for this purpose,
|
|
1976
|
+
otherwise a temporary allocation will be performed.
|
|
1884
1977
|
"""
|
|
1885
1978
|
|
|
1886
1979
|
A, A_scale = _extract_matrix_and_scale(A)
|
|
@@ -1900,22 +1993,11 @@ def bsr_mv(
|
|
|
1900
1993
|
y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype)
|
|
1901
1994
|
beta = 0.0
|
|
1902
1995
|
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
if not isinstance(beta, A.scalar_type):
|
|
1906
|
-
beta = A.scalar_type(beta)
|
|
1996
|
+
alpha = A.scalar_type(alpha)
|
|
1997
|
+
beta = A.scalar_type(beta)
|
|
1907
1998
|
|
|
1908
1999
|
if A.values.device != x.device or A.values.device != y.device:
|
|
1909
|
-
raise ValueError("A, x and y must reside on the same device")
|
|
1910
|
-
|
|
1911
|
-
if x.shape[0] != ncol:
|
|
1912
|
-
raise ValueError("Number of columns of A must match number of rows of x")
|
|
1913
|
-
if y.shape[0] != nrow:
|
|
1914
|
-
raise ValueError("Number of rows of A must match number of rows of y")
|
|
1915
|
-
|
|
1916
|
-
# View 2d arrays as arrays of vecs
|
|
1917
|
-
x = _bsr_mv_as_vec_array(x)
|
|
1918
|
-
y = _bsr_mv_as_vec_array(y)
|
|
2000
|
+
raise ValueError("A, x, and y must reside on the same device")
|
|
1919
2001
|
|
|
1920
2002
|
if x.ptr == y.ptr:
|
|
1921
2003
|
# Aliasing case, need temporary storage
|
|
@@ -1923,24 +2005,29 @@ def bsr_mv(
|
|
|
1923
2005
|
work_buffer = wp.empty_like(y)
|
|
1924
2006
|
elif work_buffer.size < y.size:
|
|
1925
2007
|
raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
|
|
1926
|
-
elif not
|
|
1927
|
-
raise ValueError(f"Work buffer must have same data type as y, {
|
|
2008
|
+
elif not types_equal(work_buffer.dtype, y.dtype):
|
|
2009
|
+
raise ValueError(f"Work buffer must have same data type as y, {type_repr(y.dtype)}")
|
|
1928
2010
|
|
|
1929
2011
|
# Save old y values before overwriting vector
|
|
1930
2012
|
wp.copy(dest=work_buffer, src=y, count=y.size)
|
|
1931
2013
|
x = work_buffer
|
|
1932
2014
|
|
|
1933
2015
|
# Promote scalar vectors to length-1 vecs and conversely
|
|
1934
|
-
if
|
|
1935
|
-
|
|
1936
|
-
|
|
1937
|
-
if block_shape[1] == 1 and x.dtype == A.scalar_type:
|
|
1938
|
-
x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
2016
|
+
if type_is_matrix(A.values.dtype):
|
|
2017
|
+
x_dtype = wp.vec(length=block_shape[1], dtype=A.scalar_type)
|
|
2018
|
+
y_dtype = wp.vec(length=block_shape[0], dtype=A.scalar_type)
|
|
1939
2019
|
else:
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
2020
|
+
x_dtype = A.scalar_type
|
|
2021
|
+
y_dtype = A.scalar_type
|
|
2022
|
+
|
|
2023
|
+
try:
|
|
2024
|
+
x_view = _vec_array_view(x, x_dtype, expected_scalar_count=ncol * block_shape[1])
|
|
2025
|
+
except ValueError as err:
|
|
2026
|
+
raise ValueError("Incompatible 'x' vector for bsr_mv") from err
|
|
2027
|
+
try:
|
|
2028
|
+
y_view = _vec_array_view(y, y_dtype, expected_scalar_count=nrow * block_shape[0])
|
|
2029
|
+
except ValueError as err:
|
|
2030
|
+
raise ValueError("Incompatible 'y' vector for bsr_mv") from err
|
|
1944
2031
|
|
|
1945
2032
|
if transpose:
|
|
1946
2033
|
if beta.value == 0.0:
|
|
@@ -1957,14 +2044,14 @@ def bsr_mv(
|
|
|
1957
2044
|
kernel=_bsr_mv_transpose_kernel,
|
|
1958
2045
|
device=A.values.device,
|
|
1959
2046
|
dim=ncol,
|
|
1960
|
-
inputs=[alpha, A.offsets, A.columns, A.values,
|
|
2047
|
+
inputs=[alpha, A.offsets, A.columns, A.values, x_view, y_view],
|
|
1961
2048
|
)
|
|
1962
2049
|
else:
|
|
1963
2050
|
wp.launch(
|
|
1964
2051
|
kernel=_bsr_mv_kernel,
|
|
1965
2052
|
device=A.values.device,
|
|
1966
2053
|
dim=nrow,
|
|
1967
|
-
inputs=[alpha, A.offsets, A.columns, A.values,
|
|
2054
|
+
inputs=[alpha, A.offsets, A.columns, A.values, x_view, beta, y_view],
|
|
1968
2055
|
)
|
|
1969
2056
|
|
|
1970
2057
|
return y
|