warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -6
- warp/autograd.py +823 -0
- warp/bin/warp.so +0 -0
- warp/build.py +6 -2
- warp/builtins.py +1410 -886
- warp/codegen.py +503 -166
- warp/config.py +48 -18
- warp/context.py +400 -198
- warp/dlpack.py +8 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
- warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
- warp/examples/benchmarks/benchmark_launches.py +1 -1
- warp/examples/core/example_cupy.py +78 -0
- warp/examples/fem/example_apic_fluid.py +17 -36
- warp/examples/fem/example_burgers.py +9 -18
- warp/examples/fem/example_convection_diffusion.py +7 -17
- warp/examples/fem/example_convection_diffusion_dg.py +27 -47
- warp/examples/fem/example_deformed_geometry.py +11 -22
- warp/examples/fem/example_diffusion.py +7 -18
- warp/examples/fem/example_diffusion_3d.py +24 -28
- warp/examples/fem/example_diffusion_mgpu.py +7 -14
- warp/examples/fem/example_magnetostatics.py +190 -0
- warp/examples/fem/example_mixed_elasticity.py +111 -80
- warp/examples/fem/example_navier_stokes.py +30 -34
- warp/examples/fem/example_nonconforming_contact.py +290 -0
- warp/examples/fem/example_stokes.py +17 -32
- warp/examples/fem/example_stokes_transfer.py +12 -21
- warp/examples/fem/example_streamlines.py +350 -0
- warp/examples/fem/utils.py +936 -0
- warp/fabric.py +5 -2
- warp/fem/__init__.py +13 -3
- warp/fem/cache.py +161 -11
- warp/fem/dirichlet.py +37 -28
- warp/fem/domain.py +105 -14
- warp/fem/field/__init__.py +14 -3
- warp/fem/field/field.py +454 -11
- warp/fem/field/nodal_field.py +33 -18
- warp/fem/geometry/deformed_geometry.py +50 -15
- warp/fem/geometry/hexmesh.py +12 -24
- warp/fem/geometry/nanogrid.py +106 -31
- warp/fem/geometry/quadmesh_2d.py +6 -11
- warp/fem/geometry/tetmesh.py +103 -61
- warp/fem/geometry/trimesh_2d.py +98 -47
- warp/fem/integrate.py +231 -186
- warp/fem/operator.py +14 -9
- warp/fem/quadrature/pic_quadrature.py +35 -9
- warp/fem/quadrature/quadrature.py +119 -32
- warp/fem/space/basis_space.py +98 -22
- warp/fem/space/collocated_function_space.py +3 -1
- warp/fem/space/function_space.py +7 -2
- warp/fem/space/grid_2d_function_space.py +3 -3
- warp/fem/space/grid_3d_function_space.py +4 -4
- warp/fem/space/hexmesh_function_space.py +3 -2
- warp/fem/space/nanogrid_function_space.py +12 -14
- warp/fem/space/partition.py +45 -47
- warp/fem/space/restriction.py +19 -16
- warp/fem/space/shape/cube_shape_function.py +91 -3
- warp/fem/space/shape/shape_function.py +7 -0
- warp/fem/space/shape/square_shape_function.py +32 -0
- warp/fem/space/shape/tet_shape_function.py +11 -7
- warp/fem/space/shape/triangle_shape_function.py +10 -1
- warp/fem/space/topology.py +116 -42
- warp/fem/types.py +8 -1
- warp/fem/utils.py +301 -83
- warp/native/array.h +16 -0
- warp/native/builtin.h +0 -15
- warp/native/cuda_util.cpp +14 -6
- warp/native/exports.h +1348 -1308
- warp/native/quat.h +79 -0
- warp/native/rand.h +27 -4
- warp/native/sparse.cpp +83 -81
- warp/native/sparse.cu +381 -453
- warp/native/vec.h +64 -0
- warp/native/volume.cpp +40 -49
- warp/native/volume_builder.cu +2 -3
- warp/native/volume_builder.h +12 -17
- warp/native/warp.cu +3 -3
- warp/native/warp.h +69 -59
- warp/render/render_opengl.py +17 -9
- warp/sim/articulation.py +117 -17
- warp/sim/collide.py +35 -29
- warp/sim/model.py +123 -18
- warp/sim/render.py +3 -1
- warp/sparse.py +867 -203
- warp/stubs.py +312 -541
- warp/tape.py +29 -1
- warp/tests/disabled_kinematics.py +1 -1
- warp/tests/test_adam.py +1 -1
- warp/tests/test_arithmetic.py +1 -1
- warp/tests/test_array.py +58 -1
- warp/tests/test_array_reduce.py +1 -1
- warp/tests/test_async.py +1 -1
- warp/tests/test_atomic.py +1 -1
- warp/tests/test_bool.py +1 -1
- warp/tests/test_builtins_resolution.py +1 -1
- warp/tests/test_bvh.py +6 -1
- warp/tests/test_closest_point_edge_edge.py +1 -1
- warp/tests/test_codegen.py +66 -1
- warp/tests/test_compile_consts.py +1 -1
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_copy.py +1 -1
- warp/tests/test_ctypes.py +1 -1
- warp/tests/test_dense.py +1 -1
- warp/tests/test_devices.py +1 -1
- warp/tests/test_dlpack.py +1 -1
- warp/tests/test_examples.py +33 -4
- warp/tests/test_fabricarray.py +5 -2
- warp/tests/test_fast_math.py +1 -1
- warp/tests/test_fem.py +213 -6
- warp/tests/test_fp16.py +1 -1
- warp/tests/test_func.py +1 -1
- warp/tests/test_future_annotations.py +90 -0
- warp/tests/test_generics.py +1 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +1 -1
- warp/tests/test_grad_debug.py +247 -0
- warp/tests/test_hash_grid.py +6 -1
- warp/tests/test_implicit_init.py +354 -0
- warp/tests/test_import.py +1 -1
- warp/tests/test_indexedarray.py +1 -1
- warp/tests/test_intersect.py +1 -1
- warp/tests/test_jax.py +1 -1
- warp/tests/test_large.py +1 -1
- warp/tests/test_launch.py +1 -1
- warp/tests/test_lerp.py +1 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_lvalue.py +1 -1
- warp/tests/test_marching_cubes.py +5 -2
- warp/tests/test_mat.py +34 -35
- warp/tests/test_mat_lite.py +2 -1
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_math.py +1 -1
- warp/tests/test_matmul.py +20 -16
- warp/tests/test_matmul_lite.py +1 -1
- warp/tests/test_mempool.py +1 -1
- warp/tests/test_mesh.py +5 -2
- warp/tests/test_mesh_query_aabb.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_mesh_query_ray.py +1 -1
- warp/tests/test_mlp.py +1 -1
- warp/tests/test_model.py +1 -1
- warp/tests/test_module_hashing.py +77 -1
- warp/tests/test_modules_lite.py +1 -1
- warp/tests/test_multigpu.py +1 -1
- warp/tests/test_noise.py +1 -1
- warp/tests/test_operators.py +1 -1
- warp/tests/test_options.py +1 -1
- warp/tests/test_overwrite.py +542 -0
- warp/tests/test_peer.py +1 -1
- warp/tests/test_pinned.py +1 -1
- warp/tests/test_print.py +1 -1
- warp/tests/test_quat.py +15 -1
- warp/tests/test_rand.py +1 -1
- warp/tests/test_reload.py +1 -1
- warp/tests/test_rounding.py +1 -1
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +95 -0
- warp/tests/test_sim_grad.py +1 -1
- warp/tests/test_sim_kinematics.py +1 -1
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +82 -15
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_special_values.py +2 -11
- warp/tests/test_streams.py +11 -1
- warp/tests/test_struct.py +1 -1
- warp/tests/test_tape.py +1 -1
- warp/tests/test_torch.py +194 -1
- warp/tests/test_transient_module.py +1 -1
- warp/tests/test_types.py +1 -1
- warp/tests/test_utils.py +1 -1
- warp/tests/test_vec.py +15 -63
- warp/tests/test_vec_lite.py +2 -1
- warp/tests/test_vec_scalar_ops.py +65 -1
- warp/tests/test_verify_fp.py +1 -1
- warp/tests/test_volume.py +28 -2
- warp/tests/test_volume_write.py +1 -1
- warp/tests/unittest_serial.py +1 -1
- warp/tests/unittest_suites.py +9 -1
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +2 -5
- warp/torch.py +103 -41
- warp/types.py +341 -224
- warp/utils.py +11 -2
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
- warp_lang-1.3.0.dist-info/RECORD +368 -0
- warp/examples/fem/bsr_utils.py +0 -378
- warp/examples/fem/mesh_utils.py +0 -133
- warp/examples/fem/plot_utils.py +0 -292
- warp_lang-1.2.2.dist-info/RECORD +0 -359
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/sparse.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import ctypes
|
|
1
2
|
from typing import Any, Generic, Optional, Tuple, TypeVar, Union
|
|
2
3
|
|
|
3
4
|
import warp as wp
|
|
@@ -31,7 +32,7 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
31
32
|
Attributes:
|
|
32
33
|
nrow (int): Number of rows of blocks
|
|
33
34
|
ncol (int): Number of columns of blocks
|
|
34
|
-
nnz (int):
|
|
35
|
+
nnz (int): Upper bound for the number of non-zero blocks, used for dimensioning launches; the exact number is at ``offsets[nrow-1]``. See also :meth:`nnz_sync`.
|
|
35
36
|
offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
|
|
36
37
|
columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
|
|
37
38
|
values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
|
|
@@ -68,6 +69,111 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
68
69
|
"""Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays"""
|
|
69
70
|
return self.values.device
|
|
70
71
|
|
|
72
|
+
def nnz_sync(self):
|
|
73
|
+
"""Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
|
|
74
|
+
and updates the nnz upper bound.
|
|
75
|
+
|
|
76
|
+
See also :meth:`copy_nnz_async`
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
if self._is_nnz_transfer_setup():
|
|
80
|
+
if self.device.is_cuda:
|
|
81
|
+
wp.synchronize_event(self._nnz_event)
|
|
82
|
+
self.nnz = int(self._nnz_buf.numpy()[0])
|
|
83
|
+
return self.nnz
|
|
84
|
+
|
|
85
|
+
def copy_nnz_async(self, known_nnz: int = None):
|
|
86
|
+
"""
|
|
87
|
+
Starts the asynchronous transfer of the exact nnz from the device offsets array to host, and records an event for completion.
|
|
88
|
+
Needs to be called whenever the offsets array has been modified from outside ``warp.sparse``.
|
|
89
|
+
|
|
90
|
+
See also :meth:`nnz_sync`
|
|
91
|
+
"""
|
|
92
|
+
if known_nnz is not None:
|
|
93
|
+
self.nnz = int(known_nnz)
|
|
94
|
+
else:
|
|
95
|
+
self._setup_nnz_transfer()
|
|
96
|
+
|
|
97
|
+
# If a transfer is already ongoing, or if the actual nnz is unknown, schedule a new transfer
|
|
98
|
+
if self._is_nnz_transfer_setup():
|
|
99
|
+
stream = wp.get_stream(self.device) if self.device.is_cuda else None
|
|
100
|
+
wp.copy(src=self.offsets, dest=self._nnz_buf, src_offset=self.nrow, count=1, stream=stream)
|
|
101
|
+
if self.device.is_cuda:
|
|
102
|
+
stream.record_event(self._nnz_event)
|
|
103
|
+
|
|
104
|
+
def _setup_nnz_transfer(self):
|
|
105
|
+
if self._is_nnz_transfer_setup():
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
BsrMatrix.__setattr__(
|
|
109
|
+
self, "_nnz_buf", wp.zeros(dtype=int, shape=(1,), device="cpu", pinned=self.device.is_cuda)
|
|
110
|
+
)
|
|
111
|
+
if self.device.is_cuda:
|
|
112
|
+
BsrMatrix.__setattr__(self, "_nnz_event", wp.Event(self.device))
|
|
113
|
+
|
|
114
|
+
def _is_nnz_transfer_setup(self):
|
|
115
|
+
return hasattr(self, "_nnz_buf")
|
|
116
|
+
|
|
117
|
+
def _nnz_transfer_buf_and_event(self):
|
|
118
|
+
self._setup_nnz_transfer()
|
|
119
|
+
|
|
120
|
+
if not self.device.is_cuda:
|
|
121
|
+
return self._nnz_buf, ctypes.c_void_p(None)
|
|
122
|
+
return self._nnz_buf, self._nnz_event.cuda_event
|
|
123
|
+
|
|
124
|
+
# Overloaded math operators
|
|
125
|
+
def __add__(self, y):
|
|
126
|
+
return bsr_axpy(y, bsr_copy(self))
|
|
127
|
+
|
|
128
|
+
def __iadd__(self, y):
|
|
129
|
+
return bsr_axpy(y, self)
|
|
130
|
+
|
|
131
|
+
def __radd__(self, x):
|
|
132
|
+
return bsr_axpy(x, bsr_copy(self))
|
|
133
|
+
|
|
134
|
+
def __sub__(self, y):
|
|
135
|
+
return bsr_axpy(y, bsr_copy(self), alpha=-1.0)
|
|
136
|
+
|
|
137
|
+
def __rsub__(self, x):
|
|
138
|
+
return bsr_axpy(x, bsr_copy(self), beta=-1.0)
|
|
139
|
+
|
|
140
|
+
def __isub__(self, y):
|
|
141
|
+
return bsr_axpy(y, self, alpha=-1.0)
|
|
142
|
+
|
|
143
|
+
def __mul__(self, y):
|
|
144
|
+
return _BsrScalingExpression(self, y)
|
|
145
|
+
|
|
146
|
+
def __rmul__(self, x):
|
|
147
|
+
return _BsrScalingExpression(self, x)
|
|
148
|
+
|
|
149
|
+
def __imul__(self, y):
|
|
150
|
+
return bsr_scale(self, y)
|
|
151
|
+
|
|
152
|
+
def __matmul__(self, y):
|
|
153
|
+
if isinstance(y, wp.array):
|
|
154
|
+
return bsr_mv(self, y)
|
|
155
|
+
|
|
156
|
+
return bsr_mm(self, y)
|
|
157
|
+
|
|
158
|
+
def __rmatmul__(self, x):
|
|
159
|
+
if isinstance(x, wp.array):
|
|
160
|
+
return bsr_mv(self, x, transpose=True)
|
|
161
|
+
|
|
162
|
+
return bsr_mm(x, self)
|
|
163
|
+
|
|
164
|
+
def __imatmul__(self, y):
|
|
165
|
+
return bsr_mm(self, y, self)
|
|
166
|
+
|
|
167
|
+
def __truediv__(self, y):
|
|
168
|
+
return _BsrScalingExpression(self, 1.0 / y)
|
|
169
|
+
|
|
170
|
+
def __neg__(self):
|
|
171
|
+
return _BsrScalingExpression(self, -1.0)
|
|
172
|
+
|
|
173
|
+
def transpose(self):
|
|
174
|
+
"""Returns a transposed copy of this matrix"""
|
|
175
|
+
return bsr_transposed(self)
|
|
176
|
+
|
|
71
177
|
|
|
72
178
|
def bsr_matrix_t(dtype: BlockType):
|
|
73
179
|
dtype = wp.types.type_to_warp(dtype)
|
|
@@ -83,7 +189,7 @@ def bsr_matrix_t(dtype: BlockType):
|
|
|
83
189
|
ncol: int
|
|
84
190
|
"""Number of columns of blocks"""
|
|
85
191
|
nnz: int
|
|
86
|
-
"""
|
|
192
|
+
"""Upper bound for the number of non-zeros"""
|
|
87
193
|
offsets: wp.array(dtype=int)
|
|
88
194
|
"""Array of size at least 1 + nrows"""
|
|
89
195
|
columns: wp.array(dtype=int)
|
|
@@ -130,7 +236,7 @@ def bsr_zeros(
|
|
|
130
236
|
|
|
131
237
|
bsr.nrow = int(rows_of_blocks)
|
|
132
238
|
bsr.ncol = int(cols_of_blocks)
|
|
133
|
-
bsr.nnz = 0
|
|
239
|
+
bsr.nnz = int(0)
|
|
134
240
|
bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
|
|
135
241
|
bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
|
|
136
242
|
bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
|
|
@@ -143,6 +249,9 @@ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
|
|
|
143
249
|
nrow = bsr.nrow
|
|
144
250
|
if nnz is None:
|
|
145
251
|
nnz = bsr.nnz
|
|
252
|
+
else:
|
|
253
|
+
# update nnz upper bound
|
|
254
|
+
bsr.nnz = int(nnz)
|
|
146
255
|
|
|
147
256
|
if bsr.offsets.size < nrow + 1:
|
|
148
257
|
bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
|
|
@@ -170,9 +279,10 @@ def bsr_set_zero(
|
|
|
170
279
|
bsr.nrow = int(rows_of_blocks)
|
|
171
280
|
if cols_of_blocks is not None:
|
|
172
281
|
bsr.ncol = int(cols_of_blocks)
|
|
173
|
-
|
|
174
|
-
_bsr_ensure_fits(bsr)
|
|
282
|
+
|
|
283
|
+
_bsr_ensure_fits(bsr, nnz=0)
|
|
175
284
|
bsr.offsets.zero_()
|
|
285
|
+
bsr.copy_nnz_async(known_nnz=0)
|
|
176
286
|
|
|
177
287
|
|
|
178
288
|
def bsr_set_from_triplets(
|
|
@@ -180,11 +290,12 @@ def bsr_set_from_triplets(
|
|
|
180
290
|
rows: "Array[int]",
|
|
181
291
|
columns: "Array[int]",
|
|
182
292
|
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
293
|
+
prune_numerical_zeros: bool = True,
|
|
183
294
|
):
|
|
184
295
|
"""
|
|
185
296
|
Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
186
297
|
|
|
187
|
-
The first dimension of the three input arrays must match
|
|
298
|
+
The first dimension of the three input arrays must match and indicates the number of COO triplets.
|
|
188
299
|
|
|
189
300
|
Args:
|
|
190
301
|
dest: Sparse matrix to populate
|
|
@@ -192,6 +303,7 @@ def bsr_set_from_triplets(
|
|
|
192
303
|
columns: Columns index for each non-zero
|
|
193
304
|
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
194
305
|
to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
|
|
306
|
+
prune_numerical_zeros: If True, will ignore the zero-valued blocks
|
|
195
307
|
"""
|
|
196
308
|
|
|
197
309
|
if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
|
|
@@ -244,62 +356,477 @@ def bsr_set_from_triplets(
|
|
|
244
356
|
if not native_func:
|
|
245
357
|
raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
|
|
246
358
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
359
|
+
nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
|
|
360
|
+
|
|
361
|
+
with wp.ScopedDevice(device):
|
|
362
|
+
native_func(
|
|
363
|
+
dest.block_shape[0],
|
|
364
|
+
dest.block_shape[1],
|
|
365
|
+
dest.nrow,
|
|
366
|
+
nnz,
|
|
367
|
+
ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
368
|
+
ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
369
|
+
ctypes.cast(values.ptr, ctypes.c_void_p),
|
|
370
|
+
prune_numerical_zeros,
|
|
371
|
+
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
372
|
+
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
373
|
+
ctypes.cast(dest.values.ptr, ctypes.c_void_p),
|
|
374
|
+
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
375
|
+
nnz_event,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
class _BsrExpression(Generic[_BlockType]):
|
|
380
|
+
pass
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
class _BsrScalingExpression(_BsrExpression):
|
|
384
|
+
def __init__(self, mat, scale):
|
|
385
|
+
self.mat = mat
|
|
386
|
+
self.scale = scale
|
|
387
|
+
|
|
388
|
+
def eval(self):
|
|
389
|
+
return bsr_copy(self)
|
|
390
|
+
|
|
391
|
+
@property
|
|
392
|
+
def nrow(self) -> int:
|
|
393
|
+
return self.mat.nrow
|
|
394
|
+
|
|
395
|
+
@property
|
|
396
|
+
def ncol(self) -> int:
|
|
397
|
+
return self.mat.ncol
|
|
398
|
+
|
|
399
|
+
@property
|
|
400
|
+
def nnz(self) -> int:
|
|
401
|
+
return self.mat.nnz
|
|
402
|
+
|
|
403
|
+
@property
|
|
404
|
+
def offsets(self) -> wp.array:
|
|
405
|
+
return self.mat.offsets
|
|
406
|
+
|
|
407
|
+
@property
|
|
408
|
+
def columns(self) -> wp.array:
|
|
409
|
+
return self.mat.columns
|
|
410
|
+
|
|
411
|
+
@property
|
|
412
|
+
def scalar_type(self) -> Scalar:
|
|
413
|
+
return self.mat.scalar_type
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
def block_shape(self) -> Tuple[int, int]:
|
|
417
|
+
return self.mat.block_shape
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def block_size(self) -> int:
|
|
421
|
+
return self.mat.block_size
|
|
422
|
+
|
|
423
|
+
@property
|
|
424
|
+
def shape(self) -> Tuple[int, int]:
|
|
425
|
+
return self.mat.shape
|
|
426
|
+
|
|
427
|
+
@property
|
|
428
|
+
def dtype(self) -> type:
|
|
429
|
+
return self.mat.dtype
|
|
430
|
+
|
|
431
|
+
@property
|
|
432
|
+
def device(self) -> wp.context.Device:
|
|
433
|
+
return self.mat.device
|
|
434
|
+
|
|
435
|
+
# Overloaded math operators
|
|
436
|
+
def __add__(self, y):
|
|
437
|
+
return bsr_axpy(y, bsr_copy(self.mat), alpha=self.scale)
|
|
438
|
+
|
|
439
|
+
def __radd__(self, x):
|
|
440
|
+
return bsr_axpy(x, bsr_copy(self.mat), beta=self.scale)
|
|
441
|
+
|
|
442
|
+
def __sub__(self, y):
|
|
443
|
+
return bsr_axpy(y, bsr_copy(self.mat), alpha=-self.scale)
|
|
444
|
+
|
|
445
|
+
def __rsub__(self, x):
|
|
446
|
+
return bsr_axpy(x, bsr_copy(self.mat), beta=-self.scale)
|
|
447
|
+
|
|
448
|
+
def __mul__(self, y):
|
|
449
|
+
return _BsrScalingExpression(self.mat, y * self.scale)
|
|
450
|
+
|
|
451
|
+
def __rmul__(self, x):
|
|
452
|
+
return _BsrScalingExpression(self.mat, x * self.scale)
|
|
453
|
+
|
|
454
|
+
def __matmul__(self, y):
|
|
455
|
+
if isinstance(y, wp.array):
|
|
456
|
+
return bsr_mv(self.mat, y, alpha=self.scale)
|
|
457
|
+
|
|
458
|
+
return bsr_mm(self.mat, y, alpha=self.scale)
|
|
459
|
+
|
|
460
|
+
def __rmatmul__(self, x):
|
|
461
|
+
if isinstance(x, wp.array):
|
|
462
|
+
return bsr_mv(self.mat, x, alpha=self.scale, transpose=True)
|
|
463
|
+
|
|
464
|
+
return bsr_mm(x, self.mat, alpha=self.scale)
|
|
465
|
+
|
|
466
|
+
def __truediv__(self, y):
|
|
467
|
+
return _BsrScalingExpression(self.mat, self.scale / y)
|
|
468
|
+
|
|
469
|
+
def __neg__(self):
|
|
470
|
+
return _BsrScalingExpression(self.mat, -self.scale)
|
|
471
|
+
|
|
472
|
+
def transpose(self):
|
|
473
|
+
"""Returns a transposed copy of this matrix"""
|
|
474
|
+
return _BsrScalingExpression(self.mat.transpose(), self.scale)
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
BsrMatrixOrExpression = Union[BsrMatrix[_BlockType], _BsrExpression[_BlockType]]
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
|
|
481
|
+
if isinstance(bsr, BsrMatrix):
|
|
482
|
+
return bsr, 1.0
|
|
483
|
+
if isinstance(bsr, _BsrScalingExpression):
|
|
484
|
+
return bsr.mat, bsr.scale
|
|
485
|
+
|
|
486
|
+
raise ValueError("Argument cannot be interpreted as a BsrMatrix")
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
@wp.kernel
|
|
490
|
+
def _bsr_assign_split_offsets(
|
|
491
|
+
row_factor: int,
|
|
492
|
+
col_factor: int,
|
|
493
|
+
src_offsets: wp.array(dtype=int),
|
|
494
|
+
dest_offsets: wp.array(dtype=int),
|
|
495
|
+
):
|
|
496
|
+
row = wp.tid()
|
|
497
|
+
|
|
498
|
+
base_offset = src_offsets[row] * row_factor * col_factor
|
|
499
|
+
row_count = src_offsets[1 + row] - src_offsets[row]
|
|
500
|
+
|
|
501
|
+
for k in range(row_factor):
|
|
502
|
+
dest_offsets[1 + k + row_factor * row] = base_offset + row_count * col_factor * (k + 1)
|
|
503
|
+
|
|
504
|
+
if row == 0:
|
|
505
|
+
dest_offsets[0] = 0
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
@wp.kernel
|
|
509
|
+
def _bsr_assign_split_blocks(
|
|
510
|
+
structure_only: wp.bool,
|
|
511
|
+
scale: Any,
|
|
512
|
+
row_factor: int,
|
|
513
|
+
col_factor: int,
|
|
514
|
+
dest_row_count: int,
|
|
515
|
+
src_offsets: wp.array(dtype=int),
|
|
516
|
+
src_columns: wp.array(dtype=int),
|
|
517
|
+
src_values: wp.array3d(dtype=Any),
|
|
518
|
+
dest_offsets: wp.array(dtype=int),
|
|
519
|
+
dest_columns: wp.array(dtype=int),
|
|
520
|
+
dest_values: wp.array3d(dtype=Any),
|
|
521
|
+
):
|
|
522
|
+
dest_block = wp.tid()
|
|
523
|
+
|
|
524
|
+
if dest_block >= dest_offsets[dest_row_count]:
|
|
525
|
+
return
|
|
526
|
+
|
|
527
|
+
dest_row = wp.lower_bound(dest_offsets, dest_block + 1) - 1
|
|
528
|
+
src_row = dest_row // row_factor
|
|
529
|
+
|
|
530
|
+
dest_col_in_row = dest_block - dest_offsets[dest_row]
|
|
531
|
+
src_col_in_row = dest_col_in_row // col_factor
|
|
532
|
+
|
|
533
|
+
src_block = src_offsets[src_row] + src_col_in_row
|
|
534
|
+
|
|
535
|
+
dest_rows_per_block = dest_values.shape[1]
|
|
536
|
+
dest_cols_per_block = dest_values.shape[2]
|
|
537
|
+
|
|
538
|
+
split_row = dest_row - row_factor * src_row
|
|
539
|
+
split_col = dest_col_in_row - col_factor * src_col_in_row
|
|
540
|
+
|
|
541
|
+
dest_columns[dest_block] = src_columns[src_block] * col_factor + split_col
|
|
542
|
+
|
|
543
|
+
if not structure_only:
|
|
544
|
+
src_base_i = split_row * dest_rows_per_block
|
|
545
|
+
src_base_j = split_col * dest_cols_per_block
|
|
546
|
+
for i in range(dest_rows_per_block):
|
|
547
|
+
for j in range(dest_cols_per_block):
|
|
548
|
+
dest_values[dest_block, i, j] = dest_values.dtype(
|
|
549
|
+
scale * src_values[src_block, i + src_base_i, j + src_base_j]
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
@wp.kernel
|
|
554
|
+
def _bsr_assign_merge_row_col(
|
|
555
|
+
row_factor: int,
|
|
556
|
+
col_factor: int,
|
|
557
|
+
src_row_count: int,
|
|
558
|
+
src_offsets: wp.array(dtype=int),
|
|
559
|
+
src_columns: wp.array(dtype=int),
|
|
560
|
+
dest_rows: wp.array(dtype=int),
|
|
561
|
+
dest_cols: wp.array(dtype=int),
|
|
562
|
+
):
|
|
563
|
+
block = wp.tid()
|
|
564
|
+
|
|
565
|
+
if block >= src_offsets[src_row_count]:
|
|
566
|
+
dest_rows[block] = -1 # invalid
|
|
567
|
+
dest_cols[block] = -1
|
|
568
|
+
else:
|
|
569
|
+
row = wp.lower_bound(src_offsets, block + 1) - 1
|
|
570
|
+
dest_rows[block] = row // row_factor
|
|
571
|
+
dest_cols[block] = src_columns[block] // col_factor
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
@wp.kernel
|
|
575
|
+
def _bsr_assign_merge_blocks(
|
|
576
|
+
scale: Any,
|
|
577
|
+
row_factor: int,
|
|
578
|
+
col_factor: int,
|
|
579
|
+
src_row_count: int,
|
|
580
|
+
src_offsets: wp.array(dtype=int),
|
|
581
|
+
src_columns: wp.array(dtype=int),
|
|
582
|
+
src_values: wp.array3d(dtype=Any),
|
|
583
|
+
dest_offsets: wp.array(dtype=int),
|
|
584
|
+
dest_columns: wp.array(dtype=int),
|
|
585
|
+
dest_values: wp.array3d(dtype=Any),
|
|
586
|
+
):
|
|
587
|
+
src_block = wp.tid()
|
|
588
|
+
|
|
589
|
+
if src_block >= src_offsets[src_row_count]:
|
|
590
|
+
return
|
|
591
|
+
|
|
592
|
+
src_row = wp.lower_bound(src_offsets, src_block + 1) - 1
|
|
593
|
+
src_col = src_columns[src_block]
|
|
594
|
+
|
|
595
|
+
dest_row = src_row // row_factor
|
|
596
|
+
dest_col = src_col // col_factor
|
|
597
|
+
|
|
598
|
+
dest_block = wp.lower_bound(dest_columns, dest_offsets[dest_row], dest_offsets[dest_row + 1], dest_col)
|
|
599
|
+
|
|
600
|
+
src_rows_per_block = src_values.shape[1]
|
|
601
|
+
src_cols_per_block = src_values.shape[2]
|
|
602
|
+
|
|
603
|
+
split_row = src_row - row_factor * dest_row
|
|
604
|
+
split_col = src_col - col_factor * dest_col
|
|
605
|
+
|
|
606
|
+
dest_base_i = split_row * src_rows_per_block
|
|
607
|
+
dest_base_j = split_col * src_cols_per_block
|
|
608
|
+
|
|
609
|
+
for i in range(src_rows_per_block):
|
|
610
|
+
for j in range(src_cols_per_block):
|
|
611
|
+
dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
|
|
612
|
+
scale * src_values[src_block, i, j]
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array:
|
|
617
|
+
if A.block_shape == (1, 1):
|
|
618
|
+
return A.values.reshape((A.values.shape[0], 1, 1))
|
|
619
|
+
|
|
620
|
+
return wp.array(
|
|
621
|
+
data=None,
|
|
622
|
+
ptr=A.values.ptr,
|
|
623
|
+
capacity=A.values.capacity,
|
|
624
|
+
device=A.device,
|
|
625
|
+
dtype=A.scalar_type,
|
|
626
|
+
shape=(A.values.shape[0], A.block_shape[0], A.block_shape[1]),
|
|
258
627
|
)
|
|
259
628
|
|
|
260
629
|
|
|
261
630
|
def bsr_assign(
|
|
262
631
|
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
263
|
-
src:
|
|
632
|
+
src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
|
|
633
|
+
structure_only: bool = False,
|
|
264
634
|
):
|
|
265
|
-
"""Copies the content of the `src` matrix to `dest
|
|
635
|
+
"""Copies the content of the `src` BSR matrix to `dest`.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
src: Matrix to be copied
|
|
639
|
+
dest: Destination matrix. May have a different block shape of scalar type than `src`, in which case the required casting will be performed.
|
|
640
|
+
structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
|
|
641
|
+
to accommodate at least `src.nnz` blocks. If `structure_only` is ``False``, values are also copied with implicit
|
|
642
|
+
casting if the two matrices use distinct scalar types.
|
|
643
|
+
"""
|
|
644
|
+
|
|
645
|
+
src, src_scale = _extract_matrix_and_scale(src)
|
|
266
646
|
|
|
267
647
|
if dest.values.device != src.values.device:
|
|
268
648
|
raise ValueError("Source and destination matrices must reside on the same device")
|
|
269
649
|
|
|
270
|
-
if dest.block_shape
|
|
271
|
-
|
|
650
|
+
if dest.block_shape == src.block_shape:
|
|
651
|
+
dest.nrow = src.nrow
|
|
652
|
+
dest.ncol = src.ncol
|
|
653
|
+
|
|
654
|
+
nnz_alloc = src.nnz
|
|
655
|
+
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
656
|
+
|
|
657
|
+
wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
|
|
658
|
+
dest.copy_nnz_async()
|
|
659
|
+
|
|
660
|
+
if nnz_alloc > 0:
|
|
661
|
+
wp.copy(dest=dest.columns, src=src.columns, count=nnz_alloc)
|
|
272
662
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
663
|
+
if not structure_only:
|
|
664
|
+
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
|
|
665
|
+
bsr_scale(dest, src_scale)
|
|
276
666
|
|
|
277
|
-
|
|
667
|
+
elif src.block_shape[0] >= dest.block_shape[0] and src.block_shape[1] >= dest.block_shape[1]:
|
|
668
|
+
# Split blocks
|
|
669
|
+
|
|
670
|
+
row_factor = src.block_shape[0] // dest.block_shape[0]
|
|
671
|
+
col_factor = src.block_shape[1] // dest.block_shape[1]
|
|
672
|
+
|
|
673
|
+
if (
|
|
674
|
+
row_factor * dest.block_shape[0] != src.block_shape[0]
|
|
675
|
+
or col_factor * dest.block_shape[1] != src.block_shape[1]
|
|
676
|
+
):
|
|
677
|
+
raise ValueError(
|
|
678
|
+
f"Dest block shape {dest.block_shape} is not an exact divider of src block shape {src.block_shape}"
|
|
679
|
+
)
|
|
278
680
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
|
|
282
|
-
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
|
|
681
|
+
dest.nrow = src.nrow * row_factor
|
|
682
|
+
dest.ncol = src.ncol * col_factor
|
|
283
683
|
|
|
684
|
+
nnz_alloc = src.nnz * row_factor * col_factor
|
|
685
|
+
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
284
686
|
|
|
285
|
-
|
|
687
|
+
wp.launch(
|
|
688
|
+
_bsr_assign_split_offsets,
|
|
689
|
+
dim=src.nrow,
|
|
690
|
+
device=dest.device,
|
|
691
|
+
inputs=[row_factor, col_factor, src.offsets, dest.offsets],
|
|
692
|
+
)
|
|
693
|
+
wp.launch(
|
|
694
|
+
_bsr_assign_split_blocks,
|
|
695
|
+
dim=dest.nnz,
|
|
696
|
+
device=dest.device,
|
|
697
|
+
inputs=[
|
|
698
|
+
wp.bool(structure_only),
|
|
699
|
+
src.scalar_type(src_scale),
|
|
700
|
+
row_factor,
|
|
701
|
+
col_factor,
|
|
702
|
+
dest.nrow,
|
|
703
|
+
src.offsets,
|
|
704
|
+
src.columns,
|
|
705
|
+
_bsr_values_as_3d_array(src),
|
|
706
|
+
dest.offsets,
|
|
707
|
+
dest.columns,
|
|
708
|
+
_bsr_values_as_3d_array(dest),
|
|
709
|
+
],
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
elif src.block_shape[0] <= dest.block_shape[0] and src.block_shape[1] <= dest.block_shape[1]:
|
|
713
|
+
# Merge blocks
|
|
714
|
+
|
|
715
|
+
row_factor = dest.block_shape[0] // src.block_shape[0]
|
|
716
|
+
col_factor = dest.block_shape[1] // src.block_shape[1]
|
|
717
|
+
|
|
718
|
+
if (
|
|
719
|
+
row_factor * src.block_shape[0] != dest.block_shape[0]
|
|
720
|
+
or col_factor * src.block_shape[1] != dest.block_shape[1]
|
|
721
|
+
):
|
|
722
|
+
raise ValueError(
|
|
723
|
+
f"Dest block shape {dest.block_shape} is not an exact multiple of src block shape {src.block_shape}"
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
if src.nrow % row_factor != 0 or src.ncol % col_factor != 0:
|
|
727
|
+
raise ValueError(
|
|
728
|
+
"The total rows and columns of the src matrix cannot be evenly divided using the requested block shape"
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
dest.nrow = src.nrow // row_factor
|
|
732
|
+
dest.ncol = src.ncol // col_factor
|
|
733
|
+
|
|
734
|
+
nnz_alloc = src.nnz # Conservative, in case all nnz in src belong to distinct merged blocks
|
|
735
|
+
_bsr_ensure_fits(dest, nnz=nnz_alloc)
|
|
736
|
+
|
|
737
|
+
# Compute destination rows and columns
|
|
738
|
+
dest_rows = wp.empty_like(src.columns)
|
|
739
|
+
dest_cols = wp.empty_like(src.columns)
|
|
740
|
+
wp.launch(
|
|
741
|
+
_bsr_assign_merge_row_col,
|
|
742
|
+
dim=src.nnz,
|
|
743
|
+
device=dest.device,
|
|
744
|
+
inputs=[row_factor, col_factor, src.nrow, src.offsets, src.columns, dest_rows, dest_cols],
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Compute destination offsets from triplets
|
|
748
|
+
from warp.context import runtime
|
|
749
|
+
|
|
750
|
+
if dest.device.is_cpu:
|
|
751
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
752
|
+
else:
|
|
753
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
754
|
+
|
|
755
|
+
nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
|
|
756
|
+
with wp.ScopedDevice(dest.device):
|
|
757
|
+
native_func(
|
|
758
|
+
dest.block_shape[0],
|
|
759
|
+
dest.block_shape[1],
|
|
760
|
+
dest.nrow,
|
|
761
|
+
dest.nnz,
|
|
762
|
+
ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
763
|
+
ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
764
|
+
0,
|
|
765
|
+
False,
|
|
766
|
+
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
767
|
+
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
768
|
+
0,
|
|
769
|
+
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
770
|
+
nnz_event,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
# merge block values
|
|
774
|
+
if not structure_only:
|
|
775
|
+
dest.values.zero_()
|
|
776
|
+
wp.launch(
|
|
777
|
+
_bsr_assign_merge_blocks,
|
|
778
|
+
dim=src.nnz,
|
|
779
|
+
device=dest.device,
|
|
780
|
+
inputs=[
|
|
781
|
+
src.scalar_type(src_scale),
|
|
782
|
+
row_factor,
|
|
783
|
+
col_factor,
|
|
784
|
+
src.nrow,
|
|
785
|
+
src.offsets,
|
|
786
|
+
src.columns,
|
|
787
|
+
_bsr_values_as_3d_array(src),
|
|
788
|
+
dest.offsets,
|
|
789
|
+
dest.columns,
|
|
790
|
+
_bsr_values_as_3d_array(dest),
|
|
791
|
+
],
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
else:
|
|
795
|
+
raise ValueError("Incompatible dest and src block shapes")
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
def bsr_copy(
|
|
799
|
+
A: BsrMatrixOrExpression,
|
|
800
|
+
scalar_type: Optional[Scalar] = None,
|
|
801
|
+
block_shape: Optional[Tuple[int, int]] = None,
|
|
802
|
+
structure_only: bool = False,
|
|
803
|
+
):
|
|
286
804
|
"""Returns a copy of matrix ``A``, possibly changing its scalar type.
|
|
287
805
|
|
|
288
806
|
Args:
|
|
807
|
+
A: Matrix to be copied
|
|
289
808
|
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
|
|
809
|
+
block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from `A`.
|
|
810
|
+
Both dimensions of `block_shape` must be either a multiple or an exact divider of the ones from `A`.
|
|
811
|
+
structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
|
|
812
|
+
to accommodate at least `src.nnz` blocks. If `structure_only` is ``False``, values are also copied with implicit
|
|
813
|
+
casting if the two matrices use distinct scalar types.
|
|
290
814
|
"""
|
|
291
815
|
if scalar_type is None:
|
|
292
|
-
|
|
293
|
-
|
|
816
|
+
scalar_type = A.scalar_type
|
|
817
|
+
if block_shape is None:
|
|
818
|
+
block_shape = A.block_shape
|
|
819
|
+
|
|
820
|
+
if block_shape == (1, 1):
|
|
294
821
|
block_type = scalar_type
|
|
295
822
|
else:
|
|
296
|
-
block_type = wp.types.matrix(shape=
|
|
823
|
+
block_type = wp.types.matrix(shape=block_shape, dtype=scalar_type)
|
|
297
824
|
|
|
298
825
|
copy = bsr_zeros(
|
|
299
826
|
rows_of_blocks=A.nrow,
|
|
300
827
|
cols_of_blocks=A.ncol,
|
|
301
828
|
block_type=block_type,
|
|
302
|
-
device=A.
|
|
829
|
+
device=A.device,
|
|
303
830
|
)
|
|
304
831
|
bsr_assign(dest=copy, src=A)
|
|
305
832
|
return copy
|
|
@@ -307,10 +834,12 @@ def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
|
|
|
307
834
|
|
|
308
835
|
def bsr_set_transpose(
|
|
309
836
|
dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
|
|
310
|
-
src:
|
|
837
|
+
src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
|
|
311
838
|
):
|
|
312
839
|
"""Assigns the transposed matrix `src` to matrix `dest`"""
|
|
313
840
|
|
|
841
|
+
src, src_scale = _extract_matrix_and_scale(src)
|
|
842
|
+
|
|
314
843
|
if dest.values.device != src.values.device:
|
|
315
844
|
raise ValueError("All arguments must reside on the same device")
|
|
316
845
|
|
|
@@ -322,15 +851,16 @@ def bsr_set_transpose(
|
|
|
322
851
|
if dest.block_shape != transpose_block_shape:
|
|
323
852
|
raise ValueError(f"Destination block shape must be {transpose_block_shape}")
|
|
324
853
|
|
|
854
|
+
nnz = src.nnz
|
|
325
855
|
dest.nrow = src.ncol
|
|
326
856
|
dest.ncol = src.nrow
|
|
327
|
-
dest.nnz = src.nnz
|
|
328
857
|
|
|
329
|
-
if
|
|
858
|
+
if nnz == 0:
|
|
859
|
+
bsr_set_zero(dest)
|
|
330
860
|
return
|
|
331
861
|
|
|
332
862
|
# Increase dest array sizes if needed
|
|
333
|
-
_bsr_ensure_fits(dest)
|
|
863
|
+
_bsr_ensure_fits(dest, nnz=nnz)
|
|
334
864
|
|
|
335
865
|
from warp.context import runtime
|
|
336
866
|
|
|
@@ -348,22 +878,26 @@ def bsr_set_transpose(
|
|
|
348
878
|
if not native_func:
|
|
349
879
|
raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
|
|
350
880
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
881
|
+
with wp.ScopedDevice(dest.device):
|
|
882
|
+
native_func(
|
|
883
|
+
src.block_shape[0],
|
|
884
|
+
src.block_shape[1],
|
|
885
|
+
src.nrow,
|
|
886
|
+
src.ncol,
|
|
887
|
+
nnz,
|
|
888
|
+
ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
889
|
+
ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
890
|
+
ctypes.cast(src.values.ptr, ctypes.c_void_p),
|
|
891
|
+
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
892
|
+
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
893
|
+
ctypes.cast(dest.values.ptr, ctypes.c_void_p),
|
|
894
|
+
)
|
|
364
895
|
|
|
896
|
+
dest.copy_nnz_async()
|
|
897
|
+
bsr_scale(dest, src_scale)
|
|
365
898
|
|
|
366
|
-
|
|
899
|
+
|
|
900
|
+
def bsr_transposed(A: BsrMatrixOrExpression):
|
|
367
901
|
"""Returns a copy of the transposed matrix `A`"""
|
|
368
902
|
|
|
369
903
|
if A.block_shape == (1, 1):
|
|
@@ -375,7 +909,7 @@ def bsr_transposed(A: BsrMatrix):
|
|
|
375
909
|
rows_of_blocks=A.ncol,
|
|
376
910
|
cols_of_blocks=A.nrow,
|
|
377
911
|
block_type=block_type,
|
|
378
|
-
device=A.
|
|
912
|
+
device=A.device,
|
|
379
913
|
)
|
|
380
914
|
bsr_set_transpose(dest=transposed, src=A)
|
|
381
915
|
return transposed
|
|
@@ -383,6 +917,7 @@ def bsr_transposed(A: BsrMatrix):
|
|
|
383
917
|
|
|
384
918
|
@wp.kernel
|
|
385
919
|
def _bsr_get_diag_kernel(
|
|
920
|
+
scale: Any,
|
|
386
921
|
A_offsets: wp.array(dtype=int),
|
|
387
922
|
A_columns: wp.array(dtype=int),
|
|
388
923
|
A_values: wp.array(dtype=Any),
|
|
@@ -395,10 +930,10 @@ def _bsr_get_diag_kernel(
|
|
|
395
930
|
diag = wp.lower_bound(A_columns, beg, end, row)
|
|
396
931
|
if diag < end:
|
|
397
932
|
if A_columns[diag] == row:
|
|
398
|
-
out[row] = A_values[diag]
|
|
933
|
+
out[row] = scale * A_values[diag]
|
|
399
934
|
|
|
400
935
|
|
|
401
|
-
def bsr_get_diag(A:
|
|
936
|
+
def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
|
|
402
937
|
"""Returns the array of blocks that constitute the diagonal of a sparse matrix.
|
|
403
938
|
|
|
404
939
|
Args:
|
|
@@ -406,6 +941,8 @@ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = N
|
|
|
406
941
|
out: if provided, the array into which to store the diagonal blocks
|
|
407
942
|
"""
|
|
408
943
|
|
|
944
|
+
A, scale = _extract_matrix_and_scale(A)
|
|
945
|
+
|
|
409
946
|
dim = min(A.nrow, A.ncol)
|
|
410
947
|
|
|
411
948
|
if out is None:
|
|
@@ -422,7 +959,7 @@ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = N
|
|
|
422
959
|
kernel=_bsr_get_diag_kernel,
|
|
423
960
|
dim=dim,
|
|
424
961
|
device=A.values.device,
|
|
425
|
-
inputs=[A.offsets, A.columns, A.values, out],
|
|
962
|
+
inputs=[A.scalar_type(scale), A.offsets, A.columns, A.values, out],
|
|
426
963
|
)
|
|
427
964
|
|
|
428
965
|
return out
|
|
@@ -495,13 +1032,13 @@ def bsr_set_diag(
|
|
|
495
1032
|
A.nrow = rows_of_blocks
|
|
496
1033
|
A.ncol = cols_of_blocks
|
|
497
1034
|
|
|
498
|
-
|
|
499
|
-
_bsr_ensure_fits(A)
|
|
1035
|
+
nnz = min(A.nrow, A.ncol)
|
|
1036
|
+
_bsr_ensure_fits(A, nnz=nnz)
|
|
500
1037
|
|
|
501
1038
|
if warp.types.is_array(diag):
|
|
502
1039
|
wp.launch(
|
|
503
1040
|
kernel=_bsr_set_diag_kernel,
|
|
504
|
-
dim=
|
|
1041
|
+
dim=nnz,
|
|
505
1042
|
device=A.values.device,
|
|
506
1043
|
inputs=[diag, A.offsets, A.columns, A.values],
|
|
507
1044
|
)
|
|
@@ -511,11 +1048,13 @@ def bsr_set_diag(
|
|
|
511
1048
|
diag = A.values.dtype(diag)
|
|
512
1049
|
wp.launch(
|
|
513
1050
|
kernel=_bsr_set_diag_constant_kernel,
|
|
514
|
-
dim=
|
|
1051
|
+
dim=nnz,
|
|
515
1052
|
device=A.values.device,
|
|
516
1053
|
inputs=[diag, A.offsets, A.columns, A.values],
|
|
517
1054
|
)
|
|
518
1055
|
|
|
1056
|
+
A.copy_nnz_async(known_nnz=nnz)
|
|
1057
|
+
|
|
519
1058
|
|
|
520
1059
|
def bsr_diag(
|
|
521
1060
|
diag: "Union[BlockType, Array[BlockType]]",
|
|
@@ -619,11 +1158,14 @@ def _bsr_scale_kernel(
|
|
|
619
1158
|
values[wp.tid()] = alpha * values[wp.tid()]
|
|
620
1159
|
|
|
621
1160
|
|
|
622
|
-
def bsr_scale(x:
|
|
1161
|
+
def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
623
1162
|
"""
|
|
624
1163
|
Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
|
|
625
1164
|
"""
|
|
626
1165
|
|
|
1166
|
+
x, scale = _extract_matrix_and_scale(x)
|
|
1167
|
+
alpha *= scale
|
|
1168
|
+
|
|
627
1169
|
if alpha != 1.0 and x.nnz > 0:
|
|
628
1170
|
if alpha == 0.0:
|
|
629
1171
|
bsr_set_zero(x)
|
|
@@ -642,11 +1184,14 @@ def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
|
|
|
642
1184
|
|
|
643
1185
|
|
|
644
1186
|
@wp.kernel
|
|
645
|
-
def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
|
|
1187
|
+
def _bsr_get_block_row(dest_offset: int, row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
|
|
646
1188
|
i = wp.tid()
|
|
647
1189
|
|
|
648
|
-
|
|
649
|
-
|
|
1190
|
+
if i >= bsr_offsets[row_count]:
|
|
1191
|
+
rows[dest_offset + i] = -1 # invalid
|
|
1192
|
+
else:
|
|
1193
|
+
row = wp.lower_bound(bsr_offsets, i + 1) - 1
|
|
1194
|
+
rows[dest_offset + i] = row
|
|
650
1195
|
|
|
651
1196
|
|
|
652
1197
|
@wp.kernel
|
|
@@ -662,6 +1207,10 @@ def _bsr_axpy_add_block(
|
|
|
662
1207
|
):
|
|
663
1208
|
i = wp.tid()
|
|
664
1209
|
row = rows[i + src_offset]
|
|
1210
|
+
|
|
1211
|
+
if row < 0:
|
|
1212
|
+
return
|
|
1213
|
+
|
|
665
1214
|
col = cols[i + src_offset]
|
|
666
1215
|
beg = dst_offsets[row]
|
|
667
1216
|
end = dst_offsets[row + 1]
|
|
@@ -694,11 +1243,11 @@ class bsr_axpy_work_arrays:
|
|
|
694
1243
|
self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
695
1244
|
|
|
696
1245
|
if self._old_y_values is None or self._old_y_values.size < y.nnz:
|
|
697
|
-
self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
|
|
1246
|
+
self._old_y_values = wp.empty(shape=(y.nnz,), dtype=y.values.dtype, device=self.device)
|
|
698
1247
|
|
|
699
1248
|
|
|
700
1249
|
def bsr_axpy(
|
|
701
|
-
x:
|
|
1250
|
+
x: BsrMatrixOrExpression,
|
|
702
1251
|
y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
703
1252
|
alpha: Scalar = 1.0,
|
|
704
1253
|
beta: Scalar = 1.0,
|
|
@@ -717,17 +1266,23 @@ def bsr_axpy(
|
|
|
717
1266
|
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
|
|
718
1267
|
"""
|
|
719
1268
|
|
|
1269
|
+
x, x_scale = _extract_matrix_and_scale(x)
|
|
1270
|
+
alpha *= x_scale
|
|
1271
|
+
|
|
720
1272
|
if y is None:
|
|
721
1273
|
# If not output matrix is provided, allocate it for convenience
|
|
722
1274
|
y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
|
|
723
1275
|
beta = 0.0
|
|
724
1276
|
|
|
1277
|
+
x_nnz = x.nnz
|
|
1278
|
+
y_nnz = y.nnz
|
|
1279
|
+
|
|
725
1280
|
# Handle easy cases first
|
|
726
|
-
if beta == 0.0 or
|
|
1281
|
+
if beta == 0.0 or y_nnz == 0:
|
|
727
1282
|
bsr_assign(src=x, dest=y)
|
|
728
1283
|
return bsr_scale(y, alpha=alpha)
|
|
729
1284
|
|
|
730
|
-
if alpha == 0.0 or
|
|
1285
|
+
if alpha == 0.0 or x_nnz == 0:
|
|
731
1286
|
return bsr_scale(y, alpha=beta)
|
|
732
1287
|
|
|
733
1288
|
if not isinstance(alpha, y.scalar_type):
|
|
@@ -753,28 +1308,28 @@ def bsr_axpy(
|
|
|
753
1308
|
if work_arrays is None:
|
|
754
1309
|
work_arrays = bsr_axpy_work_arrays()
|
|
755
1310
|
|
|
756
|
-
sum_nnz =
|
|
1311
|
+
sum_nnz = x_nnz + y_nnz
|
|
757
1312
|
device = y.values.device
|
|
758
1313
|
work_arrays._allocate(device, y, sum_nnz)
|
|
759
1314
|
|
|
760
|
-
wp.copy(work_arrays._sum_cols, y.columns, 0, 0,
|
|
1315
|
+
wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
|
|
761
1316
|
wp.launch(
|
|
762
1317
|
kernel=_bsr_get_block_row,
|
|
763
1318
|
device=device,
|
|
764
|
-
dim=
|
|
765
|
-
inputs=[0, y.offsets, work_arrays._sum_rows],
|
|
1319
|
+
dim=y_nnz,
|
|
1320
|
+
inputs=[0, y.nrow, y.offsets, work_arrays._sum_rows],
|
|
766
1321
|
)
|
|
767
1322
|
|
|
768
|
-
wp.copy(work_arrays._sum_cols, x.columns,
|
|
1323
|
+
wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
|
|
769
1324
|
wp.launch(
|
|
770
1325
|
kernel=_bsr_get_block_row,
|
|
771
1326
|
device=device,
|
|
772
|
-
dim=
|
|
773
|
-
inputs=[
|
|
1327
|
+
dim=x_nnz,
|
|
1328
|
+
inputs=[y_nnz, x.nrow, x.offsets, work_arrays._sum_rows],
|
|
774
1329
|
)
|
|
775
1330
|
|
|
776
1331
|
# Save old y values before overwriting matrix
|
|
777
|
-
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=
|
|
1332
|
+
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y_nnz)
|
|
778
1333
|
|
|
779
1334
|
# Increase dest array sizes if needed
|
|
780
1335
|
if y.columns.shape[0] < sum_nnz:
|
|
@@ -787,21 +1342,28 @@ def bsr_axpy(
|
|
|
787
1342
|
else:
|
|
788
1343
|
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
789
1344
|
|
|
790
|
-
old_y_nnz =
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
1345
|
+
old_y_nnz = y_nnz
|
|
1346
|
+
nnz_buf, nnz_event = y._nnz_transfer_buf_and_event()
|
|
1347
|
+
|
|
1348
|
+
with wp.ScopedDevice(y.device):
|
|
1349
|
+
native_func(
|
|
1350
|
+
y.block_shape[0],
|
|
1351
|
+
y.block_shape[1],
|
|
1352
|
+
y.nrow,
|
|
1353
|
+
sum_nnz,
|
|
1354
|
+
ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1355
|
+
ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1356
|
+
0,
|
|
1357
|
+
False,
|
|
1358
|
+
ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1359
|
+
ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1360
|
+
0,
|
|
1361
|
+
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1362
|
+
nnz_event,
|
|
1363
|
+
)
|
|
1364
|
+
|
|
1365
|
+
_bsr_ensure_fits(y, nnz=sum_nnz)
|
|
803
1366
|
|
|
804
|
-
_bsr_ensure_fits(y)
|
|
805
1367
|
y.values.zero_()
|
|
806
1368
|
|
|
807
1369
|
wp.launch(
|
|
@@ -823,7 +1385,7 @@ def bsr_axpy(
|
|
|
823
1385
|
wp.launch(
|
|
824
1386
|
kernel=_bsr_axpy_add_block,
|
|
825
1387
|
device=device,
|
|
826
|
-
dim=
|
|
1388
|
+
dim=x_nnz,
|
|
827
1389
|
inputs=[
|
|
828
1390
|
old_y_nnz,
|
|
829
1391
|
alpha,
|
|
@@ -918,8 +1480,9 @@ def _bsr_mm_compute_values(
|
|
|
918
1480
|
y_end = y_offsets[x_col + 1]
|
|
919
1481
|
|
|
920
1482
|
y_block = wp.lower_bound(y_columns, y_beg, y_end, col)
|
|
921
|
-
if y_block < y_end
|
|
922
|
-
|
|
1483
|
+
if y_block < y_end:
|
|
1484
|
+
if y_columns[y_block] == col:
|
|
1485
|
+
mm_val += x_values[x_block] * y_values[y_block]
|
|
923
1486
|
|
|
924
1487
|
mm_values[mm_block] += alpha * mm_val
|
|
925
1488
|
|
|
@@ -932,38 +1495,38 @@ class bsr_mm_work_arrays:
|
|
|
932
1495
|
|
|
933
1496
|
def _reset(self, device):
|
|
934
1497
|
self.device = device
|
|
935
|
-
self._pinned_count_buffer = None
|
|
936
1498
|
self._mm_row_counts = None
|
|
937
1499
|
self._mm_rows = None
|
|
938
1500
|
self._mm_cols = None
|
|
939
1501
|
self._old_z_values = None
|
|
940
1502
|
self._old_z_offsets = None
|
|
941
1503
|
self._old_z_columns = None
|
|
1504
|
+
self._mm_nnz = 0
|
|
942
1505
|
|
|
943
|
-
def _allocate_stage_1(self, device, z: BsrMatrix,
|
|
1506
|
+
def _allocate_stage_1(self, device, z: BsrMatrix, beta: float, z_aliasing: bool):
|
|
944
1507
|
if self.device != device:
|
|
945
1508
|
self._reset(device)
|
|
946
1509
|
|
|
947
1510
|
# Allocations that do not depend on any computation
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
|
|
1511
|
+
z_nnz = z.nnz_sync()
|
|
1512
|
+
self._copied_z_nnz = z_nnz if beta != 0.0 or z_aliasing else 0
|
|
951
1513
|
|
|
952
1514
|
if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
|
|
953
1515
|
self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
|
|
954
1516
|
|
|
955
|
-
if
|
|
956
|
-
if self._old_z_values is None or self._old_z_values.size <
|
|
957
|
-
self._old_z_values = wp.empty(shape=(
|
|
1517
|
+
if self._copied_z_nnz > 0:
|
|
1518
|
+
if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
|
|
1519
|
+
self._old_z_values = wp.empty(shape=(self._copied_z_nnz,), dtype=z.values.dtype, device=self.device)
|
|
958
1520
|
|
|
959
1521
|
if z_aliasing:
|
|
960
|
-
if self._old_z_columns is None or self._old_z_columns.size <
|
|
961
|
-
self._old_z_columns = wp.empty(shape=(
|
|
1522
|
+
if self._old_z_columns is None or self._old_z_columns.size < z_nnz:
|
|
1523
|
+
self._old_z_columns = wp.empty(shape=(z_nnz,), dtype=z.columns.dtype, device=self.device)
|
|
962
1524
|
if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
|
|
963
1525
|
self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
|
|
964
1526
|
|
|
965
1527
|
def _allocate_stage_2(self, mm_nnz: int):
|
|
966
1528
|
# Allocations that depend on unmerged nnz estimate
|
|
1529
|
+
self._mm_nnz = mm_nnz
|
|
967
1530
|
if self._mm_rows is None or self._mm_rows.size < mm_nnz:
|
|
968
1531
|
self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
969
1532
|
if self._mm_cols is None or self._mm_cols.size < mm_nnz:
|
|
@@ -971,12 +1534,13 @@ class bsr_mm_work_arrays:
|
|
|
971
1534
|
|
|
972
1535
|
|
|
973
1536
|
def bsr_mm(
|
|
974
|
-
x:
|
|
975
|
-
y:
|
|
1537
|
+
x: BsrMatrixOrExpression[BlockType[Rows, Any, Scalar]],
|
|
1538
|
+
y: BsrMatrixOrExpression[BlockType[Any, Cols, Scalar]],
|
|
976
1539
|
z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
977
1540
|
alpha: Scalar = 1.0,
|
|
978
1541
|
beta: Scalar = 0.0,
|
|
979
1542
|
work_arrays: Optional[bsr_mm_work_arrays] = None,
|
|
1543
|
+
reuse_topology: bool = False,
|
|
980
1544
|
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
981
1545
|
"""
|
|
982
1546
|
Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
|
|
@@ -991,8 +1555,16 @@ def bsr_mm(
|
|
|
991
1555
|
alpha: Uniform scaling factor for the ``x * y`` product
|
|
992
1556
|
beta: Uniform scaling factor for `z`
|
|
993
1557
|
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
|
|
1558
|
+
reuse_topology: If True, reuse the product topology information stored in `work_arrays` rather than recompute it from scratch.
|
|
1559
|
+
The matrices x, y and z must be structurally similar to the previous call in which `work_arrays` were populated.
|
|
1560
|
+
This is necessary for `bsr_mm` to be captured in a CUDA graph.
|
|
994
1561
|
"""
|
|
995
1562
|
|
|
1563
|
+
x, x_scale = _extract_matrix_and_scale(x)
|
|
1564
|
+
alpha *= x_scale
|
|
1565
|
+
y, y_scale = _extract_matrix_and_scale(y)
|
|
1566
|
+
alpha *= y_scale
|
|
1567
|
+
|
|
996
1568
|
if z is None:
|
|
997
1569
|
# If not output matrix is provided, allocate it for convenience
|
|
998
1570
|
z_block_shape = (x.block_shape[0], y.block_shape[1])
|
|
@@ -1030,76 +1602,84 @@ def bsr_mm(
|
|
|
1030
1602
|
if not isinstance(beta, z.scalar_type):
|
|
1031
1603
|
beta = z.scalar_type(beta)
|
|
1032
1604
|
|
|
1033
|
-
if work_arrays is None:
|
|
1034
|
-
work_arrays = bsr_mm_work_arrays()
|
|
1035
|
-
|
|
1036
1605
|
z_aliasing = z == x or z == y
|
|
1037
|
-
copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
|
|
1038
1606
|
|
|
1039
|
-
|
|
1607
|
+
if reuse_topology:
|
|
1608
|
+
if work_arrays is None:
|
|
1609
|
+
raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
|
|
1040
1610
|
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
kernel=_bsr_mm_count_coeffs,
|
|
1044
|
-
device=device,
|
|
1045
|
-
dim=z.nrow,
|
|
1046
|
-
inputs=[
|
|
1047
|
-
copied_z_nnz,
|
|
1048
|
-
x.offsets,
|
|
1049
|
-
x.columns,
|
|
1050
|
-
y.offsets,
|
|
1051
|
-
work_arrays._mm_row_counts,
|
|
1052
|
-
],
|
|
1053
|
-
)
|
|
1054
|
-
warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
|
|
1055
|
-
|
|
1056
|
-
# Get back total counts on host
|
|
1057
|
-
if device.is_cuda:
|
|
1058
|
-
wp.copy(
|
|
1059
|
-
dest=work_arrays._pinned_count_buffer,
|
|
1060
|
-
src=work_arrays._mm_row_counts,
|
|
1061
|
-
src_offset=z.nrow,
|
|
1062
|
-
count=1,
|
|
1063
|
-
)
|
|
1064
|
-
wp.synchronize_stream(wp.get_stream(device))
|
|
1065
|
-
mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
|
|
1611
|
+
copied_z_nnz = work_arrays._copied_z_nnz
|
|
1612
|
+
mm_nnz = work_arrays._mm_nnz
|
|
1066
1613
|
else:
|
|
1067
|
-
|
|
1614
|
+
if device.is_capturing:
|
|
1615
|
+
raise RuntimeError("`bsr_mm` requires `reuse_topology=True` for use in graph capture")
|
|
1068
1616
|
|
|
1069
|
-
|
|
1617
|
+
if work_arrays is None:
|
|
1618
|
+
work_arrays = bsr_mm_work_arrays()
|
|
1070
1619
|
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1620
|
+
work_arrays._allocate_stage_1(device, z, beta, z_aliasing)
|
|
1621
|
+
copied_z_nnz = work_arrays._copied_z_nnz
|
|
1622
|
+
|
|
1623
|
+
# Prefix sum of number of (unmerged) mm blocks per row
|
|
1075
1624
|
wp.launch(
|
|
1076
|
-
kernel=
|
|
1625
|
+
kernel=_bsr_mm_count_coeffs,
|
|
1077
1626
|
device=device,
|
|
1078
|
-
dim=
|
|
1079
|
-
inputs=[
|
|
1627
|
+
dim=z.nrow,
|
|
1628
|
+
inputs=[
|
|
1629
|
+
copied_z_nnz,
|
|
1630
|
+
x.offsets,
|
|
1631
|
+
x.columns,
|
|
1632
|
+
y.offsets,
|
|
1633
|
+
work_arrays._mm_row_counts,
|
|
1634
|
+
],
|
|
1635
|
+
)
|
|
1636
|
+
warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
|
|
1637
|
+
|
|
1638
|
+
# Get back total counts on host -- we need a synchronization here
|
|
1639
|
+
# Use pinned buffer from z, we are going to need it later anyway
|
|
1640
|
+
nnz_buf, _ = z._nnz_transfer_buf_and_event()
|
|
1641
|
+
stream = wp.get_stream(device) if device.is_cuda else None
|
|
1642
|
+
wp.copy(dest=nnz_buf, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1, stream=stream)
|
|
1643
|
+
if device.is_cuda:
|
|
1644
|
+
wp.synchronize_stream(stream)
|
|
1645
|
+
mm_nnz = int(nnz_buf.numpy()[0])
|
|
1646
|
+
|
|
1647
|
+
work_arrays._allocate_stage_2(mm_nnz)
|
|
1648
|
+
|
|
1649
|
+
# If z has a non-zero scale, save current data before overwriting it
|
|
1650
|
+
if copied_z_nnz > 0:
|
|
1651
|
+
# Copy z row and column indices
|
|
1652
|
+
wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
|
|
1653
|
+
wp.launch(
|
|
1654
|
+
kernel=_bsr_get_block_row,
|
|
1655
|
+
device=device,
|
|
1656
|
+
dim=copied_z_nnz,
|
|
1657
|
+
inputs=[0, z.nrow, z.offsets, work_arrays._mm_rows],
|
|
1658
|
+
)
|
|
1659
|
+
if z_aliasing:
|
|
1660
|
+
# If z is aliasing with x or y, need to save topology as well
|
|
1661
|
+
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
1662
|
+
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
1663
|
+
|
|
1664
|
+
# Fill unmerged mm blocks rows and columns
|
|
1665
|
+
wp.launch(
|
|
1666
|
+
kernel=_bsr_mm_list_coeffs,
|
|
1667
|
+
device=device,
|
|
1668
|
+
dim=z.nrow,
|
|
1669
|
+
inputs=[
|
|
1670
|
+
x.offsets,
|
|
1671
|
+
x.columns,
|
|
1672
|
+
y.offsets,
|
|
1673
|
+
y.columns,
|
|
1674
|
+
work_arrays._mm_row_counts,
|
|
1675
|
+
work_arrays._mm_rows,
|
|
1676
|
+
work_arrays._mm_cols,
|
|
1677
|
+
],
|
|
1080
1678
|
)
|
|
1679
|
+
|
|
1680
|
+
if copied_z_nnz > 0:
|
|
1081
1681
|
# Save current z values in temporary buffer
|
|
1082
1682
|
wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
|
|
1083
|
-
if z_aliasing:
|
|
1084
|
-
# If z is aliasing with x or y, need to save topology as well
|
|
1085
|
-
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
1086
|
-
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
1087
|
-
|
|
1088
|
-
# Fill unmerged mm blocks rows and columns
|
|
1089
|
-
wp.launch(
|
|
1090
|
-
kernel=_bsr_mm_list_coeffs,
|
|
1091
|
-
device=device,
|
|
1092
|
-
dim=z.nrow,
|
|
1093
|
-
inputs=[
|
|
1094
|
-
x.offsets,
|
|
1095
|
-
x.columns,
|
|
1096
|
-
y.offsets,
|
|
1097
|
-
y.columns,
|
|
1098
|
-
work_arrays._mm_row_counts,
|
|
1099
|
-
work_arrays._mm_rows,
|
|
1100
|
-
work_arrays._mm_cols,
|
|
1101
|
-
],
|
|
1102
|
-
)
|
|
1103
1683
|
|
|
1104
1684
|
# Increase dest array size if needed
|
|
1105
1685
|
if z.columns.shape[0] < mm_nnz:
|
|
@@ -1112,20 +1692,31 @@ def bsr_mm(
|
|
|
1112
1692
|
else:
|
|
1113
1693
|
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
1114
1694
|
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1695
|
+
nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
|
|
1696
|
+
|
|
1697
|
+
with wp.ScopedDevice(z.device):
|
|
1698
|
+
native_func(
|
|
1699
|
+
z.block_shape[0],
|
|
1700
|
+
z.block_shape[1],
|
|
1701
|
+
z.nrow,
|
|
1702
|
+
mm_nnz,
|
|
1703
|
+
ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1704
|
+
ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1705
|
+
0,
|
|
1706
|
+
False,
|
|
1707
|
+
ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1708
|
+
ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1709
|
+
0,
|
|
1710
|
+
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1711
|
+
nnz_event,
|
|
1712
|
+
)
|
|
1713
|
+
|
|
1714
|
+
# Resize z to fit mm result if necessary
|
|
1715
|
+
# If we are not reusing the product topology, this needs another synchronization
|
|
1716
|
+
if not reuse_topology:
|
|
1717
|
+
work_arrays.result_nnz = z.nnz_sync()
|
|
1718
|
+
_bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
|
|
1127
1719
|
|
|
1128
|
-
_bsr_ensure_fits(z)
|
|
1129
1720
|
z.values.zero_()
|
|
1130
1721
|
|
|
1131
1722
|
if copied_z_nnz > 0:
|
|
@@ -1206,12 +1797,57 @@ def _bsr_mv_kernel(
|
|
|
1206
1797
|
y[row] = v
|
|
1207
1798
|
|
|
1208
1799
|
|
|
1800
|
+
@wp.kernel
|
|
1801
|
+
def _bsr_mv_transpose_kernel(
|
|
1802
|
+
alpha: Any,
|
|
1803
|
+
A_offsets: wp.array(dtype=int),
|
|
1804
|
+
A_columns: wp.array(dtype=int),
|
|
1805
|
+
A_values: wp.array(dtype=Any),
|
|
1806
|
+
x: wp.array(dtype=Any),
|
|
1807
|
+
y: wp.array(dtype=Any),
|
|
1808
|
+
):
|
|
1809
|
+
row = wp.tid()
|
|
1810
|
+
beg = A_offsets[row]
|
|
1811
|
+
end = A_offsets[row + 1]
|
|
1812
|
+
xr = alpha * x[row]
|
|
1813
|
+
for block in range(beg, end):
|
|
1814
|
+
v = wp.transpose(A_values[block]) * xr
|
|
1815
|
+
wp.atomic_add(y, A_columns[block], v)
|
|
1816
|
+
|
|
1817
|
+
|
|
1818
|
+
def _bsr_mv_as_vec_array(array: wp.array) -> wp.array:
|
|
1819
|
+
if array.ndim == 1:
|
|
1820
|
+
return array
|
|
1821
|
+
|
|
1822
|
+
if array.ndim > 2:
|
|
1823
|
+
raise ValueError(f"Incompatible array number of dimensions {array.ndim}")
|
|
1824
|
+
|
|
1825
|
+
if not array.is_contiguous:
|
|
1826
|
+
raise ValueError("2d array must be contiguous")
|
|
1827
|
+
|
|
1828
|
+
def vec_view(array):
|
|
1829
|
+
return wp.array(
|
|
1830
|
+
data=None,
|
|
1831
|
+
ptr=array.ptr,
|
|
1832
|
+
capacity=array.capacity,
|
|
1833
|
+
device=array.device,
|
|
1834
|
+
dtype=wp.vec(length=array.shape[1], dtype=array.dtype),
|
|
1835
|
+
shape=array.shape[0],
|
|
1836
|
+
grad=None if array.grad is None else vec_view(array.grad),
|
|
1837
|
+
)
|
|
1838
|
+
|
|
1839
|
+
view = vec_view(array)
|
|
1840
|
+
view._ref = array
|
|
1841
|
+
return view
|
|
1842
|
+
|
|
1843
|
+
|
|
1209
1844
|
def bsr_mv(
|
|
1210
|
-
A:
|
|
1845
|
+
A: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
|
|
1211
1846
|
x: "Array[Vector[Cols, Scalar] | Scalar]",
|
|
1212
1847
|
y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1213
1848
|
alpha: Scalar = 1.0,
|
|
1214
1849
|
beta: Scalar = 0.0,
|
|
1850
|
+
transpose: bool = False,
|
|
1215
1851
|
work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1216
1852
|
) -> "Array[Vector[Rows, Scalar] | Scalar]":
|
|
1217
1853
|
"""
|
|
@@ -1225,16 +1861,26 @@ def bsr_mv(
|
|
|
1225
1861
|
y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
|
|
1226
1862
|
alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
|
|
1227
1863
|
beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
|
|
1864
|
+
transpose: If ``True``, use the transpose of the matrix `A`. In this case the result is **non-deterministic**.
|
|
1228
1865
|
work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
|
|
1229
1866
|
will be used for this purpose, otherwise a temporary allocation will be performed.
|
|
1230
1867
|
"""
|
|
1231
1868
|
|
|
1869
|
+
A, A_scale = _extract_matrix_and_scale(A)
|
|
1870
|
+
alpha *= A_scale
|
|
1871
|
+
|
|
1872
|
+
if transpose:
|
|
1873
|
+
block_shape = A.block_shape[1], A.block_shape[0]
|
|
1874
|
+
nrow, ncol = A.ncol, A.nrow
|
|
1875
|
+
else:
|
|
1876
|
+
block_shape = A.block_shape
|
|
1877
|
+
nrow, ncol = A.nrow, A.ncol
|
|
1878
|
+
|
|
1232
1879
|
if y is None:
|
|
1233
1880
|
# If no output array is provided, allocate one for convenience
|
|
1234
|
-
y_vec_len =
|
|
1881
|
+
y_vec_len = block_shape[0]
|
|
1235
1882
|
y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
|
|
1236
|
-
y = wp.empty(shape=(
|
|
1237
|
-
y.zero_()
|
|
1883
|
+
y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype)
|
|
1238
1884
|
beta = 0.0
|
|
1239
1885
|
|
|
1240
1886
|
if not isinstance(alpha, A.scalar_type):
|
|
@@ -1245,12 +1891,16 @@ def bsr_mv(
|
|
|
1245
1891
|
if A.values.device != x.device or A.values.device != y.device:
|
|
1246
1892
|
raise ValueError("A, x and y must reside on the same device")
|
|
1247
1893
|
|
|
1248
|
-
if x.shape[0] !=
|
|
1894
|
+
if x.shape[0] != ncol:
|
|
1249
1895
|
raise ValueError("Number of columns of A must match number of rows of x")
|
|
1250
|
-
if y.shape[0] !=
|
|
1896
|
+
if y.shape[0] != nrow:
|
|
1251
1897
|
raise ValueError("Number of rows of A must match number of rows of y")
|
|
1252
1898
|
|
|
1253
|
-
|
|
1899
|
+
# View 2d arrays as arrays of vecs
|
|
1900
|
+
x = _bsr_mv_as_vec_array(x)
|
|
1901
|
+
y = _bsr_mv_as_vec_array(y)
|
|
1902
|
+
|
|
1903
|
+
if x.ptr == y.ptr:
|
|
1254
1904
|
# Aliasing case, need temporary storage
|
|
1255
1905
|
if work_buffer is None:
|
|
1256
1906
|
work_buffer = wp.empty_like(y)
|
|
@@ -1265,25 +1915,39 @@ def bsr_mv(
|
|
|
1265
1915
|
|
|
1266
1916
|
# Promote scalar vectors to length-1 vecs and conversely
|
|
1267
1917
|
if warp.types.type_is_matrix(A.values.dtype):
|
|
1268
|
-
if
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
if x.dtype == A.scalar_type:
|
|
1273
|
-
x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
1918
|
+
if block_shape[0] == 1 and y.dtype == A.scalar_type:
|
|
1919
|
+
y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
1920
|
+
if block_shape[1] == 1 and x.dtype == A.scalar_type:
|
|
1921
|
+
x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
1274
1922
|
else:
|
|
1275
|
-
if
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1923
|
+
if block_shape[0] == 1 and y.dtype != A.scalar_type:
|
|
1924
|
+
y = y.view(dtype=A.scalar_type)
|
|
1925
|
+
if block_shape[1] == 1 and x.dtype != A.scalar_type:
|
|
1926
|
+
x = x.view(dtype=A.scalar_type)
|
|
1927
|
+
|
|
1928
|
+
if transpose:
|
|
1929
|
+
if beta.value == 0.0:
|
|
1930
|
+
y.zero_()
|
|
1931
|
+
elif beta.value != 1.0:
|
|
1932
|
+
wp.launch(
|
|
1933
|
+
kernel=_bsr_scale_kernel,
|
|
1934
|
+
device=y.device,
|
|
1935
|
+
dim=y.shape[0],
|
|
1936
|
+
inputs=[beta, y],
|
|
1937
|
+
)
|
|
1938
|
+
if alpha.value != 0.0:
|
|
1939
|
+
wp.launch(
|
|
1940
|
+
kernel=_bsr_mv_transpose_kernel,
|
|
1941
|
+
device=A.values.device,
|
|
1942
|
+
dim=ncol,
|
|
1943
|
+
inputs=[alpha, A.offsets, A.columns, A.values, x, y],
|
|
1944
|
+
)
|
|
1945
|
+
else:
|
|
1946
|
+
wp.launch(
|
|
1947
|
+
kernel=_bsr_mv_kernel,
|
|
1948
|
+
device=A.values.device,
|
|
1949
|
+
dim=nrow,
|
|
1950
|
+
inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
|
|
1951
|
+
)
|
|
1288
1952
|
|
|
1289
1953
|
return y
|