warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +482 -110
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +48 -63
- warp/builtins.py +955 -137
- warp/codegen.py +327 -209
- warp/config.py +1 -1
- warp/context.py +1363 -800
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +266 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +200 -91
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +1 -1
- warp/jax_experimental/ffi.py +203 -54
- warp/marching_cubes.py +708 -0
- warp/native/array.h +103 -8
- warp/native/builtin.h +90 -9
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +13 -3
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +42 -11
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +4 -4
- warp/native/mat.h +1913 -119
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +5 -3
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +337 -16
- warp/native/rand.h +7 -7
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +22 -22
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +14 -14
- warp/native/spatial.h +366 -17
- warp/native/svd.h +23 -8
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +303 -70
- warp/native/tile_radix_sort.h +5 -1
- warp/native/tile_reduce.h +16 -25
- warp/native/tuple.h +2 -2
- warp/native/vec.h +385 -18
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +337 -193
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +137 -57
- warp/render/render_usd.py +0 -1
- warp/sim/collide.py +1 -2
- warp/sim/graph_coloring.py +2 -2
- warp/sim/integrator_vbd.py +10 -2
- warp/sparse.py +559 -176
- warp/tape.py +2 -0
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/sim/test_cloth.py +89 -6
- warp/tests/sim/test_coloring.py +82 -7
- warp/tests/test_array.py +56 -5
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +127 -114
- warp/tests/test_codegen.py +3 -2
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +45 -2
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +1 -1
- warp/tests/test_mat.py +1540 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +162 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +103 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_tape.py +38 -0
- warp/tests/test_types.py +0 -20
- warp/tests/test_vec.py +216 -441
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +206 -152
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_reduce.py +100 -11
- warp/tests/tile/test_tile_shared_memory.py +16 -16
- warp/tests/tile/test_tile_sort.py +59 -55
- warp/tests/unittest_suites.py +16 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +554 -264
- warp/utils.py +68 -86
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/sparse.py
CHANGED
|
@@ -36,6 +36,30 @@ from warp.types import (
|
|
|
36
36
|
types_equal,
|
|
37
37
|
)
|
|
38
38
|
|
|
39
|
+
__all__ = [
|
|
40
|
+
"BsrMatrix",
|
|
41
|
+
"bsr_assign",
|
|
42
|
+
"bsr_axpy",
|
|
43
|
+
"bsr_copy",
|
|
44
|
+
"bsr_diag",
|
|
45
|
+
"bsr_from_triplets",
|
|
46
|
+
"bsr_get_diag",
|
|
47
|
+
"bsr_identity",
|
|
48
|
+
"bsr_matrix_t",
|
|
49
|
+
"bsr_mm",
|
|
50
|
+
"bsr_mm_work_arrays",
|
|
51
|
+
"bsr_mv",
|
|
52
|
+
"bsr_scale",
|
|
53
|
+
"bsr_set_diag",
|
|
54
|
+
"bsr_set_from_triplets",
|
|
55
|
+
"bsr_set_identity",
|
|
56
|
+
"bsr_set_transpose",
|
|
57
|
+
"bsr_set_zero",
|
|
58
|
+
"bsr_transposed",
|
|
59
|
+
"bsr_zeros",
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
39
63
|
# typing hints
|
|
40
64
|
|
|
41
65
|
_BlockType = TypeVar("BlockType") # noqa: PLC0132
|
|
@@ -52,6 +76,7 @@ class _ScalarBlockType(Generic[Scalar]):
|
|
|
52
76
|
BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
|
|
53
77
|
|
|
54
78
|
_struct_cache = {}
|
|
79
|
+
_transfer_buffer_cache = {}
|
|
55
80
|
|
|
56
81
|
|
|
57
82
|
class BsrMatrix(Generic[_BlockType]):
|
|
@@ -131,17 +156,21 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
131
156
|
return out
|
|
132
157
|
|
|
133
158
|
def nnz_sync(self):
|
|
134
|
-
"""
|
|
135
|
-
and
|
|
159
|
+
"""Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
|
|
160
|
+
or, if none has been scheduled yet, starts a new transfer and waits for it to complete.
|
|
161
|
+
Then updates the nnz upper bound.
|
|
136
162
|
|
|
137
163
|
See also :meth:`copy_nnz_async`.
|
|
138
164
|
"""
|
|
139
165
|
|
|
140
166
|
buf, event = self._nnz_transfer_if_any()
|
|
141
|
-
if buf is
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
167
|
+
if buf is None:
|
|
168
|
+
self.copy_nnz_async()
|
|
169
|
+
buf, event = self._nnz_transfer_if_any()
|
|
170
|
+
|
|
171
|
+
if event is not None:
|
|
172
|
+
wp.synchronize_event(event)
|
|
173
|
+
self.nnz = int(buf.numpy()[0])
|
|
145
174
|
return self.nnz
|
|
146
175
|
|
|
147
176
|
def copy_nnz_async(self) -> None:
|
|
@@ -161,17 +190,23 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
161
190
|
|
|
162
191
|
def _setup_nnz_transfer(self):
|
|
163
192
|
buf, event = self._nnz_transfer_if_any()
|
|
164
|
-
if buf is not None
|
|
193
|
+
if buf is not None:
|
|
165
194
|
return buf, event
|
|
166
195
|
|
|
167
|
-
buf
|
|
168
|
-
|
|
169
|
-
|
|
196
|
+
buf, event = _allocate_transfer_buf(self.device)
|
|
197
|
+
if buf is not None:
|
|
198
|
+
BsrMatrix.__setattr__(self, "_nnz_transfer", (buf, event))
|
|
199
|
+
|
|
170
200
|
return buf, event
|
|
171
201
|
|
|
172
202
|
def _nnz_transfer_if_any(self):
|
|
173
203
|
return getattr(self, "_nnz_transfer", (None, None))
|
|
174
204
|
|
|
205
|
+
def __del__(self):
|
|
206
|
+
buf, event = self._nnz_transfer_if_any()
|
|
207
|
+
if buf is not None:
|
|
208
|
+
_redeem_transfer_buf(self.device, buf, event)
|
|
209
|
+
|
|
175
210
|
# Overloaded math operators
|
|
176
211
|
def __add__(self, y):
|
|
177
212
|
return bsr_axpy(y, bsr_copy(self))
|
|
@@ -226,6 +261,31 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
226
261
|
return bsr_transposed(self)
|
|
227
262
|
|
|
228
263
|
|
|
264
|
+
def _allocate_transfer_buf(device):
|
|
265
|
+
if device.ordinal in _transfer_buffer_cache:
|
|
266
|
+
all_, pool = _transfer_buffer_cache[device.ordinal]
|
|
267
|
+
else:
|
|
268
|
+
all_ = []
|
|
269
|
+
pool = []
|
|
270
|
+
_transfer_buffer_cache[device.ordinal] = (all_, pool)
|
|
271
|
+
|
|
272
|
+
if pool:
|
|
273
|
+
return pool.pop()
|
|
274
|
+
|
|
275
|
+
if device.is_capturing:
|
|
276
|
+
return None, None
|
|
277
|
+
|
|
278
|
+
buf = wp.empty(dtype=int, shape=(1,), device="cpu", pinned=device.is_cuda)
|
|
279
|
+
event = wp.Event(device) if device.is_cuda else None
|
|
280
|
+
all_.append((buf, event)) # keep a reference to the buffer and event, prevent garbage collection before redeem
|
|
281
|
+
return buf, event
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _redeem_transfer_buf(device, buf, event):
|
|
285
|
+
all_, pool = _transfer_buffer_cache[device.ordinal]
|
|
286
|
+
pool.append((buf, event))
|
|
287
|
+
|
|
288
|
+
|
|
229
289
|
def bsr_matrix_t(dtype: BlockType):
|
|
230
290
|
dtype = type_to_warp(dtype)
|
|
231
291
|
|
|
@@ -483,16 +543,16 @@ def bsr_set_from_triplets(
|
|
|
483
543
|
|
|
484
544
|
device = dest.values.device
|
|
485
545
|
scalar_type = dest.scalar_type
|
|
486
|
-
zero_value_mask = _zero_value_masks.get(scalar_type, 0)
|
|
546
|
+
zero_value_mask = _zero_value_masks.get(scalar_type, 0) if prune_numerical_zeros else 0
|
|
487
547
|
|
|
488
548
|
# compute the BSR topology
|
|
489
549
|
|
|
490
550
|
from warp.context import runtime
|
|
491
551
|
|
|
492
552
|
if device.is_cpu:
|
|
493
|
-
native_func = runtime.core.
|
|
553
|
+
native_func = runtime.core.wp_bsr_matrix_from_triplets_host
|
|
494
554
|
else:
|
|
495
|
-
native_func = runtime.core.
|
|
555
|
+
native_func = runtime.core.wp_bsr_matrix_from_triplets_device
|
|
496
556
|
|
|
497
557
|
nnz_buf, nnz_event = dest._setup_nnz_transfer()
|
|
498
558
|
summed_triplet_offsets = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
|
|
@@ -901,9 +961,9 @@ def bsr_assign(
|
|
|
901
961
|
from warp.context import runtime
|
|
902
962
|
|
|
903
963
|
if dest.device.is_cpu:
|
|
904
|
-
native_func = runtime.core.
|
|
964
|
+
native_func = runtime.core.wp_bsr_matrix_from_triplets_host
|
|
905
965
|
else:
|
|
906
|
-
native_func = runtime.core.
|
|
966
|
+
native_func = runtime.core.wp_bsr_matrix_from_triplets_device
|
|
907
967
|
|
|
908
968
|
nnz_buf, nnz_event = dest._setup_nnz_transfer()
|
|
909
969
|
with wp.ScopedDevice(dest.device):
|
|
@@ -1041,9 +1101,9 @@ def bsr_set_transpose(
|
|
|
1041
1101
|
from warp.context import runtime
|
|
1042
1102
|
|
|
1043
1103
|
if dest.values.device.is_cpu:
|
|
1044
|
-
native_func = runtime.core.
|
|
1104
|
+
native_func = runtime.core.wp_bsr_transpose_host
|
|
1045
1105
|
else:
|
|
1046
|
-
native_func = runtime.core.
|
|
1106
|
+
native_func = runtime.core.wp_bsr_transpose_device
|
|
1047
1107
|
|
|
1048
1108
|
block_index_map = wp.empty(shape=2 * nnz, dtype=int, device=src.device)
|
|
1049
1109
|
|
|
@@ -1094,14 +1154,14 @@ def _bsr_get_diag_kernel(
|
|
|
1094
1154
|
scale: Any,
|
|
1095
1155
|
A_offsets: wp.array(dtype=int),
|
|
1096
1156
|
A_columns: wp.array(dtype=int),
|
|
1097
|
-
A_values: wp.
|
|
1098
|
-
out: wp.
|
|
1157
|
+
A_values: wp.array3d(dtype=Any),
|
|
1158
|
+
out: wp.array3d(dtype=Any),
|
|
1099
1159
|
):
|
|
1100
|
-
row = wp.tid()
|
|
1160
|
+
row, br, bc = wp.tid()
|
|
1101
1161
|
|
|
1102
1162
|
diag = _bsr_block_index(row, row, A_offsets, A_columns)
|
|
1103
1163
|
if diag != -1:
|
|
1104
|
-
out[row] = scale * A_values[diag]
|
|
1164
|
+
out[row, br, bc] = scale * A_values[diag, br, bc]
|
|
1105
1165
|
|
|
1106
1166
|
|
|
1107
1167
|
def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
|
|
@@ -1128,9 +1188,9 @@ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[Block
|
|
|
1128
1188
|
|
|
1129
1189
|
wp.launch(
|
|
1130
1190
|
kernel=_bsr_get_diag_kernel,
|
|
1131
|
-
dim=dim,
|
|
1191
|
+
dim=(dim, *A.block_shape),
|
|
1132
1192
|
device=A.values.device,
|
|
1133
|
-
inputs=[A.scalar_type(scale), A.offsets, A.columns, A.
|
|
1193
|
+
inputs=[A.scalar_type(scale), A.offsets, A.columns, A.scalar_values, _as_3d_array(out, A.block_shape)],
|
|
1134
1194
|
)
|
|
1135
1195
|
|
|
1136
1196
|
return out
|
|
@@ -1312,7 +1372,17 @@ def _bsr_scale_kernel(
|
|
|
1312
1372
|
alpha: Any,
|
|
1313
1373
|
values: wp.array(dtype=Any),
|
|
1314
1374
|
):
|
|
1315
|
-
|
|
1375
|
+
row = wp.tid()
|
|
1376
|
+
values[row] = alpha * values[row]
|
|
1377
|
+
|
|
1378
|
+
|
|
1379
|
+
@wp.kernel
|
|
1380
|
+
def _bsr_scale_kernel(
|
|
1381
|
+
alpha: Any,
|
|
1382
|
+
values: wp.array3d(dtype=Any),
|
|
1383
|
+
):
|
|
1384
|
+
row, br, bc = wp.tid()
|
|
1385
|
+
values[row, br, bc] = alpha * values[row, br, bc]
|
|
1316
1386
|
|
|
1317
1387
|
|
|
1318
1388
|
def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
@@ -1329,9 +1399,9 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
|
|
|
1329
1399
|
|
|
1330
1400
|
wp.launch(
|
|
1331
1401
|
kernel=_bsr_scale_kernel,
|
|
1332
|
-
dim=x.nnz,
|
|
1402
|
+
dim=(x.nnz, *x.block_shape),
|
|
1333
1403
|
device=x.values.device,
|
|
1334
|
-
inputs=[alpha, x.
|
|
1404
|
+
inputs=[alpha, x.scalar_values],
|
|
1335
1405
|
)
|
|
1336
1406
|
|
|
1337
1407
|
return x
|
|
@@ -1351,16 +1421,16 @@ def _bsr_axpy_add_block(
|
|
|
1351
1421
|
cols: wp.array(dtype=int),
|
|
1352
1422
|
dst_offsets: wp.array(dtype=int),
|
|
1353
1423
|
dst_columns: wp.array(dtype=int),
|
|
1354
|
-
src_values: wp.
|
|
1355
|
-
dst_values: wp.
|
|
1424
|
+
src_values: wp.array3d(dtype=Any),
|
|
1425
|
+
dst_values: wp.array3d(dtype=Any),
|
|
1356
1426
|
):
|
|
1357
|
-
i = wp.tid()
|
|
1427
|
+
i, br, bc = wp.tid()
|
|
1358
1428
|
row = rows[i + src_offset]
|
|
1359
1429
|
col = cols[i + src_offset]
|
|
1360
1430
|
|
|
1361
1431
|
block = _bsr_block_index(row, col, dst_offsets, dst_columns)
|
|
1362
1432
|
if block != -1:
|
|
1363
|
-
dst_values[block] += scale * src_values[i]
|
|
1433
|
+
dst_values[block, br, bc] += scale * src_values[i, br, bc]
|
|
1364
1434
|
|
|
1365
1435
|
|
|
1366
1436
|
class bsr_axpy_work_arrays:
|
|
@@ -1386,7 +1456,7 @@ class bsr_axpy_work_arrays:
|
|
|
1386
1456
|
self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
1387
1457
|
|
|
1388
1458
|
if self._old_y_values is None or self._old_y_values.size < y.nnz:
|
|
1389
|
-
self._old_y_values = wp.
|
|
1459
|
+
self._old_y_values = wp.empty_like(y.values[: y.nnz])
|
|
1390
1460
|
|
|
1391
1461
|
|
|
1392
1462
|
def bsr_axpy(
|
|
@@ -1475,7 +1545,7 @@ def bsr_axpy(
|
|
|
1475
1545
|
x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
|
|
1476
1546
|
|
|
1477
1547
|
# Save old y values before overwriting matrix
|
|
1478
|
-
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=
|
|
1548
|
+
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
|
|
1479
1549
|
|
|
1480
1550
|
# Increase dest array sizes if needed
|
|
1481
1551
|
if not masked:
|
|
@@ -1484,9 +1554,9 @@ def bsr_axpy(
|
|
|
1484
1554
|
from warp.context import runtime
|
|
1485
1555
|
|
|
1486
1556
|
if device.is_cpu:
|
|
1487
|
-
native_func = runtime.core.
|
|
1557
|
+
native_func = runtime.core.wp_bsr_matrix_from_triplets_host
|
|
1488
1558
|
else:
|
|
1489
|
-
native_func = runtime.core.
|
|
1559
|
+
native_func = runtime.core.wp_bsr_matrix_from_triplets_device
|
|
1490
1560
|
|
|
1491
1561
|
old_y_nnz = y_nnz
|
|
1492
1562
|
nnz_buf, nnz_event = y._setup_nnz_transfer()
|
|
@@ -1517,7 +1587,7 @@ def bsr_axpy(
|
|
|
1517
1587
|
wp.launch(
|
|
1518
1588
|
kernel=_bsr_axpy_add_block,
|
|
1519
1589
|
device=device,
|
|
1520
|
-
dim=old_y_nnz,
|
|
1590
|
+
dim=(old_y_nnz, y.block_shape[0], y.block_shape[1]),
|
|
1521
1591
|
inputs=[
|
|
1522
1592
|
0,
|
|
1523
1593
|
beta,
|
|
@@ -1525,15 +1595,15 @@ def bsr_axpy(
|
|
|
1525
1595
|
work_arrays._sum_cols,
|
|
1526
1596
|
y.offsets,
|
|
1527
1597
|
y.columns,
|
|
1528
|
-
work_arrays._old_y_values,
|
|
1529
|
-
y.
|
|
1598
|
+
_as_3d_array(work_arrays._old_y_values, y.block_shape),
|
|
1599
|
+
y.scalar_values,
|
|
1530
1600
|
],
|
|
1531
1601
|
)
|
|
1532
1602
|
|
|
1533
1603
|
wp.launch(
|
|
1534
1604
|
kernel=_bsr_axpy_add_block,
|
|
1535
1605
|
device=device,
|
|
1536
|
-
dim=x_nnz,
|
|
1606
|
+
dim=(x_nnz, y.block_shape[0], y.block_shape[1]),
|
|
1537
1607
|
inputs=[
|
|
1538
1608
|
old_y_nnz,
|
|
1539
1609
|
alpha,
|
|
@@ -1541,60 +1611,78 @@ def bsr_axpy(
|
|
|
1541
1611
|
work_arrays._sum_cols,
|
|
1542
1612
|
y.offsets,
|
|
1543
1613
|
y.columns,
|
|
1544
|
-
x.
|
|
1545
|
-
y.
|
|
1614
|
+
x.scalar_values,
|
|
1615
|
+
y.scalar_values,
|
|
1546
1616
|
],
|
|
1547
1617
|
)
|
|
1548
1618
|
|
|
1549
1619
|
return y
|
|
1550
1620
|
|
|
1551
1621
|
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
y_ncol: int,
|
|
1555
|
-
z_nnz: int,
|
|
1556
|
-
x_offsets: wp.array(dtype=int),
|
|
1557
|
-
x_columns: wp.array(dtype=int),
|
|
1558
|
-
y_offsets: wp.array(dtype=int),
|
|
1559
|
-
y_columns: wp.array(dtype=int),
|
|
1560
|
-
row_min: wp.array(dtype=int),
|
|
1561
|
-
block_counts: wp.array(dtype=int),
|
|
1562
|
-
):
|
|
1563
|
-
row = wp.tid()
|
|
1564
|
-
row_count = int(0)
|
|
1622
|
+
def make_bsr_mm_count_coeffs(tile_size):
|
|
1623
|
+
from warp.fem.cache import dynamic_kernel
|
|
1565
1624
|
|
|
1566
|
-
|
|
1567
|
-
|
|
1625
|
+
@dynamic_kernel(suffix=tile_size)
|
|
1626
|
+
def bsr_mm_count_coeffs(
|
|
1627
|
+
y_ncol: int,
|
|
1628
|
+
z_nnz: int,
|
|
1629
|
+
x_offsets: wp.array(dtype=int),
|
|
1630
|
+
x_columns: wp.array(dtype=int),
|
|
1631
|
+
y_offsets: wp.array(dtype=int),
|
|
1632
|
+
y_columns: wp.array(dtype=int),
|
|
1633
|
+
row_min: wp.array(dtype=int),
|
|
1634
|
+
block_counts: wp.array(dtype=int),
|
|
1635
|
+
):
|
|
1636
|
+
row, lane = wp.tid()
|
|
1637
|
+
row_count = int(0)
|
|
1568
1638
|
|
|
1569
|
-
|
|
1570
|
-
|
|
1639
|
+
x_beg = x_offsets[row]
|
|
1640
|
+
x_end = x_offsets[row + 1]
|
|
1571
1641
|
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
y_row_end = y_offsets[x_col + 1]
|
|
1575
|
-
y_row_beg = y_offsets[x_col]
|
|
1576
|
-
block_count = y_row_end - y_row_beg
|
|
1577
|
-
if block_count != 0:
|
|
1578
|
-
min_col = wp.min(y_columns[y_row_beg], min_col)
|
|
1579
|
-
max_col = wp.max(y_columns[y_row_end - 1], max_col)
|
|
1580
|
-
|
|
1581
|
-
block_counts[x_block + 1] = block_count
|
|
1582
|
-
row_count += block_count
|
|
1583
|
-
|
|
1584
|
-
if row_count > wp.max(0, max_col - min_col):
|
|
1585
|
-
row_min[row] = min_col
|
|
1586
|
-
block_counts[x_end] = max_col + 1 - min_col
|
|
1587
|
-
for x_block in range(x_beg, x_end - 1):
|
|
1588
|
-
block_counts[x_block + 1] = 0
|
|
1589
|
-
else:
|
|
1590
|
-
row_min[row] = -1
|
|
1642
|
+
min_col = y_ncol
|
|
1643
|
+
max_col = int(0)
|
|
1591
1644
|
|
|
1592
|
-
|
|
1593
|
-
|
|
1645
|
+
for x_block in range(x_beg + lane, x_end, tile_size):
|
|
1646
|
+
x_col = x_columns[x_block]
|
|
1647
|
+
y_row_end = y_offsets[x_col + 1]
|
|
1648
|
+
y_row_beg = y_offsets[x_col]
|
|
1649
|
+
block_count = y_row_end - y_row_beg
|
|
1650
|
+
if block_count != 0:
|
|
1651
|
+
min_col = wp.min(y_columns[y_row_beg], min_col)
|
|
1652
|
+
max_col = wp.max(y_columns[y_row_end - 1], max_col)
|
|
1653
|
+
|
|
1654
|
+
block_counts[x_block + 1] = block_count
|
|
1655
|
+
row_count += block_count
|
|
1656
|
+
|
|
1657
|
+
if wp.static(tile_size) > 1:
|
|
1658
|
+
row_count = wp.tile_sum(wp.tile(row_count))[0]
|
|
1659
|
+
min_col = wp.tile_min(wp.tile(min_col))[0]
|
|
1660
|
+
max_col = wp.tile_max(wp.tile(max_col))[0]
|
|
1661
|
+
col_range_size = wp.max(0, max_col - min_col + 1)
|
|
1662
|
+
|
|
1663
|
+
if row_count > col_range_size:
|
|
1664
|
+
# Optimization for deep products.
|
|
1665
|
+
# Do not store the whole whole list of src product terms, they would be highly redundant
|
|
1666
|
+
# Instead just mark a range in the output matrix
|
|
1667
|
+
|
|
1668
|
+
if lane == 0:
|
|
1669
|
+
row_min[row] = min_col
|
|
1670
|
+
block_counts[x_end] = col_range_size
|
|
1671
|
+
|
|
1672
|
+
for x_block in range(x_beg + lane, x_end - 1, tile_size):
|
|
1673
|
+
block_counts[x_block + 1] = 0
|
|
1674
|
+
elif lane == 0:
|
|
1675
|
+
row_min[row] = -1
|
|
1676
|
+
|
|
1677
|
+
if lane == 0 and row == 0:
|
|
1678
|
+
block_counts[0] = z_nnz
|
|
1679
|
+
|
|
1680
|
+
return bsr_mm_count_coeffs
|
|
1594
1681
|
|
|
1595
1682
|
|
|
1596
1683
|
@wp.kernel(enable_backward=False)
|
|
1597
1684
|
def _bsr_mm_list_coeffs(
|
|
1685
|
+
copied_z_nnz: int,
|
|
1598
1686
|
x_nrow: int,
|
|
1599
1687
|
x_offsets: wp.array(dtype=int),
|
|
1600
1688
|
x_columns: wp.array(dtype=int),
|
|
@@ -1604,38 +1692,58 @@ def _bsr_mm_list_coeffs(
|
|
|
1604
1692
|
mm_offsets: wp.array(dtype=int),
|
|
1605
1693
|
mm_rows: wp.array(dtype=int),
|
|
1606
1694
|
mm_cols: wp.array(dtype=int),
|
|
1695
|
+
mm_src_blocks: wp.array(dtype=int),
|
|
1607
1696
|
):
|
|
1608
|
-
|
|
1609
|
-
|
|
1697
|
+
mm_block = wp.tid() + copied_z_nnz
|
|
1698
|
+
|
|
1699
|
+
x_nnz = x_offsets[x_nrow]
|
|
1700
|
+
x_block = wp.lower_bound(mm_offsets, 0, x_nnz + 1, mm_block + 1) - 1
|
|
1701
|
+
pos = mm_block - mm_offsets[x_block]
|
|
1610
1702
|
|
|
1611
1703
|
row = _bsr_row_index(x_offsets, x_nrow, x_block)
|
|
1612
|
-
if row == -1:
|
|
1613
|
-
return
|
|
1614
1704
|
|
|
1615
1705
|
row_min_col = mm_row_min[row]
|
|
1616
|
-
if row_min_col
|
|
1706
|
+
if row_min_col == -1:
|
|
1617
1707
|
x_col = x_columns[x_block]
|
|
1618
|
-
|
|
1619
1708
|
y_beg = y_offsets[x_col]
|
|
1620
|
-
|
|
1709
|
+
y_block = y_beg + pos
|
|
1710
|
+
col = y_columns[y_block]
|
|
1711
|
+
src_block = x_block
|
|
1712
|
+
else:
|
|
1713
|
+
col = row_min_col + pos
|
|
1714
|
+
src_block = -1
|
|
1621
1715
|
|
|
1622
|
-
|
|
1623
|
-
|
|
1624
|
-
|
|
1625
|
-
mm_cols[mm_block + col - row_min_col] = col
|
|
1716
|
+
mm_cols[mm_block] = col
|
|
1717
|
+
mm_rows[mm_block] = row
|
|
1718
|
+
mm_src_blocks[mm_block] = src_block
|
|
1626
1719
|
|
|
1627
|
-
return
|
|
1628
1720
|
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
1721
|
+
@wp.func
|
|
1722
|
+
def _bsr_mm_use_triplets(
|
|
1723
|
+
row: int,
|
|
1724
|
+
mm_block: int,
|
|
1725
|
+
mm_row_min: wp.array(dtype=int),
|
|
1726
|
+
row_offsets: wp.array(dtype=int),
|
|
1727
|
+
summed_triplet_offsets: wp.array(dtype=int),
|
|
1728
|
+
):
|
|
1729
|
+
x_beg = row_offsets[row]
|
|
1730
|
+
x_end = row_offsets[row + 1]
|
|
1636
1731
|
|
|
1732
|
+
if mm_row_min:
|
|
1733
|
+
if mm_row_min[row] == -1:
|
|
1734
|
+
if mm_block == 0:
|
|
1735
|
+
block_beg = 0
|
|
1736
|
+
else:
|
|
1737
|
+
block_beg = summed_triplet_offsets[mm_block - 1]
|
|
1738
|
+
block_end = summed_triplet_offsets[mm_block]
|
|
1637
1739
|
|
|
1638
|
-
|
|
1740
|
+
if x_end - x_beg > 3 * (block_end - block_beg):
|
|
1741
|
+
return True, block_beg, block_end
|
|
1742
|
+
|
|
1743
|
+
return False, x_beg, x_end
|
|
1744
|
+
|
|
1745
|
+
|
|
1746
|
+
@wp.kernel(enable_backward=False)
|
|
1639
1747
|
def _bsr_mm_compute_values(
|
|
1640
1748
|
alpha: Any,
|
|
1641
1749
|
x_offsets: wp.array(dtype=int),
|
|
@@ -1644,6 +1752,9 @@ def _bsr_mm_compute_values(
|
|
|
1644
1752
|
y_offsets: wp.array(dtype=int),
|
|
1645
1753
|
y_columns: wp.array(dtype=int),
|
|
1646
1754
|
y_values: wp.array(dtype=Any),
|
|
1755
|
+
mm_row_min: wp.array(dtype=int),
|
|
1756
|
+
summed_triplet_offsets: wp.array(dtype=int),
|
|
1757
|
+
summed_triplet_src_blocks: wp.indexedarray(dtype=int),
|
|
1647
1758
|
mm_row_count: int,
|
|
1648
1759
|
mm_offsets: wp.array(dtype=int),
|
|
1649
1760
|
mm_cols: wp.array(dtype=int),
|
|
@@ -1655,21 +1766,135 @@ def _bsr_mm_compute_values(
|
|
|
1655
1766
|
if row == -1:
|
|
1656
1767
|
return
|
|
1657
1768
|
|
|
1658
|
-
|
|
1769
|
+
use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
|
|
1770
|
+
row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
|
|
1771
|
+
)
|
|
1659
1772
|
|
|
1660
1773
|
mm_val = mm_values.dtype(type(alpha)(0.0))
|
|
1661
|
-
|
|
1662
|
-
|
|
1663
|
-
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1774
|
+
col = mm_cols[mm_block]
|
|
1775
|
+
if use_triplets:
|
|
1776
|
+
for tpl_idx in range(block_beg, block_end):
|
|
1777
|
+
x_block = summed_triplet_src_blocks[tpl_idx]
|
|
1778
|
+
x_col = x_columns[x_block]
|
|
1779
|
+
if x_block != -1:
|
|
1780
|
+
y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
|
|
1781
|
+
mm_val += x_values[x_block] * y_values[y_block]
|
|
1782
|
+
else:
|
|
1783
|
+
for x_block in range(block_beg, block_end):
|
|
1784
|
+
x_col = x_columns[x_block]
|
|
1785
|
+
y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
|
|
1786
|
+
if y_block != -1:
|
|
1787
|
+
mm_val += x_values[x_block] * y_values[y_block]
|
|
1669
1788
|
|
|
1670
1789
|
mm_values[mm_block] += alpha * mm_val
|
|
1671
1790
|
|
|
1672
1791
|
|
|
1792
|
+
def make_bsr_mm_compute_values_tiled_outer(subblock_rows, subblock_cols, block_depth, scalar_type, tile_size):
|
|
1793
|
+
from warp.fem.cache import dynamic_func, dynamic_kernel
|
|
1794
|
+
|
|
1795
|
+
mm_type = wp.mat(dtype=scalar_type, shape=(subblock_rows, subblock_cols))
|
|
1796
|
+
|
|
1797
|
+
x_col_vec_t = wp.vec(dtype=scalar_type, length=subblock_rows)
|
|
1798
|
+
y_row_vec_t = wp.vec(dtype=scalar_type, length=subblock_cols)
|
|
1799
|
+
|
|
1800
|
+
suffix = f"{subblock_rows}{subblock_cols}{block_depth}{tile_size}{scalar_type.__name__}"
|
|
1801
|
+
|
|
1802
|
+
@dynamic_func(suffix=suffix)
|
|
1803
|
+
def _outer_product(
|
|
1804
|
+
x_values: wp.array2d(dtype=scalar_type),
|
|
1805
|
+
y_values: wp.array2d(dtype=scalar_type),
|
|
1806
|
+
brow_off: int,
|
|
1807
|
+
bcol_off: int,
|
|
1808
|
+
block_col: int,
|
|
1809
|
+
brow_count: int,
|
|
1810
|
+
bcol_count: int,
|
|
1811
|
+
):
|
|
1812
|
+
x_col_vec = x_col_vec_t()
|
|
1813
|
+
y_row_vec = y_row_vec_t()
|
|
1814
|
+
|
|
1815
|
+
for k in range(brow_count):
|
|
1816
|
+
x_col_vec[k] = x_values[brow_off + k, block_col]
|
|
1817
|
+
for k in range(bcol_count):
|
|
1818
|
+
y_row_vec[k] = y_values[block_col, bcol_off + k]
|
|
1819
|
+
|
|
1820
|
+
return wp.outer(x_col_vec, y_row_vec)
|
|
1821
|
+
|
|
1822
|
+
@dynamic_kernel(suffix=suffix, kernel_options={"enable_backward": False})
|
|
1823
|
+
def bsr_mm_compute_values(
|
|
1824
|
+
alpha: scalar_type,
|
|
1825
|
+
x_offsets: wp.array(dtype=int),
|
|
1826
|
+
x_columns: wp.array(dtype=int),
|
|
1827
|
+
x_values: wp.array3d(dtype=scalar_type),
|
|
1828
|
+
y_offsets: wp.array(dtype=int),
|
|
1829
|
+
y_columns: wp.array(dtype=int),
|
|
1830
|
+
y_values: wp.array3d(dtype=scalar_type),
|
|
1831
|
+
mm_row_min: wp.array(dtype=int),
|
|
1832
|
+
summed_triplet_offsets: wp.array(dtype=int),
|
|
1833
|
+
summed_triplet_src_blocks: wp.indexedarray(dtype=int),
|
|
1834
|
+
mm_row_count: int,
|
|
1835
|
+
mm_offsets: wp.array(dtype=int),
|
|
1836
|
+
mm_cols: wp.array(dtype=int),
|
|
1837
|
+
mm_values: wp.array3d(dtype=scalar_type),
|
|
1838
|
+
):
|
|
1839
|
+
mm_block, subrow, subcol, lane = wp.tid()
|
|
1840
|
+
|
|
1841
|
+
brow_off = subrow * wp.static(subblock_rows)
|
|
1842
|
+
bcol_off = subcol * wp.static(subblock_cols)
|
|
1843
|
+
|
|
1844
|
+
brow_count = wp.min(mm_values.shape[1] - brow_off, subblock_rows)
|
|
1845
|
+
bcol_count = wp.min(mm_values.shape[2] - bcol_off, subblock_cols)
|
|
1846
|
+
|
|
1847
|
+
mm_row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
|
|
1848
|
+
if mm_row == -1:
|
|
1849
|
+
return
|
|
1850
|
+
|
|
1851
|
+
lane_val = mm_type()
|
|
1852
|
+
|
|
1853
|
+
use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
|
|
1854
|
+
mm_row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
|
|
1855
|
+
)
|
|
1856
|
+
|
|
1857
|
+
col_count = (block_end - block_beg) * block_depth
|
|
1858
|
+
|
|
1859
|
+
mm_col = mm_cols[mm_block]
|
|
1860
|
+
if use_triplets:
|
|
1861
|
+
for col in range(lane, col_count, tile_size):
|
|
1862
|
+
tpl_block = col // wp.static(block_depth)
|
|
1863
|
+
block_col = col - tpl_block * wp.static(block_depth)
|
|
1864
|
+
tpl_block += block_beg
|
|
1865
|
+
|
|
1866
|
+
x_block = summed_triplet_src_blocks[tpl_block]
|
|
1867
|
+
if x_block != -1:
|
|
1868
|
+
x_col = x_columns[x_block]
|
|
1869
|
+
y_block = _bsr_block_index(x_col, mm_col, y_offsets, y_columns)
|
|
1870
|
+
lane_val += _outer_product(
|
|
1871
|
+
x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
|
|
1872
|
+
)
|
|
1873
|
+
else:
|
|
1874
|
+
for col in range(lane, col_count, tile_size):
|
|
1875
|
+
x_block = col // wp.static(block_depth)
|
|
1876
|
+
block_col = col - x_block * wp.static(block_depth)
|
|
1877
|
+
x_block += block_beg
|
|
1878
|
+
|
|
1879
|
+
x_col = x_columns[x_block]
|
|
1880
|
+
y_block = _bsr_block_index(x_col, mm_col, y_offsets, y_columns)
|
|
1881
|
+
|
|
1882
|
+
if y_block != -1:
|
|
1883
|
+
lane_val += _outer_product(
|
|
1884
|
+
x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
|
|
1885
|
+
)
|
|
1886
|
+
|
|
1887
|
+
mm_val = wp.tile_sum(wp.tile(lane_val, preserve_type=True))[0]
|
|
1888
|
+
|
|
1889
|
+
for coef in range(lane, wp.static(subblock_cols * subblock_rows), tile_size):
|
|
1890
|
+
br = coef // subblock_cols
|
|
1891
|
+
bc = coef - br * subblock_cols
|
|
1892
|
+
if br < brow_count and bc < bcol_count:
|
|
1893
|
+
mm_values[mm_block, br + brow_off, bc + bcol_off] += mm_val[br, bc] * alpha
|
|
1894
|
+
|
|
1895
|
+
return bsr_mm_compute_values
|
|
1896
|
+
|
|
1897
|
+
|
|
1673
1898
|
class bsr_mm_work_arrays:
|
|
1674
1899
|
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
|
|
1675
1900
|
|
|
@@ -1682,6 +1907,7 @@ class bsr_mm_work_arrays:
|
|
|
1682
1907
|
self._mm_block_counts = None
|
|
1683
1908
|
self._mm_rows = None
|
|
1684
1909
|
self._mm_cols = None
|
|
1910
|
+
self._mm_src_blocks = None
|
|
1685
1911
|
self._old_z_values = None
|
|
1686
1912
|
self._old_z_offsets = None
|
|
1687
1913
|
self._old_z_columns = None
|
|
@@ -1717,6 +1943,8 @@ class bsr_mm_work_arrays:
|
|
|
1717
1943
|
self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
1718
1944
|
if self._mm_cols is None or self._mm_cols.size < mm_nnz:
|
|
1719
1945
|
self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
1946
|
+
if self._mm_src_blocks is None or self._mm_src_blocks.size < mm_nnz:
|
|
1947
|
+
self._mm_src_blocks = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
1720
1948
|
|
|
1721
1949
|
|
|
1722
1950
|
def bsr_mm(
|
|
@@ -1728,6 +1956,7 @@ def bsr_mm(
|
|
|
1728
1956
|
masked: bool = False,
|
|
1729
1957
|
work_arrays: Optional[bsr_mm_work_arrays] = None,
|
|
1730
1958
|
reuse_topology: bool = False,
|
|
1959
|
+
tile_size: int = 0,
|
|
1731
1960
|
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
1732
1961
|
"""
|
|
1733
1962
|
Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
|
|
@@ -1750,6 +1979,9 @@ def bsr_mm(
|
|
|
1750
1979
|
The matrices ``x``, ``y`` and ``z`` must be structurally similar to
|
|
1751
1980
|
the previous call in which ``work_arrays`` were populated.
|
|
1752
1981
|
This is necessary for ``bsr_mm`` to be captured in a CUDA graph.
|
|
1982
|
+
tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
|
|
1983
|
+
If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
|
|
1984
|
+
use tiles using using an heuristic based on the matrix shape and number of non-zeros..
|
|
1753
1985
|
"""
|
|
1754
1986
|
|
|
1755
1987
|
x, x_scale = _extract_matrix_and_scale(x)
|
|
@@ -1835,11 +2067,17 @@ def bsr_mm(
|
|
|
1835
2067
|
copied_z_nnz = work_arrays._copied_z_nnz
|
|
1836
2068
|
|
|
1837
2069
|
# Prefix sum of number of (unmerged) mm blocks per row
|
|
2070
|
+
# Use either a thread or a block per row depending on avg nnz/row
|
|
1838
2071
|
work_arrays._mm_block_counts.zero_()
|
|
2072
|
+
count_tile_size = 32
|
|
2073
|
+
if not device.is_cuda or x.nnz < 3 * count_tile_size * x.nrow:
|
|
2074
|
+
count_tile_size = 1
|
|
2075
|
+
|
|
1839
2076
|
wp.launch(
|
|
1840
|
-
kernel=
|
|
2077
|
+
kernel=make_bsr_mm_count_coeffs(count_tile_size),
|
|
1841
2078
|
device=device,
|
|
1842
|
-
dim=z.nrow,
|
|
2079
|
+
dim=(z.nrow, count_tile_size),
|
|
2080
|
+
block_dim=count_tile_size if count_tile_size > 1 else 256,
|
|
1843
2081
|
inputs=[
|
|
1844
2082
|
y.ncol,
|
|
1845
2083
|
copied_z_nnz,
|
|
@@ -1851,7 +2089,7 @@ def bsr_mm(
|
|
|
1851
2089
|
work_arrays._mm_block_counts,
|
|
1852
2090
|
],
|
|
1853
2091
|
)
|
|
1854
|
-
warp.utils.array_scan(work_arrays._mm_block_counts, work_arrays._mm_block_counts)
|
|
2092
|
+
warp.utils.array_scan(work_arrays._mm_block_counts[: x.nnz + 1], work_arrays._mm_block_counts[: x.nnz + 1])
|
|
1855
2093
|
|
|
1856
2094
|
# Get back total counts on host -- we need a synchronization here
|
|
1857
2095
|
# Use pinned buffer from z, we are going to need it later anyway
|
|
@@ -1873,18 +2111,19 @@ def bsr_mm(
|
|
|
1873
2111
|
# Copy z row and column indices
|
|
1874
2112
|
wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
|
|
1875
2113
|
z.uncompress_rows(out=work_arrays._mm_rows)
|
|
2114
|
+
work_arrays._mm_src_blocks[:copied_z_nnz].fill_(-1)
|
|
1876
2115
|
if z_aliasing:
|
|
1877
2116
|
# If z is aliasing with x or y, need to save topology as well
|
|
1878
2117
|
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
1879
2118
|
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
1880
2119
|
|
|
1881
2120
|
# Fill unmerged mm blocks rows and columns
|
|
1882
|
-
work_arrays._mm_rows[copied_z_nnz:].fill_(-1)
|
|
1883
2121
|
wp.launch(
|
|
1884
2122
|
kernel=_bsr_mm_list_coeffs,
|
|
1885
2123
|
device=device,
|
|
1886
|
-
dim=
|
|
2124
|
+
dim=mm_nnz - copied_z_nnz,
|
|
1887
2125
|
inputs=[
|
|
2126
|
+
copied_z_nnz,
|
|
1888
2127
|
x.nrow,
|
|
1889
2128
|
x.offsets,
|
|
1890
2129
|
x.columns,
|
|
@@ -1894,6 +2133,7 @@ def bsr_mm(
|
|
|
1894
2133
|
work_arrays._mm_block_counts,
|
|
1895
2134
|
work_arrays._mm_rows,
|
|
1896
2135
|
work_arrays._mm_cols,
|
|
2136
|
+
work_arrays._mm_src_blocks,
|
|
1897
2137
|
],
|
|
1898
2138
|
)
|
|
1899
2139
|
|
|
@@ -1912,11 +2152,13 @@ def bsr_mm(
|
|
|
1912
2152
|
from warp.context import runtime
|
|
1913
2153
|
|
|
1914
2154
|
if device.is_cpu:
|
|
1915
|
-
native_func = runtime.core.
|
|
2155
|
+
native_func = runtime.core.wp_bsr_matrix_from_triplets_host
|
|
1916
2156
|
else:
|
|
1917
|
-
native_func = runtime.core.
|
|
2157
|
+
native_func = runtime.core.wp_bsr_matrix_from_triplets_device
|
|
1918
2158
|
|
|
1919
2159
|
nnz_buf, nnz_event = z._setup_nnz_transfer()
|
|
2160
|
+
summed_triplet_offsets = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
|
|
2161
|
+
summed_triplet_indices = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
|
|
1920
2162
|
|
|
1921
2163
|
with wp.ScopedDevice(z.device):
|
|
1922
2164
|
native_func(
|
|
@@ -1931,8 +2173,8 @@ def bsr_mm(
|
|
|
1931
2173
|
None, # triplet values
|
|
1932
2174
|
0, # zero_value_mask
|
|
1933
2175
|
False, # masked_topology
|
|
1934
|
-
|
|
1935
|
-
|
|
2176
|
+
ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
2177
|
+
ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1936
2178
|
ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1937
2179
|
ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1938
2180
|
_optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
|
|
@@ -1952,7 +2194,7 @@ def bsr_mm(
|
|
|
1952
2194
|
wp.launch(
|
|
1953
2195
|
kernel=_bsr_axpy_add_block,
|
|
1954
2196
|
device=device,
|
|
1955
|
-
dim=copied_z_nnz,
|
|
2197
|
+
dim=(copied_z_nnz, z.block_shape[0], z.block_shape[1]),
|
|
1956
2198
|
inputs=[
|
|
1957
2199
|
0,
|
|
1958
2200
|
beta,
|
|
@@ -1960,11 +2202,61 @@ def bsr_mm(
|
|
|
1960
2202
|
work_arrays._mm_cols,
|
|
1961
2203
|
z.offsets,
|
|
1962
2204
|
z.columns,
|
|
1963
|
-
work_arrays._old_z_values,
|
|
1964
|
-
z.
|
|
2205
|
+
_as_3d_array(work_arrays._old_z_values, z.block_shape),
|
|
2206
|
+
z.scalar_values,
|
|
1965
2207
|
],
|
|
1966
2208
|
)
|
|
1967
2209
|
|
|
2210
|
+
max_subblock_dim = 12
|
|
2211
|
+
if tile_size > 0:
|
|
2212
|
+
use_tiles = True
|
|
2213
|
+
elif tile_size < 0:
|
|
2214
|
+
use_tiles = False
|
|
2215
|
+
else:
|
|
2216
|
+
# Heuristic for using tiled variant: few or very large blocks
|
|
2217
|
+
tile_size = 64
|
|
2218
|
+
max_tiles_per_sm = 2048 // tile_size # assume 64 resident warps per SM
|
|
2219
|
+
use_tiles = device.is_cuda and (
|
|
2220
|
+
max(x.block_size, y.block_size, z.block_size) > max_subblock_dim**2
|
|
2221
|
+
or mm_nnz < max_tiles_per_sm * device.sm_count
|
|
2222
|
+
)
|
|
2223
|
+
|
|
2224
|
+
if use_tiles:
|
|
2225
|
+
subblock_rows = min(max_subblock_dim, z.block_shape[0])
|
|
2226
|
+
subblock_cols = min(max_subblock_dim, z.block_shape[1])
|
|
2227
|
+
|
|
2228
|
+
wp.launch(
|
|
2229
|
+
kernel=make_bsr_mm_compute_values_tiled_outer(
|
|
2230
|
+
subblock_rows, subblock_cols, x.block_shape[1], z.scalar_type, tile_size
|
|
2231
|
+
),
|
|
2232
|
+
device=device,
|
|
2233
|
+
dim=(
|
|
2234
|
+
z.nnz,
|
|
2235
|
+
(z.block_shape[0] + subblock_rows - 1) // subblock_rows,
|
|
2236
|
+
(z.block_shape[1] + subblock_cols - 1) // subblock_cols,
|
|
2237
|
+
tile_size,
|
|
2238
|
+
),
|
|
2239
|
+
block_dim=tile_size,
|
|
2240
|
+
inputs=[
|
|
2241
|
+
alpha,
|
|
2242
|
+
work_arrays._old_z_offsets if x == z else x.offsets,
|
|
2243
|
+
work_arrays._old_z_columns if x == z else x.columns,
|
|
2244
|
+
_as_3d_array(work_arrays._old_z_values, z.block_shape) if x == z else x.scalar_values,
|
|
2245
|
+
work_arrays._old_z_offsets if y == z else y.offsets,
|
|
2246
|
+
work_arrays._old_z_columns if y == z else y.columns,
|
|
2247
|
+
_as_3d_array(work_arrays._old_z_values, z.block_shape) if y == z else y.scalar_values,
|
|
2248
|
+
None if masked else work_arrays._mm_row_min,
|
|
2249
|
+
None if masked else summed_triplet_offsets,
|
|
2250
|
+
None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
|
|
2251
|
+
z.nrow,
|
|
2252
|
+
z.offsets,
|
|
2253
|
+
z.columns,
|
|
2254
|
+
z.scalar_values,
|
|
2255
|
+
],
|
|
2256
|
+
)
|
|
2257
|
+
|
|
2258
|
+
return z
|
|
2259
|
+
|
|
1968
2260
|
# Add mm blocks to z values
|
|
1969
2261
|
if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
|
|
1970
2262
|
# Result block type is scalar, but operands are matrices
|
|
@@ -1985,6 +2277,9 @@ def bsr_mm(
|
|
|
1985
2277
|
work_arrays._old_z_offsets if y == z else y.offsets,
|
|
1986
2278
|
work_arrays._old_z_columns if y == z else y.columns,
|
|
1987
2279
|
work_arrays._old_z_values if y == z else y.values,
|
|
2280
|
+
None if masked else work_arrays._mm_row_min,
|
|
2281
|
+
None if masked else summed_triplet_offsets,
|
|
2282
|
+
None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
|
|
1988
2283
|
z.nrow,
|
|
1989
2284
|
z.offsets,
|
|
1990
2285
|
z.columns,
|
|
@@ -1995,51 +2290,125 @@ def bsr_mm(
|
|
|
1995
2290
|
return z
|
|
1996
2291
|
|
|
1997
2292
|
|
|
1998
|
-
|
|
1999
|
-
|
|
2000
|
-
alpha: Any,
|
|
2001
|
-
A_offsets: wp.array(dtype=int),
|
|
2002
|
-
A_columns: wp.array(dtype=int),
|
|
2003
|
-
A_values: wp.array(dtype=Any),
|
|
2004
|
-
x: wp.array(dtype=Any),
|
|
2005
|
-
beta: Any,
|
|
2006
|
-
y: wp.array(dtype=Any),
|
|
2007
|
-
):
|
|
2008
|
-
row = wp.tid()
|
|
2293
|
+
def make_bsr_mv_kernel(block_cols: int):
|
|
2294
|
+
from warp.fem.cache import dynamic_kernel
|
|
2009
2295
|
|
|
2010
|
-
|
|
2011
|
-
|
|
2012
|
-
|
|
2296
|
+
@dynamic_kernel(suffix=f"{block_cols}", kernel_options={"enable_backward": False})
|
|
2297
|
+
def bsr_mv_kernel(
|
|
2298
|
+
alpha: Any,
|
|
2299
|
+
A_offsets: wp.array(dtype=int),
|
|
2300
|
+
A_columns: wp.array(dtype=int),
|
|
2301
|
+
A_values: wp.array3d(dtype=Any),
|
|
2302
|
+
x: wp.array(dtype=Any),
|
|
2303
|
+
beta: Any,
|
|
2304
|
+
y: wp.array(dtype=Any),
|
|
2305
|
+
):
|
|
2306
|
+
row, subrow = wp.tid()
|
|
2013
2307
|
|
|
2014
|
-
|
|
2015
|
-
beg = A_offsets[row]
|
|
2016
|
-
end = A_offsets[row + 1]
|
|
2017
|
-
for block in range(beg, end):
|
|
2018
|
-
v += A_values[block] * x[A_columns[block]]
|
|
2019
|
-
v *= alpha
|
|
2308
|
+
block_rows = A_values.shape[1]
|
|
2020
2309
|
|
|
2021
|
-
|
|
2022
|
-
v += beta * y[row]
|
|
2310
|
+
yi = row * block_rows + subrow
|
|
2023
2311
|
|
|
2024
|
-
|
|
2312
|
+
# zero-initialize with type of y elements
|
|
2313
|
+
scalar_zero = type(alpha)(0)
|
|
2314
|
+
v = scalar_zero
|
|
2025
2315
|
|
|
2316
|
+
if alpha != scalar_zero:
|
|
2317
|
+
beg = A_offsets[row]
|
|
2318
|
+
end = A_offsets[row + 1]
|
|
2319
|
+
for block in range(beg, end):
|
|
2320
|
+
xs = A_columns[block] * block_cols
|
|
2321
|
+
for col in range(wp.static(block_cols)):
|
|
2322
|
+
v += A_values[block, subrow, col] * x[xs + col]
|
|
2323
|
+
v *= alpha
|
|
2026
2324
|
|
|
2027
|
-
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
):
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
wp.
|
|
2325
|
+
if beta != scalar_zero:
|
|
2326
|
+
v += beta * y[yi]
|
|
2327
|
+
|
|
2328
|
+
y[yi] = v
|
|
2329
|
+
|
|
2330
|
+
return bsr_mv_kernel
|
|
2331
|
+
|
|
2332
|
+
|
|
2333
|
+
def make_bsr_mv_tiled_kernel(tile_size: int):
|
|
2334
|
+
from warp.fem.cache import dynamic_kernel
|
|
2335
|
+
|
|
2336
|
+
@dynamic_kernel(suffix=f"{tile_size}", kernel_options={"enable_backward": False})
|
|
2337
|
+
def bsr_mv_tiled_kernel(
|
|
2338
|
+
alpha: Any,
|
|
2339
|
+
A_offsets: wp.array(dtype=int),
|
|
2340
|
+
A_columns: wp.array(dtype=int),
|
|
2341
|
+
A_values: wp.array3d(dtype=Any),
|
|
2342
|
+
x: wp.array(dtype=Any),
|
|
2343
|
+
beta: Any,
|
|
2344
|
+
y: wp.array(dtype=Any),
|
|
2345
|
+
):
|
|
2346
|
+
row, subrow, lane = wp.tid()
|
|
2347
|
+
|
|
2348
|
+
scalar_zero = type(alpha)(0)
|
|
2349
|
+
block_rows = A_values.shape[1]
|
|
2350
|
+
block_cols = A_values.shape[2]
|
|
2351
|
+
|
|
2352
|
+
yi = row * block_rows + subrow
|
|
2353
|
+
|
|
2354
|
+
if beta == scalar_zero:
|
|
2355
|
+
subrow_sum = wp.tile_zeros(shape=(1,), dtype=y.dtype)
|
|
2356
|
+
else:
|
|
2357
|
+
subrow_sum = beta * wp.tile_load(y, 1, yi)
|
|
2358
|
+
|
|
2359
|
+
if alpha != scalar_zero:
|
|
2360
|
+
block_beg = A_offsets[row]
|
|
2361
|
+
col_count = (A_offsets[row + 1] - block_beg) * block_cols
|
|
2362
|
+
|
|
2363
|
+
col = lane
|
|
2364
|
+
lane_sum = y.dtype(0)
|
|
2365
|
+
|
|
2366
|
+
for col in range(lane, col_count, tile_size):
|
|
2367
|
+
block = col // block_cols
|
|
2368
|
+
block_col = col - block * block_cols
|
|
2369
|
+
block += block_beg
|
|
2370
|
+
|
|
2371
|
+
xi = x[A_columns[block] * block_cols + block_col]
|
|
2372
|
+
lane_sum += A_values[block, subrow, block_col] * xi
|
|
2373
|
+
|
|
2374
|
+
lane_sum *= alpha
|
|
2375
|
+
subrow_sum += wp.tile_sum(wp.tile(lane_sum))
|
|
2376
|
+
|
|
2377
|
+
wp.tile_store(y, subrow_sum, yi)
|
|
2378
|
+
|
|
2379
|
+
return bsr_mv_tiled_kernel
|
|
2380
|
+
|
|
2381
|
+
|
|
2382
|
+
def make_bsr_mv_transpose_kernel(block_rows: int):
|
|
2383
|
+
from warp.fem.cache import dynamic_kernel
|
|
2384
|
+
|
|
2385
|
+
@dynamic_kernel(suffix=f"{block_rows}", kernel_options={"enable_backward": False})
|
|
2386
|
+
def bsr_mv_transpose_kernel(
|
|
2387
|
+
alpha: Any,
|
|
2388
|
+
A_row_count: int,
|
|
2389
|
+
A_offsets: wp.array(dtype=int),
|
|
2390
|
+
A_columns: wp.array(dtype=int),
|
|
2391
|
+
A_values: wp.array3d(dtype=Any),
|
|
2392
|
+
x: wp.array(dtype=Any),
|
|
2393
|
+
y: wp.array(dtype=Any),
|
|
2394
|
+
):
|
|
2395
|
+
block, subcol = wp.tid()
|
|
2396
|
+
|
|
2397
|
+
row = _bsr_row_index(A_offsets, A_row_count, block)
|
|
2398
|
+
if row == -1:
|
|
2399
|
+
return
|
|
2400
|
+
|
|
2401
|
+
block_cols = A_values.shape[2]
|
|
2402
|
+
|
|
2403
|
+
A_block = A_values[block]
|
|
2404
|
+
|
|
2405
|
+
col_sum = type(alpha)(0)
|
|
2406
|
+
for subrow in range(wp.static(block_rows)):
|
|
2407
|
+
col_sum += A_block[subrow, subcol] * x[row * block_rows + subrow]
|
|
2408
|
+
|
|
2409
|
+
wp.atomic_add(y, A_columns[block] * block_cols + subcol, alpha * col_sum)
|
|
2410
|
+
|
|
2411
|
+
return bsr_mv_transpose_kernel
|
|
2043
2412
|
|
|
2044
2413
|
|
|
2045
2414
|
def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
|
|
@@ -2092,6 +2461,7 @@ def bsr_mv(
|
|
|
2092
2461
|
beta: Scalar = 0.0,
|
|
2093
2462
|
transpose: bool = False,
|
|
2094
2463
|
work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
2464
|
+
tile_size: int = 0,
|
|
2095
2465
|
) -> "Array[Vector[Rows, Scalar] | Scalar]":
|
|
2096
2466
|
"""Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
|
|
2097
2467
|
|
|
@@ -2107,6 +2477,9 @@ def bsr_mv(
|
|
|
2107
2477
|
work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
|
|
2108
2478
|
If provided, the ``work_buffer`` array will be used for this purpose,
|
|
2109
2479
|
otherwise a temporary allocation will be performed.
|
|
2480
|
+
tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
|
|
2481
|
+
If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
|
|
2482
|
+
use tiles using using an heuristic based on the matrix shape and number of non-zeros..
|
|
2110
2483
|
"""
|
|
2111
2484
|
|
|
2112
2485
|
A, A_scale = _extract_matrix_and_scale(A)
|
|
@@ -2129,6 +2502,7 @@ def bsr_mv(
|
|
|
2129
2502
|
alpha = A.scalar_type(alpha)
|
|
2130
2503
|
beta = A.scalar_type(beta)
|
|
2131
2504
|
|
|
2505
|
+
device = A.values.device
|
|
2132
2506
|
if A.values.device != x.device or A.values.device != y.device:
|
|
2133
2507
|
raise ValueError(
|
|
2134
2508
|
f"A, x, and y must reside on the same device, got {A.values.device}, {x.device} and {y.device}"
|
|
@@ -2149,23 +2523,24 @@ def bsr_mv(
|
|
|
2149
2523
|
wp.copy(dest=work_buffer, src=y, count=y.size)
|
|
2150
2524
|
x = work_buffer
|
|
2151
2525
|
|
|
2152
|
-
# Promote scalar vectors to length-1 vecs and conversely
|
|
2153
|
-
if type_is_matrix(A.values.dtype):
|
|
2154
|
-
x_dtype = wp.vec(length=block_shape[1], dtype=A.scalar_type)
|
|
2155
|
-
y_dtype = wp.vec(length=block_shape[0], dtype=A.scalar_type)
|
|
2156
|
-
else:
|
|
2157
|
-
x_dtype = A.scalar_type
|
|
2158
|
-
y_dtype = A.scalar_type
|
|
2159
|
-
|
|
2160
2526
|
try:
|
|
2161
|
-
x_view = _vec_array_view(x,
|
|
2527
|
+
x_view = _vec_array_view(x, A.scalar_type, expected_scalar_count=ncol * block_shape[1])
|
|
2162
2528
|
except ValueError as err:
|
|
2163
2529
|
raise ValueError("Incompatible 'x' vector for bsr_mv") from err
|
|
2164
2530
|
try:
|
|
2165
|
-
y_view = _vec_array_view(y,
|
|
2531
|
+
y_view = _vec_array_view(y, A.scalar_type, expected_scalar_count=nrow * block_shape[0])
|
|
2166
2532
|
except ValueError as err:
|
|
2167
2533
|
raise ValueError("Incompatible 'y' vector for bsr_mv") from err
|
|
2168
2534
|
|
|
2535
|
+
# heuristic to use tiled version for long rows
|
|
2536
|
+
if tile_size > 0:
|
|
2537
|
+
use_tiles = True
|
|
2538
|
+
elif tile_size < 0:
|
|
2539
|
+
use_tiles = False
|
|
2540
|
+
else:
|
|
2541
|
+
tile_size = 64
|
|
2542
|
+
use_tiles = device.is_cuda and A.nnz * A.block_size > 2 * tile_size * A.shape[0]
|
|
2543
|
+
|
|
2169
2544
|
if transpose:
|
|
2170
2545
|
if beta.value == 0.0:
|
|
2171
2546
|
y.zero_()
|
|
@@ -2173,22 +2548,30 @@ def bsr_mv(
|
|
|
2173
2548
|
wp.launch(
|
|
2174
2549
|
kernel=_bsr_scale_kernel,
|
|
2175
2550
|
device=y.device,
|
|
2176
|
-
dim=
|
|
2177
|
-
inputs=[beta,
|
|
2551
|
+
dim=y_view.shape[0],
|
|
2552
|
+
inputs=[beta, y_view],
|
|
2178
2553
|
)
|
|
2179
2554
|
if alpha.value != 0.0:
|
|
2180
2555
|
wp.launch(
|
|
2181
|
-
kernel=
|
|
2556
|
+
kernel=make_bsr_mv_transpose_kernel(block_rows=block_shape[1]),
|
|
2182
2557
|
device=A.values.device,
|
|
2183
|
-
dim=
|
|
2184
|
-
inputs=[alpha, A.offsets, A.columns, A.
|
|
2558
|
+
dim=(A.nnz, block_shape[0]),
|
|
2559
|
+
inputs=[alpha, A.nrow, A.offsets, A.columns, A.scalar_values, x_view, y_view],
|
|
2185
2560
|
)
|
|
2561
|
+
elif use_tiles:
|
|
2562
|
+
wp.launch(
|
|
2563
|
+
kernel=make_bsr_mv_tiled_kernel(tile_size),
|
|
2564
|
+
device=A.values.device,
|
|
2565
|
+
dim=(nrow, block_shape[0], tile_size),
|
|
2566
|
+
block_dim=tile_size,
|
|
2567
|
+
inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
|
|
2568
|
+
)
|
|
2186
2569
|
else:
|
|
2187
2570
|
wp.launch(
|
|
2188
|
-
kernel=
|
|
2571
|
+
kernel=make_bsr_mv_kernel(block_cols=block_shape[1]),
|
|
2189
2572
|
device=A.values.device,
|
|
2190
|
-
dim=nrow,
|
|
2191
|
-
inputs=[alpha, A.offsets, A.columns, A.
|
|
2573
|
+
dim=(nrow, block_shape[0]),
|
|
2574
|
+
inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
|
|
2192
2575
|
)
|
|
2193
2576
|
|
|
2194
2577
|
return y
|