warp-lang 1.0.1__py3-none-macosx_10_13_universal2.whl → 1.1.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -279
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -28
- warp/examples/core/example_dem.py +234 -221
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -129
- warp/examples/core/example_marching_cubes.py +188 -176
- warp/examples/core/example_mesh.py +174 -154
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -169
- warp/examples/core/example_raycast.py +105 -89
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -389
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -249
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -391
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -248
- warp/examples/optim/example_cloth_throw.py +222 -210
- warp/examples/optim/example_diffray.py +566 -535
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -169
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
- warp/examples/optim/example_spring_cage.py +239 -234
- warp/examples/optim/example_trajectory.py +223 -201
- warp/examples/optim/example_walker.py +306 -292
- warp/examples/sim/example_cartpole.py +139 -128
- warp/examples/sim/example_cloth.py +196 -184
- warp/examples/sim/example_granular.py +124 -113
- warp/examples/sim/example_granular_collision_sdf.py +197 -185
- warp/examples/sim/example_jacobian_ik.py +236 -213
- warp/examples/sim/example_particle_chain.py +118 -106
- warp/examples/sim/example_quadruped.py +193 -179
- warp/examples/sim/example_rigid_chain.py +197 -189
- warp/examples/sim/example_rigid_contact.py +189 -176
- warp/examples/sim/example_rigid_force.py +127 -126
- warp/examples/sim/example_rigid_gyroscopic.py +109 -97
- warp/examples/sim/example_rigid_soft_contact.py +134 -124
- warp/examples/sim/example_soft_body.py +190 -178
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.1.dist-info/RECORD +0 -352
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/sparse.py
CHANGED
|
@@ -1,1227 +1,1227 @@
|
|
|
1
|
-
from typing import Any, Generic, Optional, Tuple, TypeVar, Union
|
|
2
|
-
|
|
3
|
-
import warp as wp
|
|
4
|
-
import warp.types
|
|
5
|
-
import warp.utils
|
|
6
|
-
from warp.types import Array, Cols,
|
|
7
|
-
|
|
8
|
-
# typing hints
|
|
9
|
-
|
|
10
|
-
_BlockType = TypeVar("BlockType")
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class _MatrixBlockType(
|
|
14
|
-
pass
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class _ScalarBlockType(Generic[Scalar]):
|
|
18
|
-
pass
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
|
|
22
|
-
|
|
23
|
-
_struct_cache =
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class BsrMatrix(Generic[_BlockType]):
|
|
27
|
-
"""Untyped base class for BSR and CSR matrices.
|
|
28
|
-
|
|
29
|
-
Should not be constructed directly but through functions such as :func:`bsr_zeros`.
|
|
30
|
-
|
|
31
|
-
Attributes:
|
|
32
|
-
nrow (int): Number of rows of blocks
|
|
33
|
-
ncol (int): Number of columns of blocks
|
|
34
|
-
nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
|
|
35
|
-
offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
|
|
36
|
-
columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
|
|
37
|
-
values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
|
|
38
|
-
"""
|
|
39
|
-
|
|
40
|
-
@property
|
|
41
|
-
def scalar_type(self) -> Scalar:
|
|
42
|
-
"""Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
|
|
43
|
-
return warp.types.type_scalar_type(self.values.dtype)
|
|
44
|
-
|
|
45
|
-
@property
|
|
46
|
-
def block_shape(self) -> Tuple[int, int]:
|
|
47
|
-
"""Shape of the individual blocks"""
|
|
48
|
-
return getattr(self.values.dtype, "_shape_", (1, 1))
|
|
49
|
-
|
|
50
|
-
@property
|
|
51
|
-
def block_size(self) -> int:
|
|
52
|
-
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
|
|
53
|
-
return warp.types.type_length(self.values.dtype)
|
|
54
|
-
|
|
55
|
-
@property
|
|
56
|
-
def shape(self) -> Tuple[int, int]:
|
|
57
|
-
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
|
|
58
|
-
block_shape = self.block_shape
|
|
59
|
-
return (self.nrow * block_shape[0], self.ncol * block_shape[1])
|
|
60
|
-
|
|
61
|
-
@property
|
|
62
|
-
def dtype(self) -> type:
|
|
63
|
-
"""Data type for individual block values"""
|
|
64
|
-
return self.values.dtype
|
|
65
|
-
|
|
66
|
-
@property
|
|
67
|
-
def device(self) -> wp.context.Device:
|
|
68
|
-
"""Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays
|
|
69
|
-
return self.values.device
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def bsr_matrix_t(dtype: BlockType):
|
|
73
|
-
dtype = wp.types.type_to_warp(dtype)
|
|
74
|
-
|
|
75
|
-
if not warp.types.type_is_matrix(dtype) and not
|
|
76
|
-
raise ValueError(
|
|
77
|
-
f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
class BsrMatrixTyped(BsrMatrix):
|
|
81
|
-
nrow: int
|
|
82
|
-
"""Number of rows of blocks"""
|
|
83
|
-
ncol: int
|
|
84
|
-
"""Number of columns of blocks"""
|
|
85
|
-
nnz: int
|
|
86
|
-
"""Number of non-zero blocks: equal to offsets[-1], cached on host for convenience"""
|
|
87
|
-
offsets: wp.array(dtype=int)
|
|
88
|
-
"""Array of size at least 1 + nrows"""
|
|
89
|
-
columns: wp.array(dtype=int)
|
|
90
|
-
"""Array of size at least equal to nnz"""
|
|
91
|
-
values: wp.array(dtype=dtype)
|
|
92
|
-
|
|
93
|
-
module = wp.get_module(BsrMatrix.__module__)
|
|
94
|
-
|
|
95
|
-
if hasattr(dtype, "_shape_"):
|
|
96
|
-
type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
|
|
97
|
-
else:
|
|
98
|
-
type_str = dtype.__name__
|
|
99
|
-
key = f"{BsrMatrix.__qualname__}_{type_str}"
|
|
100
|
-
|
|
101
|
-
if key not in _struct_cache:
|
|
102
|
-
_struct_cache[key] = wp.codegen.Struct(
|
|
103
|
-
cls=BsrMatrixTyped,
|
|
104
|
-
key=key,
|
|
105
|
-
module=module,
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
return _struct_cache[key]
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def bsr_zeros(
|
|
112
|
-
rows_of_blocks: int,
|
|
113
|
-
cols_of_blocks: int,
|
|
114
|
-
block_type: BlockType,
|
|
115
|
-
device: wp.context.Devicelike = None,
|
|
116
|
-
) -> BsrMatrix:
|
|
117
|
-
"""
|
|
118
|
-
Constructs and returns an empty BSR or CSR matrix with the given shape
|
|
119
|
-
|
|
120
|
-
Args:
|
|
121
|
-
bsr: The BSR or CSR matrix to set to zero
|
|
122
|
-
rows_of_blocks: Number of rows of blocks
|
|
123
|
-
cols_of_blocks: Number of columns of blocks
|
|
124
|
-
block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
|
|
125
|
-
for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
|
|
126
|
-
device: Device on which to allocate the matrix arrays
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
bsr = bsr_matrix_t(block_type)()
|
|
130
|
-
|
|
131
|
-
bsr.nrow = rows_of_blocks
|
|
132
|
-
bsr.ncol = cols_of_blocks
|
|
133
|
-
bsr.nnz = 0
|
|
134
|
-
bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
|
|
135
|
-
bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
|
|
136
|
-
bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
|
|
137
|
-
|
|
138
|
-
return bsr
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
|
|
142
|
-
if nrow is None:
|
|
143
|
-
nrow = bsr.nrow
|
|
144
|
-
if nnz is None:
|
|
145
|
-
nnz = bsr.nnz
|
|
146
|
-
|
|
147
|
-
if bsr.offsets.size < nrow + 1:
|
|
148
|
-
bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
|
|
149
|
-
if bsr.columns.size < nnz:
|
|
150
|
-
bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
|
|
151
|
-
if bsr.values.size < nnz:
|
|
152
|
-
bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
|
|
156
|
-
"""
|
|
157
|
-
Sets a BSR matrix to zero, possibly changing its size
|
|
158
|
-
|
|
159
|
-
Args:
|
|
160
|
-
bsr: The BSR or CSR matrix to set to zero
|
|
161
|
-
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
162
|
-
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
163
|
-
"""
|
|
164
|
-
|
|
165
|
-
if rows_of_blocks is not None:
|
|
166
|
-
bsr.nrow = rows_of_blocks
|
|
167
|
-
if cols_of_blocks is not None:
|
|
168
|
-
bsr.ncol = cols_of_blocks
|
|
169
|
-
bsr.nnz = 0
|
|
170
|
-
_bsr_ensure_fits(bsr)
|
|
171
|
-
bsr.offsets.zero_()
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def bsr_set_from_triplets(
|
|
175
|
-
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
176
|
-
rows: "Array[int]",
|
|
177
|
-
columns: "Array[int]",
|
|
178
|
-
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
179
|
-
):
|
|
180
|
-
"""
|
|
181
|
-
Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
182
|
-
|
|
183
|
-
The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
|
|
184
|
-
|
|
185
|
-
Args:
|
|
186
|
-
dest: Sparse matrix to populate
|
|
187
|
-
rows: Row index for each non-zero
|
|
188
|
-
columns: Columns index for each non-zero
|
|
189
|
-
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
190
|
-
to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
|
|
191
|
-
"""
|
|
192
|
-
|
|
193
|
-
if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
|
|
194
|
-
raise ValueError("All arguments must reside on the same device")
|
|
195
|
-
|
|
196
|
-
if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]:
|
|
197
|
-
raise ValueError("All triplet arrays must have the same length")
|
|
198
|
-
|
|
199
|
-
# Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
|
|
200
|
-
if values.ndim == 1:
|
|
201
|
-
if values.dtype != dest.values.dtype:
|
|
202
|
-
raise ValueError("Values array type must correspond to that of dest matrix")
|
|
203
|
-
elif values.ndim == 3:
|
|
204
|
-
if values.shape[1:] != dest.block_shape:
|
|
205
|
-
raise ValueError(
|
|
206
|
-
f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
|
|
210
|
-
raise ValueError("Scalar type of values array should correspond to that of matrix")
|
|
211
|
-
|
|
212
|
-
if not values.is_contiguous:
|
|
213
|
-
raise ValueError("Multi-dimensional values array should be contiguous")
|
|
214
|
-
else:
|
|
215
|
-
raise ValueError("Number of dimension for values array should be 1 or 3")
|
|
216
|
-
|
|
217
|
-
nnz = rows.shape[0]
|
|
218
|
-
if nnz == 0:
|
|
219
|
-
bsr_set_zero(dest)
|
|
220
|
-
return
|
|
221
|
-
|
|
222
|
-
# Increase dest array sizes if needed
|
|
223
|
-
_bsr_ensure_fits(dest, nnz=nnz)
|
|
224
|
-
|
|
225
|
-
device = dest.values.device
|
|
226
|
-
scalar_type = dest.scalar_type
|
|
227
|
-
from warp.context import runtime
|
|
228
|
-
|
|
229
|
-
if device.is_cpu:
|
|
230
|
-
if scalar_type == wp.float32:
|
|
231
|
-
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
232
|
-
elif scalar_type == wp.float64:
|
|
233
|
-
native_func = runtime.core.bsr_matrix_from_triplets_double_host
|
|
234
|
-
else:
|
|
235
|
-
if scalar_type == wp.float32:
|
|
236
|
-
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
237
|
-
elif scalar_type == wp.float64:
|
|
238
|
-
native_func = runtime.core.bsr_matrix_from_triplets_double_device
|
|
239
|
-
|
|
240
|
-
if not native_func:
|
|
241
|
-
raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
|
|
242
|
-
|
|
243
|
-
dest.nnz = native_func(
|
|
244
|
-
dest.block_shape[0],
|
|
245
|
-
dest.block_shape[1],
|
|
246
|
-
dest.nrow,
|
|
247
|
-
nnz,
|
|
248
|
-
rows.ptr,
|
|
249
|
-
columns.ptr,
|
|
250
|
-
values.ptr,
|
|
251
|
-
dest.offsets.ptr,
|
|
252
|
-
dest.columns.ptr,
|
|
253
|
-
dest.values.ptr,
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
|
|
258
|
-
"""Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
|
|
259
|
-
|
|
260
|
-
if dest.values.device != src.values.device:
|
|
261
|
-
raise ValueError("Source and destination matrices must reside on the same device")
|
|
262
|
-
|
|
263
|
-
if dest.block_shape != src.block_shape:
|
|
264
|
-
raise ValueError("Source and destination matrices must have the same block shape")
|
|
265
|
-
|
|
266
|
-
dest.nrow = src.nrow
|
|
267
|
-
dest.ncol = src.ncol
|
|
268
|
-
dest.nnz = src.nnz
|
|
269
|
-
|
|
270
|
-
_bsr_ensure_fits(dest)
|
|
271
|
-
|
|
272
|
-
wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
|
|
273
|
-
if src.nnz > 0:
|
|
274
|
-
wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
|
|
275
|
-
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
|
|
279
|
-
"""Returns a copy of matrix ``A``, possibly changing its scalar type.
|
|
280
|
-
|
|
281
|
-
Args:
|
|
282
|
-
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
|
|
283
|
-
"""
|
|
284
|
-
if scalar_type is None:
|
|
285
|
-
block_type = A.values.dtype
|
|
286
|
-
elif A.block_shape == (1, 1):
|
|
287
|
-
block_type = scalar_type
|
|
288
|
-
else:
|
|
289
|
-
block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type)
|
|
290
|
-
|
|
291
|
-
copy = bsr_zeros(rows_of_blocks=A.nrow, cols_of_blocks=A.ncol, block_type=block_type, device=A.values.device)
|
|
292
|
-
bsr_assign(dest=copy, src=A)
|
|
293
|
-
return copy
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
|
|
297
|
-
"""Assigns the transposed matrix `src` to matrix `dest`"""
|
|
298
|
-
|
|
299
|
-
if dest.values.device != src.values.device:
|
|
300
|
-
raise ValueError("All arguments must reside on the same device")
|
|
301
|
-
|
|
302
|
-
if dest.scalar_type != src.scalar_type:
|
|
303
|
-
raise ValueError("All arguments must have the same scalar type")
|
|
304
|
-
|
|
305
|
-
transpose_block_shape = src.block_shape[::-1]
|
|
306
|
-
|
|
307
|
-
if dest.block_shape != transpose_block_shape:
|
|
308
|
-
raise ValueError(f"Destination block shape must be {transpose_block_shape}")
|
|
309
|
-
|
|
310
|
-
dest.nrow = src.ncol
|
|
311
|
-
dest.ncol = src.nrow
|
|
312
|
-
dest.nnz = src.nnz
|
|
313
|
-
|
|
314
|
-
if src.nnz == 0:
|
|
315
|
-
return
|
|
316
|
-
|
|
317
|
-
# Increase dest array sizes if needed
|
|
318
|
-
_bsr_ensure_fits(dest)
|
|
319
|
-
|
|
320
|
-
from warp.context import runtime
|
|
321
|
-
|
|
322
|
-
if dest.values.device.is_cpu:
|
|
323
|
-
if dest.scalar_type == wp.float32:
|
|
324
|
-
native_func = runtime.core.bsr_transpose_float_host
|
|
325
|
-
elif dest.scalar_type == wp.float64:
|
|
326
|
-
native_func = runtime.core.bsr_transpose_double_host
|
|
327
|
-
else:
|
|
328
|
-
if dest.scalar_type == wp.float32:
|
|
329
|
-
native_func = runtime.core.bsr_transpose_float_device
|
|
330
|
-
elif dest.scalar_type == wp.float64:
|
|
331
|
-
native_func = runtime.core.bsr_transpose_double_device
|
|
332
|
-
|
|
333
|
-
if not native_func:
|
|
334
|
-
raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
|
|
335
|
-
|
|
336
|
-
native_func(
|
|
337
|
-
src.block_shape[0],
|
|
338
|
-
src.block_shape[1],
|
|
339
|
-
src.nrow,
|
|
340
|
-
src.ncol,
|
|
341
|
-
src.nnz,
|
|
342
|
-
src.offsets.ptr,
|
|
343
|
-
src.columns.ptr,
|
|
344
|
-
src.values.ptr,
|
|
345
|
-
dest.offsets.ptr,
|
|
346
|
-
dest.columns.ptr,
|
|
347
|
-
dest.values.ptr,
|
|
348
|
-
)
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
def bsr_transposed(A: BsrMatrix):
|
|
352
|
-
"""Returns a copy of the transposed matrix `A`"""
|
|
353
|
-
|
|
354
|
-
if A.block_shape == (1, 1):
|
|
355
|
-
block_type = A.values.dtype
|
|
356
|
-
else:
|
|
357
|
-
block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type)
|
|
358
|
-
|
|
359
|
-
transposed = bsr_zeros(rows_of_blocks=A.ncol, cols_of_blocks=A.nrow, block_type=block_type, device=A.values.device)
|
|
360
|
-
bsr_set_transpose(dest=transposed, src=A)
|
|
361
|
-
return transposed
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
@wp.kernel
|
|
365
|
-
def _bsr_get_diag_kernel(
|
|
366
|
-
A_offsets: wp.array(dtype=int),
|
|
367
|
-
A_columns: wp.array(dtype=int),
|
|
368
|
-
A_values: wp.array(dtype=Any),
|
|
369
|
-
out: wp.array(dtype=Any),
|
|
370
|
-
):
|
|
371
|
-
row = wp.tid()
|
|
372
|
-
beg = A_offsets[row]
|
|
373
|
-
end = A_offsets[row + 1]
|
|
374
|
-
|
|
375
|
-
diag = wp.lower_bound(A_columns, beg, end, row)
|
|
376
|
-
if diag < end:
|
|
377
|
-
if A_columns[diag] == row:
|
|
378
|
-
out[row] = A_values[diag]
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
|
|
382
|
-
"""Returns the array of blocks that constitute the diagonal of a sparse matrix.
|
|
383
|
-
|
|
384
|
-
Args:
|
|
385
|
-
A: the sparse matrix from which to extract the diagonal
|
|
386
|
-
out: if provided, the array into which to store the diagonal blocks
|
|
387
|
-
"""
|
|
388
|
-
|
|
389
|
-
dim = min(A.nrow, A.ncol)
|
|
390
|
-
|
|
391
|
-
if out is None:
|
|
392
|
-
out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
|
|
393
|
-
else:
|
|
394
|
-
if out.dtype != A.values.dtype:
|
|
395
|
-
raise ValueError(f"Output array must have type {A.values.dtype}")
|
|
396
|
-
if out.device != A.values.device:
|
|
397
|
-
raise ValueError(f"Output array must reside on device {A.values.device}")
|
|
398
|
-
if out.shape[0] < dim:
|
|
399
|
-
raise ValueError(f"Output array must be of length at least {dim}")
|
|
400
|
-
|
|
401
|
-
wp.launch(
|
|
402
|
-
kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
return out
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
@wp.kernel
|
|
409
|
-
def _bsr_set_diag_kernel(
|
|
410
|
-
diag: wp.array(dtype=Any),
|
|
411
|
-
A_offsets: wp.array(dtype=int),
|
|
412
|
-
A_columns: wp.array(dtype=int),
|
|
413
|
-
A_values: wp.array(dtype=Any),
|
|
414
|
-
):
|
|
415
|
-
row = wp.tid()
|
|
416
|
-
A_offsets[row + 1] = row + 1
|
|
417
|
-
A_columns[row] = row
|
|
418
|
-
A_values[row] = diag[row]
|
|
419
|
-
|
|
420
|
-
if row == 0:
|
|
421
|
-
A_offsets[0] = 0
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
@wp.kernel
|
|
425
|
-
def _bsr_set_diag_constant_kernel(
|
|
426
|
-
diag_value: Any,
|
|
427
|
-
A_offsets: wp.array(dtype=int),
|
|
428
|
-
A_columns: wp.array(dtype=int),
|
|
429
|
-
A_values: wp.array(dtype=Any),
|
|
430
|
-
):
|
|
431
|
-
row = wp.tid()
|
|
432
|
-
A_offsets[row + 1] = row + 1
|
|
433
|
-
A_columns[row] = row
|
|
434
|
-
A_values[row] = diag_value
|
|
435
|
-
|
|
436
|
-
if row == 0:
|
|
437
|
-
A_offsets[0] = 0
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
def bsr_set_diag(
|
|
441
|
-
A: BsrMatrix[BlockType],
|
|
442
|
-
diag: "Union[BlockType, Array[BlockType]]",
|
|
443
|
-
rows_of_blocks: Optional[int] = None,
|
|
444
|
-
cols_of_blocks: Optional[int] = None,
|
|
445
|
-
):
|
|
446
|
-
"""Sets `A` as a block-diagonal matrix
|
|
447
|
-
|
|
448
|
-
Args:
|
|
449
|
-
A: the sparse matrix to modify
|
|
450
|
-
diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
|
|
451
|
-
or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
|
|
452
|
-
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
453
|
-
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
454
|
-
|
|
455
|
-
The shape of the matrix will be defined one of the following, in that order:
|
|
456
|
-
- `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
|
|
457
|
-
- the first dimension of `diag`, if `diag` is an array
|
|
458
|
-
- the current dimensions of `A` otherwise
|
|
459
|
-
"""
|
|
460
|
-
|
|
461
|
-
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
462
|
-
rows_of_blocks = cols_of_blocks
|
|
463
|
-
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
464
|
-
cols_of_blocks = rows_of_blocks
|
|
465
|
-
|
|
466
|
-
if warp.types.is_array(diag):
|
|
467
|
-
if rows_of_blocks is None:
|
|
468
|
-
rows_of_blocks = diag.shape[0]
|
|
469
|
-
cols_of_blocks = diag.shape[0]
|
|
470
|
-
|
|
471
|
-
if rows_of_blocks is not None:
|
|
472
|
-
A.nrow = rows_of_blocks
|
|
473
|
-
A.ncol = cols_of_blocks
|
|
474
|
-
|
|
475
|
-
A.nnz = min(A.nrow, A.ncol)
|
|
476
|
-
_bsr_ensure_fits(A)
|
|
477
|
-
|
|
478
|
-
if warp.types.is_array(diag):
|
|
479
|
-
wp.launch(
|
|
480
|
-
kernel=_bsr_set_diag_kernel,
|
|
481
|
-
dim=A.nnz,
|
|
482
|
-
device=A.values.device,
|
|
483
|
-
inputs=[diag, A.offsets, A.columns, A.values],
|
|
484
|
-
)
|
|
485
|
-
else:
|
|
486
|
-
if not warp.types.type_is_value(type(diag)):
|
|
487
|
-
# Cast to launchable type
|
|
488
|
-
diag = A.values.dtype(diag)
|
|
489
|
-
wp.launch(
|
|
490
|
-
kernel=_bsr_set_diag_constant_kernel,
|
|
491
|
-
dim=A.nnz,
|
|
492
|
-
device=A.values.device,
|
|
493
|
-
inputs=[diag, A.offsets, A.columns, A.values],
|
|
494
|
-
)
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
def bsr_diag(
|
|
498
|
-
diag: "Union[BlockType, Array[BlockType]]",
|
|
499
|
-
rows_of_blocks: Optional[int] = None,
|
|
500
|
-
cols_of_blocks: Optional[int] = None,
|
|
501
|
-
) -> BsrMatrix["BlockType"]:
|
|
502
|
-
"""Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
|
|
503
|
-
|
|
504
|
-
Args:
|
|
505
|
-
diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
|
|
506
|
-
or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
|
|
507
|
-
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
508
|
-
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
509
|
-
|
|
510
|
-
The shape of the matrix will be defined one of the following, in that order:
|
|
511
|
-
- `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
|
|
512
|
-
- the first dimension of `diag`, if `diag` is an array
|
|
513
|
-
"""
|
|
514
|
-
|
|
515
|
-
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
516
|
-
rows_of_blocks = cols_of_blocks
|
|
517
|
-
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
518
|
-
cols_of_blocks = rows_of_blocks
|
|
519
|
-
|
|
520
|
-
if warp.types.is_array(diag):
|
|
521
|
-
if rows_of_blocks is None:
|
|
522
|
-
rows_of_blocks = diag.shape[0]
|
|
523
|
-
cols_of_blocks = diag.shape[0]
|
|
524
|
-
|
|
525
|
-
A = bsr_zeros(
|
|
526
|
-
rows_of_blocks,
|
|
527
|
-
cols_of_blocks,
|
|
528
|
-
block_type=diag.dtype,
|
|
529
|
-
device=diag.device,
|
|
530
|
-
)
|
|
531
|
-
else:
|
|
532
|
-
if rows_of_blocks is None:
|
|
533
|
-
raise ValueError(
|
|
534
|
-
"rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
block_type = type(diag)
|
|
538
|
-
if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
|
|
539
|
-
block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
|
|
540
|
-
|
|
541
|
-
A = bsr_zeros(
|
|
542
|
-
rows_of_blocks,
|
|
543
|
-
cols_of_blocks,
|
|
544
|
-
block_type=block_type,
|
|
545
|
-
)
|
|
546
|
-
|
|
547
|
-
bsr_set_diag(A, diag)
|
|
548
|
-
return A
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
|
|
552
|
-
"""Sets `A` as the identity matrix
|
|
553
|
-
|
|
554
|
-
Args:
|
|
555
|
-
A: the sparse matrix to modify
|
|
556
|
-
rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
|
|
557
|
-
"""
|
|
558
|
-
|
|
559
|
-
if A.block_shape == (1, 1):
|
|
560
|
-
identity = A.scalar_type(1.0)
|
|
561
|
-
else:
|
|
562
|
-
from numpy import eye
|
|
563
|
-
|
|
564
|
-
identity = eye(A.block_shape[0])
|
|
565
|
-
|
|
566
|
-
bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
def bsr_identity(
|
|
570
|
-
rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
|
|
571
|
-
) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
|
|
572
|
-
"""Creates and returns a square identity matrix.
|
|
573
|
-
|
|
574
|
-
Args:
|
|
575
|
-
rows_of_blocks: Number of rows and columns of blocks in the created matrix.
|
|
576
|
-
block_type: Block type for the newly created matrix -- must be square
|
|
577
|
-
device: Device onto which to allocate the data arrays
|
|
578
|
-
"""
|
|
579
|
-
A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
|
|
580
|
-
bsr_set_identity(A)
|
|
581
|
-
return A
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
@wp.kernel
|
|
585
|
-
def _bsr_scale_kernel(
|
|
586
|
-
alpha: Any,
|
|
587
|
-
values: wp.array(dtype=Any),
|
|
588
|
-
):
|
|
589
|
-
values[wp.tid()] = alpha * values[wp.tid()]
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
|
|
593
|
-
"""
|
|
594
|
-
Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
|
|
595
|
-
"""
|
|
596
|
-
|
|
597
|
-
if alpha != 1.0 and x.nnz > 0:
|
|
598
|
-
if alpha == 0.0:
|
|
599
|
-
bsr_set_zero(x)
|
|
600
|
-
else:
|
|
601
|
-
if not isinstance(alpha, x.scalar_type):
|
|
602
|
-
alpha = x.scalar_type(alpha)
|
|
603
|
-
|
|
604
|
-
wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
|
|
605
|
-
|
|
606
|
-
return x
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
@wp.kernel
|
|
610
|
-
def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
|
|
611
|
-
i = wp.tid()
|
|
612
|
-
|
|
613
|
-
row = wp.lower_bound(bsr_offsets, i + 1) - 1
|
|
614
|
-
rows[dest_offset + i] = row
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
@wp.kernel
|
|
618
|
-
def _bsr_axpy_add_block(
|
|
619
|
-
src_offset: int,
|
|
620
|
-
scale: Any,
|
|
621
|
-
rows: wp.array(dtype=int),
|
|
622
|
-
cols: wp.array(dtype=int),
|
|
623
|
-
dst_offsets: wp.array(dtype=int),
|
|
624
|
-
dst_columns: wp.array(dtype=int),
|
|
625
|
-
src_values: wp.array(dtype=Any),
|
|
626
|
-
dst_values: wp.array(dtype=Any),
|
|
627
|
-
):
|
|
628
|
-
i = wp.tid()
|
|
629
|
-
row = rows[i + src_offset]
|
|
630
|
-
col = cols[i + src_offset]
|
|
631
|
-
beg = dst_offsets[row]
|
|
632
|
-
end = dst_offsets[row + 1]
|
|
633
|
-
|
|
634
|
-
block = wp.lower_bound(dst_columns, beg, end, col)
|
|
635
|
-
|
|
636
|
-
dst_values[block] = dst_values[block] + scale * src_values[i]
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
class bsr_axpy_work_arrays:
|
|
640
|
-
"""Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
|
|
641
|
-
|
|
642
|
-
def __init__(self):
|
|
643
|
-
self._reset(None)
|
|
644
|
-
|
|
645
|
-
def _reset(self, device):
|
|
646
|
-
self.device = device
|
|
647
|
-
self._sum_rows = None
|
|
648
|
-
self._sum_cols = None
|
|
649
|
-
self._old_y_values = None
|
|
650
|
-
self._old_x_values = None
|
|
651
|
-
|
|
652
|
-
def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
|
|
653
|
-
if self.device != device:
|
|
654
|
-
self._reset(device)
|
|
655
|
-
|
|
656
|
-
if self._sum_rows is None or self._sum_rows.size < sum_nnz:
|
|
657
|
-
self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
658
|
-
if self._sum_cols is None or self._sum_cols.size < sum_nnz:
|
|
659
|
-
self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
660
|
-
|
|
661
|
-
if self._old_y_values is None or self._old_y_values.size < y.nnz:
|
|
662
|
-
self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
def bsr_axpy(
|
|
666
|
-
x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
667
|
-
y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
668
|
-
alpha: Scalar = 1.0,
|
|
669
|
-
beta: Scalar = 1.0,
|
|
670
|
-
work_arrays: Optional[bsr_axpy_work_arrays] = None,
|
|
671
|
-
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
672
|
-
"""
|
|
673
|
-
Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
|
|
674
|
-
|
|
675
|
-
The `x` and `y` matrices are allowed to alias.
|
|
676
|
-
|
|
677
|
-
Args:
|
|
678
|
-
x: Read-only right-hand-side.
|
|
679
|
-
y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
|
|
680
|
-
alpha: Uniform scaling factor for `x`
|
|
681
|
-
beta: Uniform scaling factor for `y`
|
|
682
|
-
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
|
|
683
|
-
"""
|
|
684
|
-
|
|
685
|
-
if y is None:
|
|
686
|
-
# If not output matrix is provided, allocate it for convenience
|
|
687
|
-
y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
|
|
688
|
-
beta = 0.0
|
|
689
|
-
|
|
690
|
-
# Handle easy cases first
|
|
691
|
-
if beta == 0.0 or y.nnz == 0:
|
|
692
|
-
bsr_assign(src=x, dest=y)
|
|
693
|
-
return bsr_scale(y, alpha=alpha)
|
|
694
|
-
|
|
695
|
-
if alpha == 0.0 or x.nnz == 0:
|
|
696
|
-
return bsr_scale(y, alpha=beta)
|
|
697
|
-
|
|
698
|
-
if not isinstance(alpha, y.scalar_type):
|
|
699
|
-
alpha = y.scalar_type(alpha)
|
|
700
|
-
if not isinstance(beta, y.scalar_type):
|
|
701
|
-
beta = y.scalar_type(beta)
|
|
702
|
-
|
|
703
|
-
if x == y:
|
|
704
|
-
# Aliasing case
|
|
705
|
-
return bsr_scale(y, alpha=alpha.value + beta.value)
|
|
706
|
-
|
|
707
|
-
# General case
|
|
708
|
-
|
|
709
|
-
if x.values.device != y.values.device:
|
|
710
|
-
raise ValueError("All arguments must reside on the same device")
|
|
711
|
-
|
|
712
|
-
if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
|
|
713
|
-
raise ValueError("Matrices must have the same block type")
|
|
714
|
-
|
|
715
|
-
if x.nrow != y.nrow or x.ncol != y.ncol:
|
|
716
|
-
raise ValueError("Matrices must have the same number of rows and columns")
|
|
717
|
-
|
|
718
|
-
if work_arrays is None:
|
|
719
|
-
work_arrays = bsr_axpy_work_arrays()
|
|
720
|
-
|
|
721
|
-
sum_nnz = x.nnz + y.nnz
|
|
722
|
-
device = y.values.device
|
|
723
|
-
work_arrays._allocate(device, y, sum_nnz)
|
|
724
|
-
|
|
725
|
-
wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
|
|
726
|
-
wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
|
|
727
|
-
|
|
728
|
-
wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
|
|
729
|
-
wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
|
|
730
|
-
|
|
731
|
-
# Save old y values before overwriting matrix
|
|
732
|
-
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
|
|
733
|
-
|
|
734
|
-
# Increase dest array sizes if needed
|
|
735
|
-
if y.columns.shape[0] < sum_nnz:
|
|
736
|
-
y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device)
|
|
737
|
-
|
|
738
|
-
from warp.context import runtime
|
|
739
|
-
|
|
740
|
-
if device.is_cpu:
|
|
741
|
-
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
742
|
-
else:
|
|
743
|
-
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
744
|
-
|
|
745
|
-
old_y_nnz = y.nnz
|
|
746
|
-
y.nnz = native_func(
|
|
747
|
-
y.block_shape[0],
|
|
748
|
-
y.block_shape[1],
|
|
749
|
-
y.nrow,
|
|
750
|
-
sum_nnz,
|
|
751
|
-
work_arrays._sum_rows.ptr,
|
|
752
|
-
work_arrays._sum_cols.ptr,
|
|
753
|
-
0,
|
|
754
|
-
y.offsets.ptr,
|
|
755
|
-
y.columns.ptr,
|
|
756
|
-
0,
|
|
757
|
-
)
|
|
758
|
-
|
|
759
|
-
_bsr_ensure_fits(y)
|
|
760
|
-
y.values.zero_()
|
|
761
|
-
|
|
762
|
-
wp.launch(
|
|
763
|
-
kernel=_bsr_axpy_add_block,
|
|
764
|
-
device=device,
|
|
765
|
-
dim=old_y_nnz,
|
|
766
|
-
inputs=[
|
|
767
|
-
0,
|
|
768
|
-
beta,
|
|
769
|
-
work_arrays._sum_rows,
|
|
770
|
-
work_arrays._sum_cols,
|
|
771
|
-
y.offsets,
|
|
772
|
-
y.columns,
|
|
773
|
-
work_arrays._old_y_values,
|
|
774
|
-
y.values,
|
|
775
|
-
],
|
|
776
|
-
)
|
|
777
|
-
|
|
778
|
-
wp.launch(
|
|
779
|
-
kernel=_bsr_axpy_add_block,
|
|
780
|
-
device=device,
|
|
781
|
-
dim=x.nnz,
|
|
782
|
-
inputs=[
|
|
783
|
-
old_y_nnz,
|
|
784
|
-
alpha,
|
|
785
|
-
work_arrays._sum_rows,
|
|
786
|
-
work_arrays._sum_cols,
|
|
787
|
-
y.offsets,
|
|
788
|
-
y.columns,
|
|
789
|
-
x.values,
|
|
790
|
-
y.values,
|
|
791
|
-
],
|
|
792
|
-
)
|
|
793
|
-
|
|
794
|
-
return y
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
@wp.kernel
|
|
798
|
-
def _bsr_mm_count_coeffs(
|
|
799
|
-
z_nnz: int,
|
|
800
|
-
x_offsets: wp.array(dtype=int),
|
|
801
|
-
x_columns: wp.array(dtype=int),
|
|
802
|
-
y_offsets: wp.array(dtype=int),
|
|
803
|
-
counts: wp.array(dtype=int),
|
|
804
|
-
):
|
|
805
|
-
row = wp.tid()
|
|
806
|
-
count = int(0)
|
|
807
|
-
|
|
808
|
-
x_beg = x_offsets[row]
|
|
809
|
-
x_end = x_offsets[row + 1]
|
|
810
|
-
|
|
811
|
-
for x_block in range(x_beg, x_end):
|
|
812
|
-
x_col = x_columns[x_block]
|
|
813
|
-
count += y_offsets[x_col + 1] - y_offsets[x_col]
|
|
814
|
-
|
|
815
|
-
counts[row + 1] = count
|
|
816
|
-
|
|
817
|
-
if row == 0:
|
|
818
|
-
counts[0] = z_nnz
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
@wp.kernel
|
|
822
|
-
def _bsr_mm_list_coeffs(
|
|
823
|
-
x_offsets: wp.array(dtype=int),
|
|
824
|
-
x_columns: wp.array(dtype=int),
|
|
825
|
-
y_offsets: wp.array(dtype=int),
|
|
826
|
-
y_columns: wp.array(dtype=int),
|
|
827
|
-
mm_offsets: wp.array(dtype=int),
|
|
828
|
-
mm_rows: wp.array(dtype=int),
|
|
829
|
-
mm_cols: wp.array(dtype=int),
|
|
830
|
-
):
|
|
831
|
-
row = wp.tid()
|
|
832
|
-
mm_block = mm_offsets[row]
|
|
833
|
-
|
|
834
|
-
x_beg = x_offsets[row]
|
|
835
|
-
x_end = x_offsets[row + 1]
|
|
836
|
-
|
|
837
|
-
for x_block in range(x_beg, x_end):
|
|
838
|
-
x_col = x_columns[x_block]
|
|
839
|
-
|
|
840
|
-
y_beg = y_offsets[x_col]
|
|
841
|
-
y_end = y_offsets[x_col + 1]
|
|
842
|
-
for y_block in range(y_beg, y_end):
|
|
843
|
-
mm_cols[mm_block] = y_columns[y_block]
|
|
844
|
-
mm_rows[mm_block] = row
|
|
845
|
-
mm_block += 1
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
@wp.kernel
|
|
849
|
-
def _bsr_mm_compute_values(
|
|
850
|
-
alpha: Any,
|
|
851
|
-
x_offsets: wp.array(dtype=int),
|
|
852
|
-
x_columns: wp.array(dtype=int),
|
|
853
|
-
x_values: wp.array(dtype=Any),
|
|
854
|
-
y_offsets: wp.array(dtype=int),
|
|
855
|
-
y_columns: wp.array(dtype=int),
|
|
856
|
-
y_values: wp.array(dtype=Any),
|
|
857
|
-
mm_offsets: wp.array(dtype=int),
|
|
858
|
-
mm_cols: wp.array(dtype=int),
|
|
859
|
-
mm_values: wp.array(dtype=Any),
|
|
860
|
-
):
|
|
861
|
-
row = wp.tid()
|
|
862
|
-
mm_beg = mm_offsets[row]
|
|
863
|
-
mm_end = mm_offsets[row + 1]
|
|
864
|
-
|
|
865
|
-
x_beg = x_offsets[row]
|
|
866
|
-
x_end = x_offsets[row + 1]
|
|
867
|
-
for x_block in range(x_beg, x_end):
|
|
868
|
-
x_col = x_columns[x_block]
|
|
869
|
-
ax_val = alpha * x_values[x_block]
|
|
870
|
-
|
|
871
|
-
y_beg = y_offsets[x_col]
|
|
872
|
-
y_end = y_offsets[x_col + 1]
|
|
873
|
-
|
|
874
|
-
for y_block in range(y_beg, y_end):
|
|
875
|
-
mm_block = wp.lower_bound(mm_cols, mm_beg, mm_end, y_columns[y_block])
|
|
876
|
-
mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
class bsr_mm_work_arrays:
|
|
880
|
-
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
|
|
881
|
-
|
|
882
|
-
def __init__(self):
|
|
883
|
-
self._reset(None)
|
|
884
|
-
|
|
885
|
-
def _reset(self, device):
|
|
886
|
-
self.device = device
|
|
887
|
-
self._pinned_count_buffer = None
|
|
888
|
-
self._mm_row_counts = None
|
|
889
|
-
self._mm_rows = None
|
|
890
|
-
self._mm_cols = None
|
|
891
|
-
self._old_z_values = None
|
|
892
|
-
self._old_z_offsets = None
|
|
893
|
-
self._old_z_columns = None
|
|
894
|
-
|
|
895
|
-
def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
|
|
896
|
-
if self.device != device:
|
|
897
|
-
self._reset(device)
|
|
898
|
-
|
|
899
|
-
# Allocations that do not depend on any computation
|
|
900
|
-
if self.device.is_cuda:
|
|
901
|
-
if self._pinned_count_buffer is None:
|
|
902
|
-
self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
|
|
903
|
-
|
|
904
|
-
if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
|
|
905
|
-
self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
|
|
906
|
-
|
|
907
|
-
if copied_z_nnz > 0:
|
|
908
|
-
if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
|
|
909
|
-
self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
|
|
910
|
-
|
|
911
|
-
if z_aliasing:
|
|
912
|
-
if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
|
|
913
|
-
self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
|
|
914
|
-
if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
|
|
915
|
-
self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
|
|
916
|
-
|
|
917
|
-
def _allocate_stage_2(self, mm_nnz: int):
|
|
918
|
-
# Allocations that depend on unmerged nnz estimate
|
|
919
|
-
if self._mm_rows is None or self._mm_rows.size < mm_nnz:
|
|
920
|
-
self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
921
|
-
if self._mm_cols is None or self._mm_cols.size < mm_nnz:
|
|
922
|
-
self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
def bsr_mm(
|
|
926
|
-
x: BsrMatrix[BlockType[Rows, Any, Scalar]],
|
|
927
|
-
y: BsrMatrix[BlockType[Any, Cols, Scalar]],
|
|
928
|
-
z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
929
|
-
alpha: Scalar = 1.0,
|
|
930
|
-
beta: Scalar = 0.0,
|
|
931
|
-
work_arrays: Optional[bsr_mm_work_arrays] = None,
|
|
932
|
-
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
933
|
-
"""
|
|
934
|
-
Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
|
|
935
|
-
|
|
936
|
-
The `x`, `y` and `z` matrices are allowed to alias.
|
|
937
|
-
If the matrix `z` is not provided as input, it will be allocated and treated as zero.
|
|
938
|
-
|
|
939
|
-
Args:
|
|
940
|
-
x: Read-only left factor of the matrix-matrix product.
|
|
941
|
-
y: Read-only right factor of the matrix-matrix product.
|
|
942
|
-
z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
|
|
943
|
-
alpha: Uniform scaling factor for the ``x * y`` product
|
|
944
|
-
beta: Uniform scaling factor for `z`
|
|
945
|
-
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
|
|
946
|
-
"""
|
|
947
|
-
|
|
948
|
-
if z is None:
|
|
949
|
-
# If not output matrix is provided, allocate it for convenience
|
|
950
|
-
z_block_shape = (x.block_shape[0], y.block_shape[1])
|
|
951
|
-
if z_block_shape == (1, 1):
|
|
952
|
-
z_block_type = x.scalar_type
|
|
953
|
-
else:
|
|
954
|
-
z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type)
|
|
955
|
-
z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
|
|
956
|
-
beta = 0.0
|
|
957
|
-
|
|
958
|
-
if x.values.device != y.values.device or x.values.device != z.values.device:
|
|
959
|
-
raise ValueError("All arguments must reside on the same device")
|
|
960
|
-
|
|
961
|
-
if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
|
|
962
|
-
raise ValueError("Matrices must have the same scalar type")
|
|
963
|
-
|
|
964
|
-
if (
|
|
965
|
-
x.block_shape[0] != z.block_shape[0]
|
|
966
|
-
or y.block_shape[1] != z.block_shape[1]
|
|
967
|
-
or x.block_shape[1] != y.block_shape[0]
|
|
968
|
-
):
|
|
969
|
-
raise ValueError("Incompatible block sizes for matrix multiplication")
|
|
970
|
-
|
|
971
|
-
if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
|
|
972
|
-
raise ValueError("Incompatible number of rows/columns for matrix multiplication")
|
|
973
|
-
|
|
974
|
-
device = z.values.device
|
|
975
|
-
|
|
976
|
-
if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
|
|
977
|
-
# Easy case
|
|
978
|
-
return bsr_scale(z, beta)
|
|
979
|
-
|
|
980
|
-
if not isinstance(alpha, z.scalar_type):
|
|
981
|
-
alpha = z.scalar_type(alpha)
|
|
982
|
-
if not isinstance(beta, z.scalar_type):
|
|
983
|
-
beta = z.scalar_type(beta)
|
|
984
|
-
|
|
985
|
-
if work_arrays is None:
|
|
986
|
-
work_arrays = bsr_mm_work_arrays()
|
|
987
|
-
|
|
988
|
-
z_aliasing = z == x or z == y
|
|
989
|
-
copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
|
|
990
|
-
|
|
991
|
-
work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
|
|
992
|
-
|
|
993
|
-
# Prefix sum of number of (unmerged) mm blocks per row
|
|
994
|
-
wp.launch(
|
|
995
|
-
kernel=_bsr_mm_count_coeffs,
|
|
996
|
-
device=device,
|
|
997
|
-
dim=z.nrow,
|
|
998
|
-
inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
|
|
999
|
-
)
|
|
1000
|
-
warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
|
|
1001
|
-
|
|
1002
|
-
# Get back total counts on host
|
|
1003
|
-
if device.is_cuda:
|
|
1004
|
-
wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
|
|
1005
|
-
wp.synchronize_stream(wp.get_stream(device))
|
|
1006
|
-
mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
|
|
1007
|
-
else:
|
|
1008
|
-
mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
|
|
1009
|
-
|
|
1010
|
-
work_arrays._allocate_stage_2(mm_nnz)
|
|
1011
|
-
|
|
1012
|
-
# If z has a non-zero scale, save current data before overwriting it
|
|
1013
|
-
if copied_z_nnz > 0:
|
|
1014
|
-
# Copy z row and column indices
|
|
1015
|
-
wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
|
|
1016
|
-
wp.launch(
|
|
1017
|
-
kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
|
|
1018
|
-
)
|
|
1019
|
-
# Save current z values in temporary buffer
|
|
1020
|
-
wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
|
|
1021
|
-
if z_aliasing:
|
|
1022
|
-
# If z is aliasing with x or y, need to save topology as well
|
|
1023
|
-
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
1024
|
-
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
1025
|
-
|
|
1026
|
-
# Fill unmerged mm blocks rows and columns
|
|
1027
|
-
wp.launch(
|
|
1028
|
-
kernel=_bsr_mm_list_coeffs,
|
|
1029
|
-
device=device,
|
|
1030
|
-
dim=z.nrow,
|
|
1031
|
-
inputs=[
|
|
1032
|
-
x.offsets,
|
|
1033
|
-
x.columns,
|
|
1034
|
-
y.offsets,
|
|
1035
|
-
y.columns,
|
|
1036
|
-
work_arrays._mm_row_counts,
|
|
1037
|
-
work_arrays._mm_rows,
|
|
1038
|
-
work_arrays._mm_cols,
|
|
1039
|
-
],
|
|
1040
|
-
)
|
|
1041
|
-
|
|
1042
|
-
# Increase dest array size if needed
|
|
1043
|
-
if z.columns.shape[0] < mm_nnz:
|
|
1044
|
-
z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
|
|
1045
|
-
|
|
1046
|
-
from warp.context import runtime
|
|
1047
|
-
|
|
1048
|
-
if device.is_cpu:
|
|
1049
|
-
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
1050
|
-
else:
|
|
1051
|
-
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
1052
|
-
|
|
1053
|
-
z.nnz = native_func(
|
|
1054
|
-
z.block_shape[0],
|
|
1055
|
-
z.block_shape[1],
|
|
1056
|
-
z.nrow,
|
|
1057
|
-
mm_nnz,
|
|
1058
|
-
work_arrays._mm_rows.ptr,
|
|
1059
|
-
work_arrays._mm_cols.ptr,
|
|
1060
|
-
0,
|
|
1061
|
-
z.offsets.ptr,
|
|
1062
|
-
z.columns.ptr,
|
|
1063
|
-
0,
|
|
1064
|
-
)
|
|
1065
|
-
|
|
1066
|
-
_bsr_ensure_fits(z)
|
|
1067
|
-
z.values.zero_()
|
|
1068
|
-
|
|
1069
|
-
if copied_z_nnz > 0:
|
|
1070
|
-
# Add back original z values
|
|
1071
|
-
wp.launch(
|
|
1072
|
-
kernel=_bsr_axpy_add_block,
|
|
1073
|
-
device=device,
|
|
1074
|
-
dim=copied_z_nnz,
|
|
1075
|
-
inputs=[
|
|
1076
|
-
0,
|
|
1077
|
-
beta,
|
|
1078
|
-
work_arrays._mm_rows,
|
|
1079
|
-
work_arrays._mm_cols,
|
|
1080
|
-
z.offsets,
|
|
1081
|
-
z.columns,
|
|
1082
|
-
work_arrays._old_z_values,
|
|
1083
|
-
z.values,
|
|
1084
|
-
],
|
|
1085
|
-
)
|
|
1086
|
-
|
|
1087
|
-
# Add mm blocks to z values
|
|
1088
|
-
if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
|
|
1089
|
-
warp.types.type_is_matrix(z.values.dtype)
|
|
1090
|
-
):
|
|
1091
|
-
# Result block type is scalar, but operands are matrices
|
|
1092
|
-
# Cast result to (1x1) matrix to perform multiplication
|
|
1093
|
-
mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
|
|
1094
|
-
else:
|
|
1095
|
-
mm_values = z.values
|
|
1096
|
-
|
|
1097
|
-
wp.launch(
|
|
1098
|
-
kernel=_bsr_mm_compute_values,
|
|
1099
|
-
device=device,
|
|
1100
|
-
dim=z.nrow,
|
|
1101
|
-
inputs=[
|
|
1102
|
-
alpha,
|
|
1103
|
-
work_arrays._old_z_offsets if x == z else x.offsets,
|
|
1104
|
-
work_arrays._old_z_columns if x == z else x.columns,
|
|
1105
|
-
work_arrays._old_z_values if x == z else x.values,
|
|
1106
|
-
work_arrays._old_z_offsets if y == z else y.offsets,
|
|
1107
|
-
work_arrays._old_z_columns if y == z else y.columns,
|
|
1108
|
-
work_arrays._old_z_values if y == z else y.values,
|
|
1109
|
-
z.offsets,
|
|
1110
|
-
z.columns,
|
|
1111
|
-
mm_values,
|
|
1112
|
-
],
|
|
1113
|
-
)
|
|
1114
|
-
|
|
1115
|
-
return z
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
@wp.kernel
|
|
1119
|
-
def _bsr_mv_kernel(
|
|
1120
|
-
alpha: Any,
|
|
1121
|
-
A_offsets: wp.array(dtype=int),
|
|
1122
|
-
A_columns: wp.array(dtype=int),
|
|
1123
|
-
A_values: wp.array(dtype=Any),
|
|
1124
|
-
x: wp.array(dtype=Any),
|
|
1125
|
-
beta: Any,
|
|
1126
|
-
y: wp.array(dtype=Any),
|
|
1127
|
-
):
|
|
1128
|
-
row = wp.tid()
|
|
1129
|
-
|
|
1130
|
-
# zero-initialize with type of y elements
|
|
1131
|
-
scalar_zero = type(alpha)(0)
|
|
1132
|
-
v = y.dtype(scalar_zero)
|
|
1133
|
-
|
|
1134
|
-
if alpha != scalar_zero:
|
|
1135
|
-
beg = A_offsets[row]
|
|
1136
|
-
end = A_offsets[row + 1]
|
|
1137
|
-
for block in range(beg, end):
|
|
1138
|
-
v += A_values[block] * x[A_columns[block]]
|
|
1139
|
-
v *= alpha
|
|
1140
|
-
|
|
1141
|
-
if beta != scalar_zero:
|
|
1142
|
-
v += beta * y[row]
|
|
1143
|
-
|
|
1144
|
-
y[row] = v
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
def bsr_mv(
|
|
1148
|
-
A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
1149
|
-
x: "Array[Vector[Cols, Scalar] | Scalar]",
|
|
1150
|
-
y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1151
|
-
alpha: Scalar = 1.0,
|
|
1152
|
-
beta: Scalar = 0.0,
|
|
1153
|
-
work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1154
|
-
) -> "Array[Vector[Rows, Scalar] | Scalar]":
|
|
1155
|
-
"""
|
|
1156
|
-
Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
|
|
1157
|
-
|
|
1158
|
-
The `x` and `y` vectors are allowed to alias.
|
|
1159
|
-
|
|
1160
|
-
Args:
|
|
1161
|
-
A: Read-only, left matrix factor of the matrix-vector product.
|
|
1162
|
-
x: Read-only, right vector factor of the matrix-vector product.
|
|
1163
|
-
y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
|
|
1164
|
-
alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
|
|
1165
|
-
beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
|
|
1166
|
-
work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
|
|
1167
|
-
will be used for this purpose, otherwise a temporary allocation will be performed.
|
|
1168
|
-
"""
|
|
1169
|
-
|
|
1170
|
-
if y is None:
|
|
1171
|
-
# If no output array is provided, allocate one for convenience
|
|
1172
|
-
y_vec_len = A.block_shape[0]
|
|
1173
|
-
y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
|
|
1174
|
-
y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
|
|
1175
|
-
y.zero_()
|
|
1176
|
-
beta = 0.0
|
|
1177
|
-
|
|
1178
|
-
if not isinstance(alpha, A.scalar_type):
|
|
1179
|
-
alpha = A.scalar_type(alpha)
|
|
1180
|
-
if not isinstance(beta, A.scalar_type):
|
|
1181
|
-
beta = A.scalar_type(beta)
|
|
1182
|
-
|
|
1183
|
-
if A.values.device != x.device or A.values.device != y.device:
|
|
1184
|
-
raise ValueError("A, x and y must reside on the same device")
|
|
1185
|
-
|
|
1186
|
-
if x.shape[0] != A.ncol:
|
|
1187
|
-
raise ValueError("Number of columns of A must match number of rows of x")
|
|
1188
|
-
if y.shape[0] != A.nrow:
|
|
1189
|
-
raise ValueError("Number of rows of A must match number of rows of y")
|
|
1190
|
-
|
|
1191
|
-
if x == y:
|
|
1192
|
-
# Aliasing case, need temporary storage
|
|
1193
|
-
if work_buffer is None:
|
|
1194
|
-
work_buffer = wp.empty_like(y)
|
|
1195
|
-
elif work_buffer.size < y.size:
|
|
1196
|
-
raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
|
|
1197
|
-
elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
|
|
1198
|
-
raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
|
|
1199
|
-
|
|
1200
|
-
# Save old y values before overwriting vector
|
|
1201
|
-
wp.copy(dest=work_buffer, src=y, count=y.size)
|
|
1202
|
-
x = work_buffer
|
|
1203
|
-
|
|
1204
|
-
# Promote scalar vectors to length-1 vecs and conversely
|
|
1205
|
-
if warp.types.type_is_matrix(A.values.dtype):
|
|
1206
|
-
if A.block_shape[0] == 1:
|
|
1207
|
-
if y.dtype == A.scalar_type:
|
|
1208
|
-
y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
1209
|
-
if A.block_shape[1] == 1:
|
|
1210
|
-
if x.dtype == A.scalar_type:
|
|
1211
|
-
x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
1212
|
-
else:
|
|
1213
|
-
if A.block_shape[0] == 1:
|
|
1214
|
-
if y.dtype != A.scalar_type:
|
|
1215
|
-
y = y.view(dtype=A.scalar_type)
|
|
1216
|
-
if A.block_shape[1] == 1:
|
|
1217
|
-
if x.dtype != A.scalar_type:
|
|
1218
|
-
x = x.view(dtype=A.scalar_type)
|
|
1219
|
-
|
|
1220
|
-
wp.launch(
|
|
1221
|
-
kernel=_bsr_mv_kernel,
|
|
1222
|
-
device=A.values.device,
|
|
1223
|
-
dim=A.nrow,
|
|
1224
|
-
inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
|
|
1225
|
-
)
|
|
1226
|
-
|
|
1227
|
-
return y
|
|
1
|
+
from typing import Any, Generic, Optional, Tuple, TypeVar, Union
|
|
2
|
+
|
|
3
|
+
import warp as wp
|
|
4
|
+
import warp.types
|
|
5
|
+
import warp.utils
|
|
6
|
+
from warp.types import Array, Cols, Rows, Scalar, Vector
|
|
7
|
+
|
|
8
|
+
# typing hints
|
|
9
|
+
|
|
10
|
+
_BlockType = TypeVar("BlockType")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _MatrixBlockType(Generic[Rows, Cols, Scalar]):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class _ScalarBlockType(Generic[Scalar]):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
|
|
22
|
+
|
|
23
|
+
_struct_cache = {}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BsrMatrix(Generic[_BlockType]):
|
|
27
|
+
"""Untyped base class for BSR and CSR matrices.
|
|
28
|
+
|
|
29
|
+
Should not be constructed directly but through functions such as :func:`bsr_zeros`.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
nrow (int): Number of rows of blocks
|
|
33
|
+
ncol (int): Number of columns of blocks
|
|
34
|
+
nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
|
|
35
|
+
offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
|
|
36
|
+
columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
|
|
37
|
+
values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def scalar_type(self) -> Scalar:
|
|
42
|
+
"""Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
|
|
43
|
+
return warp.types.type_scalar_type(self.values.dtype)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def block_shape(self) -> Tuple[int, int]:
|
|
47
|
+
"""Shape of the individual blocks"""
|
|
48
|
+
return getattr(self.values.dtype, "_shape_", (1, 1))
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def block_size(self) -> int:
|
|
52
|
+
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
|
|
53
|
+
return warp.types.type_length(self.values.dtype)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def shape(self) -> Tuple[int, int]:
|
|
57
|
+
"""Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
|
|
58
|
+
block_shape = self.block_shape
|
|
59
|
+
return (self.nrow * block_shape[0], self.ncol * block_shape[1])
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def dtype(self) -> type:
|
|
63
|
+
"""Data type for individual block values"""
|
|
64
|
+
return self.values.dtype
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def device(self) -> wp.context.Device:
|
|
68
|
+
"""Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays"""
|
|
69
|
+
return self.values.device
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def bsr_matrix_t(dtype: BlockType):
|
|
73
|
+
dtype = wp.types.type_to_warp(dtype)
|
|
74
|
+
|
|
75
|
+
if not warp.types.type_is_matrix(dtype) and dtype not in warp.types.scalar_types:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
class BsrMatrixTyped(BsrMatrix):
|
|
81
|
+
nrow: int
|
|
82
|
+
"""Number of rows of blocks"""
|
|
83
|
+
ncol: int
|
|
84
|
+
"""Number of columns of blocks"""
|
|
85
|
+
nnz: int
|
|
86
|
+
"""Number of non-zero blocks: equal to offsets[-1], cached on host for convenience"""
|
|
87
|
+
offsets: wp.array(dtype=int)
|
|
88
|
+
"""Array of size at least 1 + nrows"""
|
|
89
|
+
columns: wp.array(dtype=int)
|
|
90
|
+
"""Array of size at least equal to nnz"""
|
|
91
|
+
values: wp.array(dtype=dtype)
|
|
92
|
+
|
|
93
|
+
module = wp.get_module(BsrMatrix.__module__)
|
|
94
|
+
|
|
95
|
+
if hasattr(dtype, "_shape_"):
|
|
96
|
+
type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
|
|
97
|
+
else:
|
|
98
|
+
type_str = dtype.__name__
|
|
99
|
+
key = f"{BsrMatrix.__qualname__}_{type_str}"
|
|
100
|
+
|
|
101
|
+
if key not in _struct_cache:
|
|
102
|
+
_struct_cache[key] = wp.codegen.Struct(
|
|
103
|
+
cls=BsrMatrixTyped,
|
|
104
|
+
key=key,
|
|
105
|
+
module=module,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return _struct_cache[key]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def bsr_zeros(
|
|
112
|
+
rows_of_blocks: int,
|
|
113
|
+
cols_of_blocks: int,
|
|
114
|
+
block_type: BlockType,
|
|
115
|
+
device: wp.context.Devicelike = None,
|
|
116
|
+
) -> BsrMatrix:
|
|
117
|
+
"""
|
|
118
|
+
Constructs and returns an empty BSR or CSR matrix with the given shape
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
bsr: The BSR or CSR matrix to set to zero
|
|
122
|
+
rows_of_blocks: Number of rows of blocks
|
|
123
|
+
cols_of_blocks: Number of columns of blocks
|
|
124
|
+
block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
|
|
125
|
+
for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
|
|
126
|
+
device: Device on which to allocate the matrix arrays
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
bsr = bsr_matrix_t(block_type)()
|
|
130
|
+
|
|
131
|
+
bsr.nrow = rows_of_blocks
|
|
132
|
+
bsr.ncol = cols_of_blocks
|
|
133
|
+
bsr.nnz = 0
|
|
134
|
+
bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
|
|
135
|
+
bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
|
|
136
|
+
bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
|
|
137
|
+
|
|
138
|
+
return bsr
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
|
|
142
|
+
if nrow is None:
|
|
143
|
+
nrow = bsr.nrow
|
|
144
|
+
if nnz is None:
|
|
145
|
+
nnz = bsr.nnz
|
|
146
|
+
|
|
147
|
+
if bsr.offsets.size < nrow + 1:
|
|
148
|
+
bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
|
|
149
|
+
if bsr.columns.size < nnz:
|
|
150
|
+
bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
|
|
151
|
+
if bsr.values.size < nnz:
|
|
152
|
+
bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
|
|
156
|
+
"""
|
|
157
|
+
Sets a BSR matrix to zero, possibly changing its size
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
bsr: The BSR or CSR matrix to set to zero
|
|
161
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
162
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
if rows_of_blocks is not None:
|
|
166
|
+
bsr.nrow = rows_of_blocks
|
|
167
|
+
if cols_of_blocks is not None:
|
|
168
|
+
bsr.ncol = cols_of_blocks
|
|
169
|
+
bsr.nnz = 0
|
|
170
|
+
_bsr_ensure_fits(bsr)
|
|
171
|
+
bsr.offsets.zero_()
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def bsr_set_from_triplets(
|
|
175
|
+
dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
176
|
+
rows: "Array[int]",
|
|
177
|
+
columns: "Array[int]",
|
|
178
|
+
values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
|
|
182
|
+
|
|
183
|
+
The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
dest: Sparse matrix to populate
|
|
187
|
+
rows: Row index for each non-zero
|
|
188
|
+
columns: Columns index for each non-zero
|
|
189
|
+
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
190
|
+
to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
|
|
194
|
+
raise ValueError("All arguments must reside on the same device")
|
|
195
|
+
|
|
196
|
+
if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]:
|
|
197
|
+
raise ValueError("All triplet arrays must have the same length")
|
|
198
|
+
|
|
199
|
+
# Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
|
|
200
|
+
if values.ndim == 1:
|
|
201
|
+
if values.dtype != dest.values.dtype:
|
|
202
|
+
raise ValueError("Values array type must correspond to that of dest matrix")
|
|
203
|
+
elif values.ndim == 3:
|
|
204
|
+
if values.shape[1:] != dest.block_shape:
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
|
|
210
|
+
raise ValueError("Scalar type of values array should correspond to that of matrix")
|
|
211
|
+
|
|
212
|
+
if not values.is_contiguous:
|
|
213
|
+
raise ValueError("Multi-dimensional values array should be contiguous")
|
|
214
|
+
else:
|
|
215
|
+
raise ValueError("Number of dimension for values array should be 1 or 3")
|
|
216
|
+
|
|
217
|
+
nnz = rows.shape[0]
|
|
218
|
+
if nnz == 0:
|
|
219
|
+
bsr_set_zero(dest)
|
|
220
|
+
return
|
|
221
|
+
|
|
222
|
+
# Increase dest array sizes if needed
|
|
223
|
+
_bsr_ensure_fits(dest, nnz=nnz)
|
|
224
|
+
|
|
225
|
+
device = dest.values.device
|
|
226
|
+
scalar_type = dest.scalar_type
|
|
227
|
+
from warp.context import runtime
|
|
228
|
+
|
|
229
|
+
if device.is_cpu:
|
|
230
|
+
if scalar_type == wp.float32:
|
|
231
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
232
|
+
elif scalar_type == wp.float64:
|
|
233
|
+
native_func = runtime.core.bsr_matrix_from_triplets_double_host
|
|
234
|
+
else:
|
|
235
|
+
if scalar_type == wp.float32:
|
|
236
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
237
|
+
elif scalar_type == wp.float64:
|
|
238
|
+
native_func = runtime.core.bsr_matrix_from_triplets_double_device
|
|
239
|
+
|
|
240
|
+
if not native_func:
|
|
241
|
+
raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
|
|
242
|
+
|
|
243
|
+
dest.nnz = native_func(
|
|
244
|
+
dest.block_shape[0],
|
|
245
|
+
dest.block_shape[1],
|
|
246
|
+
dest.nrow,
|
|
247
|
+
nnz,
|
|
248
|
+
rows.ptr,
|
|
249
|
+
columns.ptr,
|
|
250
|
+
values.ptr,
|
|
251
|
+
dest.offsets.ptr,
|
|
252
|
+
dest.columns.ptr,
|
|
253
|
+
dest.values.ptr,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
|
|
258
|
+
"""Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
|
|
259
|
+
|
|
260
|
+
if dest.values.device != src.values.device:
|
|
261
|
+
raise ValueError("Source and destination matrices must reside on the same device")
|
|
262
|
+
|
|
263
|
+
if dest.block_shape != src.block_shape:
|
|
264
|
+
raise ValueError("Source and destination matrices must have the same block shape")
|
|
265
|
+
|
|
266
|
+
dest.nrow = src.nrow
|
|
267
|
+
dest.ncol = src.ncol
|
|
268
|
+
dest.nnz = src.nnz
|
|
269
|
+
|
|
270
|
+
_bsr_ensure_fits(dest)
|
|
271
|
+
|
|
272
|
+
wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
|
|
273
|
+
if src.nnz > 0:
|
|
274
|
+
wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
|
|
275
|
+
warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
|
|
279
|
+
"""Returns a copy of matrix ``A``, possibly changing its scalar type.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
|
|
283
|
+
"""
|
|
284
|
+
if scalar_type is None:
|
|
285
|
+
block_type = A.values.dtype
|
|
286
|
+
elif A.block_shape == (1, 1):
|
|
287
|
+
block_type = scalar_type
|
|
288
|
+
else:
|
|
289
|
+
block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type)
|
|
290
|
+
|
|
291
|
+
copy = bsr_zeros(rows_of_blocks=A.nrow, cols_of_blocks=A.ncol, block_type=block_type, device=A.values.device)
|
|
292
|
+
bsr_assign(dest=copy, src=A)
|
|
293
|
+
return copy
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
|
|
297
|
+
"""Assigns the transposed matrix `src` to matrix `dest`"""
|
|
298
|
+
|
|
299
|
+
if dest.values.device != src.values.device:
|
|
300
|
+
raise ValueError("All arguments must reside on the same device")
|
|
301
|
+
|
|
302
|
+
if dest.scalar_type != src.scalar_type:
|
|
303
|
+
raise ValueError("All arguments must have the same scalar type")
|
|
304
|
+
|
|
305
|
+
transpose_block_shape = src.block_shape[::-1]
|
|
306
|
+
|
|
307
|
+
if dest.block_shape != transpose_block_shape:
|
|
308
|
+
raise ValueError(f"Destination block shape must be {transpose_block_shape}")
|
|
309
|
+
|
|
310
|
+
dest.nrow = src.ncol
|
|
311
|
+
dest.ncol = src.nrow
|
|
312
|
+
dest.nnz = src.nnz
|
|
313
|
+
|
|
314
|
+
if src.nnz == 0:
|
|
315
|
+
return
|
|
316
|
+
|
|
317
|
+
# Increase dest array sizes if needed
|
|
318
|
+
_bsr_ensure_fits(dest)
|
|
319
|
+
|
|
320
|
+
from warp.context import runtime
|
|
321
|
+
|
|
322
|
+
if dest.values.device.is_cpu:
|
|
323
|
+
if dest.scalar_type == wp.float32:
|
|
324
|
+
native_func = runtime.core.bsr_transpose_float_host
|
|
325
|
+
elif dest.scalar_type == wp.float64:
|
|
326
|
+
native_func = runtime.core.bsr_transpose_double_host
|
|
327
|
+
else:
|
|
328
|
+
if dest.scalar_type == wp.float32:
|
|
329
|
+
native_func = runtime.core.bsr_transpose_float_device
|
|
330
|
+
elif dest.scalar_type == wp.float64:
|
|
331
|
+
native_func = runtime.core.bsr_transpose_double_device
|
|
332
|
+
|
|
333
|
+
if not native_func:
|
|
334
|
+
raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
|
|
335
|
+
|
|
336
|
+
native_func(
|
|
337
|
+
src.block_shape[0],
|
|
338
|
+
src.block_shape[1],
|
|
339
|
+
src.nrow,
|
|
340
|
+
src.ncol,
|
|
341
|
+
src.nnz,
|
|
342
|
+
src.offsets.ptr,
|
|
343
|
+
src.columns.ptr,
|
|
344
|
+
src.values.ptr,
|
|
345
|
+
dest.offsets.ptr,
|
|
346
|
+
dest.columns.ptr,
|
|
347
|
+
dest.values.ptr,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def bsr_transposed(A: BsrMatrix):
|
|
352
|
+
"""Returns a copy of the transposed matrix `A`"""
|
|
353
|
+
|
|
354
|
+
if A.block_shape == (1, 1):
|
|
355
|
+
block_type = A.values.dtype
|
|
356
|
+
else:
|
|
357
|
+
block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type)
|
|
358
|
+
|
|
359
|
+
transposed = bsr_zeros(rows_of_blocks=A.ncol, cols_of_blocks=A.nrow, block_type=block_type, device=A.values.device)
|
|
360
|
+
bsr_set_transpose(dest=transposed, src=A)
|
|
361
|
+
return transposed
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@wp.kernel
|
|
365
|
+
def _bsr_get_diag_kernel(
|
|
366
|
+
A_offsets: wp.array(dtype=int),
|
|
367
|
+
A_columns: wp.array(dtype=int),
|
|
368
|
+
A_values: wp.array(dtype=Any),
|
|
369
|
+
out: wp.array(dtype=Any),
|
|
370
|
+
):
|
|
371
|
+
row = wp.tid()
|
|
372
|
+
beg = A_offsets[row]
|
|
373
|
+
end = A_offsets[row + 1]
|
|
374
|
+
|
|
375
|
+
diag = wp.lower_bound(A_columns, beg, end, row)
|
|
376
|
+
if diag < end:
|
|
377
|
+
if A_columns[diag] == row:
|
|
378
|
+
out[row] = A_values[diag]
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
|
|
382
|
+
"""Returns the array of blocks that constitute the diagonal of a sparse matrix.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
A: the sparse matrix from which to extract the diagonal
|
|
386
|
+
out: if provided, the array into which to store the diagonal blocks
|
|
387
|
+
"""
|
|
388
|
+
|
|
389
|
+
dim = min(A.nrow, A.ncol)
|
|
390
|
+
|
|
391
|
+
if out is None:
|
|
392
|
+
out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
|
|
393
|
+
else:
|
|
394
|
+
if out.dtype != A.values.dtype:
|
|
395
|
+
raise ValueError(f"Output array must have type {A.values.dtype}")
|
|
396
|
+
if out.device != A.values.device:
|
|
397
|
+
raise ValueError(f"Output array must reside on device {A.values.device}")
|
|
398
|
+
if out.shape[0] < dim:
|
|
399
|
+
raise ValueError(f"Output array must be of length at least {dim}")
|
|
400
|
+
|
|
401
|
+
wp.launch(
|
|
402
|
+
kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
return out
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
@wp.kernel
|
|
409
|
+
def _bsr_set_diag_kernel(
|
|
410
|
+
diag: wp.array(dtype=Any),
|
|
411
|
+
A_offsets: wp.array(dtype=int),
|
|
412
|
+
A_columns: wp.array(dtype=int),
|
|
413
|
+
A_values: wp.array(dtype=Any),
|
|
414
|
+
):
|
|
415
|
+
row = wp.tid()
|
|
416
|
+
A_offsets[row + 1] = row + 1
|
|
417
|
+
A_columns[row] = row
|
|
418
|
+
A_values[row] = diag[row]
|
|
419
|
+
|
|
420
|
+
if row == 0:
|
|
421
|
+
A_offsets[0] = 0
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@wp.kernel
|
|
425
|
+
def _bsr_set_diag_constant_kernel(
|
|
426
|
+
diag_value: Any,
|
|
427
|
+
A_offsets: wp.array(dtype=int),
|
|
428
|
+
A_columns: wp.array(dtype=int),
|
|
429
|
+
A_values: wp.array(dtype=Any),
|
|
430
|
+
):
|
|
431
|
+
row = wp.tid()
|
|
432
|
+
A_offsets[row + 1] = row + 1
|
|
433
|
+
A_columns[row] = row
|
|
434
|
+
A_values[row] = diag_value
|
|
435
|
+
|
|
436
|
+
if row == 0:
|
|
437
|
+
A_offsets[0] = 0
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def bsr_set_diag(
|
|
441
|
+
A: BsrMatrix[BlockType],
|
|
442
|
+
diag: "Union[BlockType, Array[BlockType]]",
|
|
443
|
+
rows_of_blocks: Optional[int] = None,
|
|
444
|
+
cols_of_blocks: Optional[int] = None,
|
|
445
|
+
):
|
|
446
|
+
"""Sets `A` as a block-diagonal matrix
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
A: the sparse matrix to modify
|
|
450
|
+
diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
|
|
451
|
+
or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
|
|
452
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
453
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
454
|
+
|
|
455
|
+
The shape of the matrix will be defined one of the following, in that order:
|
|
456
|
+
- `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
|
|
457
|
+
- the first dimension of `diag`, if `diag` is an array
|
|
458
|
+
- the current dimensions of `A` otherwise
|
|
459
|
+
"""
|
|
460
|
+
|
|
461
|
+
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
462
|
+
rows_of_blocks = cols_of_blocks
|
|
463
|
+
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
464
|
+
cols_of_blocks = rows_of_blocks
|
|
465
|
+
|
|
466
|
+
if warp.types.is_array(diag):
|
|
467
|
+
if rows_of_blocks is None:
|
|
468
|
+
rows_of_blocks = diag.shape[0]
|
|
469
|
+
cols_of_blocks = diag.shape[0]
|
|
470
|
+
|
|
471
|
+
if rows_of_blocks is not None:
|
|
472
|
+
A.nrow = rows_of_blocks
|
|
473
|
+
A.ncol = cols_of_blocks
|
|
474
|
+
|
|
475
|
+
A.nnz = min(A.nrow, A.ncol)
|
|
476
|
+
_bsr_ensure_fits(A)
|
|
477
|
+
|
|
478
|
+
if warp.types.is_array(diag):
|
|
479
|
+
wp.launch(
|
|
480
|
+
kernel=_bsr_set_diag_kernel,
|
|
481
|
+
dim=A.nnz,
|
|
482
|
+
device=A.values.device,
|
|
483
|
+
inputs=[diag, A.offsets, A.columns, A.values],
|
|
484
|
+
)
|
|
485
|
+
else:
|
|
486
|
+
if not warp.types.type_is_value(type(diag)):
|
|
487
|
+
# Cast to launchable type
|
|
488
|
+
diag = A.values.dtype(diag)
|
|
489
|
+
wp.launch(
|
|
490
|
+
kernel=_bsr_set_diag_constant_kernel,
|
|
491
|
+
dim=A.nnz,
|
|
492
|
+
device=A.values.device,
|
|
493
|
+
inputs=[diag, A.offsets, A.columns, A.values],
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def bsr_diag(
|
|
498
|
+
diag: "Union[BlockType, Array[BlockType]]",
|
|
499
|
+
rows_of_blocks: Optional[int] = None,
|
|
500
|
+
cols_of_blocks: Optional[int] = None,
|
|
501
|
+
) -> BsrMatrix["BlockType"]:
|
|
502
|
+
"""Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
|
|
506
|
+
or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
|
|
507
|
+
rows_of_blocks: If not ``None``, the new number of rows of blocks
|
|
508
|
+
cols_of_blocks: If not ``None``, the new number of columns of blocks
|
|
509
|
+
|
|
510
|
+
The shape of the matrix will be defined one of the following, in that order:
|
|
511
|
+
- `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
|
|
512
|
+
- the first dimension of `diag`, if `diag` is an array
|
|
513
|
+
"""
|
|
514
|
+
|
|
515
|
+
if rows_of_blocks is None and cols_of_blocks is not None:
|
|
516
|
+
rows_of_blocks = cols_of_blocks
|
|
517
|
+
if cols_of_blocks is None and rows_of_blocks is not None:
|
|
518
|
+
cols_of_blocks = rows_of_blocks
|
|
519
|
+
|
|
520
|
+
if warp.types.is_array(diag):
|
|
521
|
+
if rows_of_blocks is None:
|
|
522
|
+
rows_of_blocks = diag.shape[0]
|
|
523
|
+
cols_of_blocks = diag.shape[0]
|
|
524
|
+
|
|
525
|
+
A = bsr_zeros(
|
|
526
|
+
rows_of_blocks,
|
|
527
|
+
cols_of_blocks,
|
|
528
|
+
block_type=diag.dtype,
|
|
529
|
+
device=diag.device,
|
|
530
|
+
)
|
|
531
|
+
else:
|
|
532
|
+
if rows_of_blocks is None:
|
|
533
|
+
raise ValueError(
|
|
534
|
+
"rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
block_type = type(diag)
|
|
538
|
+
if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
|
|
539
|
+
block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
|
|
540
|
+
|
|
541
|
+
A = bsr_zeros(
|
|
542
|
+
rows_of_blocks,
|
|
543
|
+
cols_of_blocks,
|
|
544
|
+
block_type=block_type,
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
bsr_set_diag(A, diag)
|
|
548
|
+
return A
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
|
|
552
|
+
"""Sets `A` as the identity matrix
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
A: the sparse matrix to modify
|
|
556
|
+
rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
|
|
557
|
+
"""
|
|
558
|
+
|
|
559
|
+
if A.block_shape == (1, 1):
|
|
560
|
+
identity = A.scalar_type(1.0)
|
|
561
|
+
else:
|
|
562
|
+
from numpy import eye
|
|
563
|
+
|
|
564
|
+
identity = eye(A.block_shape[0])
|
|
565
|
+
|
|
566
|
+
bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def bsr_identity(
|
|
570
|
+
rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
|
|
571
|
+
) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
|
|
572
|
+
"""Creates and returns a square identity matrix.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
rows_of_blocks: Number of rows and columns of blocks in the created matrix.
|
|
576
|
+
block_type: Block type for the newly created matrix -- must be square
|
|
577
|
+
device: Device onto which to allocate the data arrays
|
|
578
|
+
"""
|
|
579
|
+
A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
|
|
580
|
+
bsr_set_identity(A)
|
|
581
|
+
return A
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
@wp.kernel
|
|
585
|
+
def _bsr_scale_kernel(
|
|
586
|
+
alpha: Any,
|
|
587
|
+
values: wp.array(dtype=Any),
|
|
588
|
+
):
|
|
589
|
+
values[wp.tid()] = alpha * values[wp.tid()]
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
|
|
593
|
+
"""
|
|
594
|
+
Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
if alpha != 1.0 and x.nnz > 0:
|
|
598
|
+
if alpha == 0.0:
|
|
599
|
+
bsr_set_zero(x)
|
|
600
|
+
else:
|
|
601
|
+
if not isinstance(alpha, x.scalar_type):
|
|
602
|
+
alpha = x.scalar_type(alpha)
|
|
603
|
+
|
|
604
|
+
wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
|
|
605
|
+
|
|
606
|
+
return x
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
@wp.kernel
|
|
610
|
+
def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
|
|
611
|
+
i = wp.tid()
|
|
612
|
+
|
|
613
|
+
row = wp.lower_bound(bsr_offsets, i + 1) - 1
|
|
614
|
+
rows[dest_offset + i] = row
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
@wp.kernel
|
|
618
|
+
def _bsr_axpy_add_block(
|
|
619
|
+
src_offset: int,
|
|
620
|
+
scale: Any,
|
|
621
|
+
rows: wp.array(dtype=int),
|
|
622
|
+
cols: wp.array(dtype=int),
|
|
623
|
+
dst_offsets: wp.array(dtype=int),
|
|
624
|
+
dst_columns: wp.array(dtype=int),
|
|
625
|
+
src_values: wp.array(dtype=Any),
|
|
626
|
+
dst_values: wp.array(dtype=Any),
|
|
627
|
+
):
|
|
628
|
+
i = wp.tid()
|
|
629
|
+
row = rows[i + src_offset]
|
|
630
|
+
col = cols[i + src_offset]
|
|
631
|
+
beg = dst_offsets[row]
|
|
632
|
+
end = dst_offsets[row + 1]
|
|
633
|
+
|
|
634
|
+
block = wp.lower_bound(dst_columns, beg, end, col)
|
|
635
|
+
|
|
636
|
+
dst_values[block] = dst_values[block] + scale * src_values[i]
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
class bsr_axpy_work_arrays:
|
|
640
|
+
"""Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
|
|
641
|
+
|
|
642
|
+
def __init__(self):
|
|
643
|
+
self._reset(None)
|
|
644
|
+
|
|
645
|
+
def _reset(self, device):
|
|
646
|
+
self.device = device
|
|
647
|
+
self._sum_rows = None
|
|
648
|
+
self._sum_cols = None
|
|
649
|
+
self._old_y_values = None
|
|
650
|
+
self._old_x_values = None
|
|
651
|
+
|
|
652
|
+
def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
|
|
653
|
+
if self.device != device:
|
|
654
|
+
self._reset(device)
|
|
655
|
+
|
|
656
|
+
if self._sum_rows is None or self._sum_rows.size < sum_nnz:
|
|
657
|
+
self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
658
|
+
if self._sum_cols is None or self._sum_cols.size < sum_nnz:
|
|
659
|
+
self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
|
|
660
|
+
|
|
661
|
+
if self._old_y_values is None or self._old_y_values.size < y.nnz:
|
|
662
|
+
self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def bsr_axpy(
|
|
666
|
+
x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
667
|
+
y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
668
|
+
alpha: Scalar = 1.0,
|
|
669
|
+
beta: Scalar = 1.0,
|
|
670
|
+
work_arrays: Optional[bsr_axpy_work_arrays] = None,
|
|
671
|
+
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
672
|
+
"""
|
|
673
|
+
Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
|
|
674
|
+
|
|
675
|
+
The `x` and `y` matrices are allowed to alias.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
x: Read-only right-hand-side.
|
|
679
|
+
y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
|
|
680
|
+
alpha: Uniform scaling factor for `x`
|
|
681
|
+
beta: Uniform scaling factor for `y`
|
|
682
|
+
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
|
|
683
|
+
"""
|
|
684
|
+
|
|
685
|
+
if y is None:
|
|
686
|
+
# If not output matrix is provided, allocate it for convenience
|
|
687
|
+
y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
|
|
688
|
+
beta = 0.0
|
|
689
|
+
|
|
690
|
+
# Handle easy cases first
|
|
691
|
+
if beta == 0.0 or y.nnz == 0:
|
|
692
|
+
bsr_assign(src=x, dest=y)
|
|
693
|
+
return bsr_scale(y, alpha=alpha)
|
|
694
|
+
|
|
695
|
+
if alpha == 0.0 or x.nnz == 0:
|
|
696
|
+
return bsr_scale(y, alpha=beta)
|
|
697
|
+
|
|
698
|
+
if not isinstance(alpha, y.scalar_type):
|
|
699
|
+
alpha = y.scalar_type(alpha)
|
|
700
|
+
if not isinstance(beta, y.scalar_type):
|
|
701
|
+
beta = y.scalar_type(beta)
|
|
702
|
+
|
|
703
|
+
if x == y:
|
|
704
|
+
# Aliasing case
|
|
705
|
+
return bsr_scale(y, alpha=alpha.value + beta.value)
|
|
706
|
+
|
|
707
|
+
# General case
|
|
708
|
+
|
|
709
|
+
if x.values.device != y.values.device:
|
|
710
|
+
raise ValueError("All arguments must reside on the same device")
|
|
711
|
+
|
|
712
|
+
if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
|
|
713
|
+
raise ValueError("Matrices must have the same block type")
|
|
714
|
+
|
|
715
|
+
if x.nrow != y.nrow or x.ncol != y.ncol:
|
|
716
|
+
raise ValueError("Matrices must have the same number of rows and columns")
|
|
717
|
+
|
|
718
|
+
if work_arrays is None:
|
|
719
|
+
work_arrays = bsr_axpy_work_arrays()
|
|
720
|
+
|
|
721
|
+
sum_nnz = x.nnz + y.nnz
|
|
722
|
+
device = y.values.device
|
|
723
|
+
work_arrays._allocate(device, y, sum_nnz)
|
|
724
|
+
|
|
725
|
+
wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
|
|
726
|
+
wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
|
|
727
|
+
|
|
728
|
+
wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
|
|
729
|
+
wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
|
|
730
|
+
|
|
731
|
+
# Save old y values before overwriting matrix
|
|
732
|
+
wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
|
|
733
|
+
|
|
734
|
+
# Increase dest array sizes if needed
|
|
735
|
+
if y.columns.shape[0] < sum_nnz:
|
|
736
|
+
y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device)
|
|
737
|
+
|
|
738
|
+
from warp.context import runtime
|
|
739
|
+
|
|
740
|
+
if device.is_cpu:
|
|
741
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
742
|
+
else:
|
|
743
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
744
|
+
|
|
745
|
+
old_y_nnz = y.nnz
|
|
746
|
+
y.nnz = native_func(
|
|
747
|
+
y.block_shape[0],
|
|
748
|
+
y.block_shape[1],
|
|
749
|
+
y.nrow,
|
|
750
|
+
sum_nnz,
|
|
751
|
+
work_arrays._sum_rows.ptr,
|
|
752
|
+
work_arrays._sum_cols.ptr,
|
|
753
|
+
0,
|
|
754
|
+
y.offsets.ptr,
|
|
755
|
+
y.columns.ptr,
|
|
756
|
+
0,
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
_bsr_ensure_fits(y)
|
|
760
|
+
y.values.zero_()
|
|
761
|
+
|
|
762
|
+
wp.launch(
|
|
763
|
+
kernel=_bsr_axpy_add_block,
|
|
764
|
+
device=device,
|
|
765
|
+
dim=old_y_nnz,
|
|
766
|
+
inputs=[
|
|
767
|
+
0,
|
|
768
|
+
beta,
|
|
769
|
+
work_arrays._sum_rows,
|
|
770
|
+
work_arrays._sum_cols,
|
|
771
|
+
y.offsets,
|
|
772
|
+
y.columns,
|
|
773
|
+
work_arrays._old_y_values,
|
|
774
|
+
y.values,
|
|
775
|
+
],
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
wp.launch(
|
|
779
|
+
kernel=_bsr_axpy_add_block,
|
|
780
|
+
device=device,
|
|
781
|
+
dim=x.nnz,
|
|
782
|
+
inputs=[
|
|
783
|
+
old_y_nnz,
|
|
784
|
+
alpha,
|
|
785
|
+
work_arrays._sum_rows,
|
|
786
|
+
work_arrays._sum_cols,
|
|
787
|
+
y.offsets,
|
|
788
|
+
y.columns,
|
|
789
|
+
x.values,
|
|
790
|
+
y.values,
|
|
791
|
+
],
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
return y
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
@wp.kernel
|
|
798
|
+
def _bsr_mm_count_coeffs(
|
|
799
|
+
z_nnz: int,
|
|
800
|
+
x_offsets: wp.array(dtype=int),
|
|
801
|
+
x_columns: wp.array(dtype=int),
|
|
802
|
+
y_offsets: wp.array(dtype=int),
|
|
803
|
+
counts: wp.array(dtype=int),
|
|
804
|
+
):
|
|
805
|
+
row = wp.tid()
|
|
806
|
+
count = int(0)
|
|
807
|
+
|
|
808
|
+
x_beg = x_offsets[row]
|
|
809
|
+
x_end = x_offsets[row + 1]
|
|
810
|
+
|
|
811
|
+
for x_block in range(x_beg, x_end):
|
|
812
|
+
x_col = x_columns[x_block]
|
|
813
|
+
count += y_offsets[x_col + 1] - y_offsets[x_col]
|
|
814
|
+
|
|
815
|
+
counts[row + 1] = count
|
|
816
|
+
|
|
817
|
+
if row == 0:
|
|
818
|
+
counts[0] = z_nnz
|
|
819
|
+
|
|
820
|
+
|
|
821
|
+
@wp.kernel
|
|
822
|
+
def _bsr_mm_list_coeffs(
|
|
823
|
+
x_offsets: wp.array(dtype=int),
|
|
824
|
+
x_columns: wp.array(dtype=int),
|
|
825
|
+
y_offsets: wp.array(dtype=int),
|
|
826
|
+
y_columns: wp.array(dtype=int),
|
|
827
|
+
mm_offsets: wp.array(dtype=int),
|
|
828
|
+
mm_rows: wp.array(dtype=int),
|
|
829
|
+
mm_cols: wp.array(dtype=int),
|
|
830
|
+
):
|
|
831
|
+
row = wp.tid()
|
|
832
|
+
mm_block = mm_offsets[row]
|
|
833
|
+
|
|
834
|
+
x_beg = x_offsets[row]
|
|
835
|
+
x_end = x_offsets[row + 1]
|
|
836
|
+
|
|
837
|
+
for x_block in range(x_beg, x_end):
|
|
838
|
+
x_col = x_columns[x_block]
|
|
839
|
+
|
|
840
|
+
y_beg = y_offsets[x_col]
|
|
841
|
+
y_end = y_offsets[x_col + 1]
|
|
842
|
+
for y_block in range(y_beg, y_end):
|
|
843
|
+
mm_cols[mm_block] = y_columns[y_block]
|
|
844
|
+
mm_rows[mm_block] = row
|
|
845
|
+
mm_block += 1
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
@wp.kernel
|
|
849
|
+
def _bsr_mm_compute_values(
|
|
850
|
+
alpha: Any,
|
|
851
|
+
x_offsets: wp.array(dtype=int),
|
|
852
|
+
x_columns: wp.array(dtype=int),
|
|
853
|
+
x_values: wp.array(dtype=Any),
|
|
854
|
+
y_offsets: wp.array(dtype=int),
|
|
855
|
+
y_columns: wp.array(dtype=int),
|
|
856
|
+
y_values: wp.array(dtype=Any),
|
|
857
|
+
mm_offsets: wp.array(dtype=int),
|
|
858
|
+
mm_cols: wp.array(dtype=int),
|
|
859
|
+
mm_values: wp.array(dtype=Any),
|
|
860
|
+
):
|
|
861
|
+
row = wp.tid()
|
|
862
|
+
mm_beg = mm_offsets[row]
|
|
863
|
+
mm_end = mm_offsets[row + 1]
|
|
864
|
+
|
|
865
|
+
x_beg = x_offsets[row]
|
|
866
|
+
x_end = x_offsets[row + 1]
|
|
867
|
+
for x_block in range(x_beg, x_end):
|
|
868
|
+
x_col = x_columns[x_block]
|
|
869
|
+
ax_val = alpha * x_values[x_block]
|
|
870
|
+
|
|
871
|
+
y_beg = y_offsets[x_col]
|
|
872
|
+
y_end = y_offsets[x_col + 1]
|
|
873
|
+
|
|
874
|
+
for y_block in range(y_beg, y_end):
|
|
875
|
+
mm_block = wp.lower_bound(mm_cols, mm_beg, mm_end, y_columns[y_block])
|
|
876
|
+
mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
class bsr_mm_work_arrays:
|
|
880
|
+
"""Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
|
|
881
|
+
|
|
882
|
+
def __init__(self):
|
|
883
|
+
self._reset(None)
|
|
884
|
+
|
|
885
|
+
def _reset(self, device):
|
|
886
|
+
self.device = device
|
|
887
|
+
self._pinned_count_buffer = None
|
|
888
|
+
self._mm_row_counts = None
|
|
889
|
+
self._mm_rows = None
|
|
890
|
+
self._mm_cols = None
|
|
891
|
+
self._old_z_values = None
|
|
892
|
+
self._old_z_offsets = None
|
|
893
|
+
self._old_z_columns = None
|
|
894
|
+
|
|
895
|
+
def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
|
|
896
|
+
if self.device != device:
|
|
897
|
+
self._reset(device)
|
|
898
|
+
|
|
899
|
+
# Allocations that do not depend on any computation
|
|
900
|
+
if self.device.is_cuda:
|
|
901
|
+
if self._pinned_count_buffer is None:
|
|
902
|
+
self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
|
|
903
|
+
|
|
904
|
+
if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
|
|
905
|
+
self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
|
|
906
|
+
|
|
907
|
+
if copied_z_nnz > 0:
|
|
908
|
+
if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
|
|
909
|
+
self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
|
|
910
|
+
|
|
911
|
+
if z_aliasing:
|
|
912
|
+
if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
|
|
913
|
+
self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
|
|
914
|
+
if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
|
|
915
|
+
self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
|
|
916
|
+
|
|
917
|
+
def _allocate_stage_2(self, mm_nnz: int):
|
|
918
|
+
# Allocations that depend on unmerged nnz estimate
|
|
919
|
+
if self._mm_rows is None or self._mm_rows.size < mm_nnz:
|
|
920
|
+
self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
921
|
+
if self._mm_cols is None or self._mm_cols.size < mm_nnz:
|
|
922
|
+
self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
def bsr_mm(
|
|
926
|
+
x: BsrMatrix[BlockType[Rows, Any, Scalar]],
|
|
927
|
+
y: BsrMatrix[BlockType[Any, Cols, Scalar]],
|
|
928
|
+
z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
|
|
929
|
+
alpha: Scalar = 1.0,
|
|
930
|
+
beta: Scalar = 0.0,
|
|
931
|
+
work_arrays: Optional[bsr_mm_work_arrays] = None,
|
|
932
|
+
) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
|
|
933
|
+
"""
|
|
934
|
+
Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
|
|
935
|
+
|
|
936
|
+
The `x`, `y` and `z` matrices are allowed to alias.
|
|
937
|
+
If the matrix `z` is not provided as input, it will be allocated and treated as zero.
|
|
938
|
+
|
|
939
|
+
Args:
|
|
940
|
+
x: Read-only left factor of the matrix-matrix product.
|
|
941
|
+
y: Read-only right factor of the matrix-matrix product.
|
|
942
|
+
z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
|
|
943
|
+
alpha: Uniform scaling factor for the ``x * y`` product
|
|
944
|
+
beta: Uniform scaling factor for `z`
|
|
945
|
+
work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
|
|
946
|
+
"""
|
|
947
|
+
|
|
948
|
+
if z is None:
|
|
949
|
+
# If not output matrix is provided, allocate it for convenience
|
|
950
|
+
z_block_shape = (x.block_shape[0], y.block_shape[1])
|
|
951
|
+
if z_block_shape == (1, 1):
|
|
952
|
+
z_block_type = x.scalar_type
|
|
953
|
+
else:
|
|
954
|
+
z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type)
|
|
955
|
+
z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
|
|
956
|
+
beta = 0.0
|
|
957
|
+
|
|
958
|
+
if x.values.device != y.values.device or x.values.device != z.values.device:
|
|
959
|
+
raise ValueError("All arguments must reside on the same device")
|
|
960
|
+
|
|
961
|
+
if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
|
|
962
|
+
raise ValueError("Matrices must have the same scalar type")
|
|
963
|
+
|
|
964
|
+
if (
|
|
965
|
+
x.block_shape[0] != z.block_shape[0]
|
|
966
|
+
or y.block_shape[1] != z.block_shape[1]
|
|
967
|
+
or x.block_shape[1] != y.block_shape[0]
|
|
968
|
+
):
|
|
969
|
+
raise ValueError("Incompatible block sizes for matrix multiplication")
|
|
970
|
+
|
|
971
|
+
if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
|
|
972
|
+
raise ValueError("Incompatible number of rows/columns for matrix multiplication")
|
|
973
|
+
|
|
974
|
+
device = z.values.device
|
|
975
|
+
|
|
976
|
+
if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
|
|
977
|
+
# Easy case
|
|
978
|
+
return bsr_scale(z, beta)
|
|
979
|
+
|
|
980
|
+
if not isinstance(alpha, z.scalar_type):
|
|
981
|
+
alpha = z.scalar_type(alpha)
|
|
982
|
+
if not isinstance(beta, z.scalar_type):
|
|
983
|
+
beta = z.scalar_type(beta)
|
|
984
|
+
|
|
985
|
+
if work_arrays is None:
|
|
986
|
+
work_arrays = bsr_mm_work_arrays()
|
|
987
|
+
|
|
988
|
+
z_aliasing = z == x or z == y
|
|
989
|
+
copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
|
|
990
|
+
|
|
991
|
+
work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
|
|
992
|
+
|
|
993
|
+
# Prefix sum of number of (unmerged) mm blocks per row
|
|
994
|
+
wp.launch(
|
|
995
|
+
kernel=_bsr_mm_count_coeffs,
|
|
996
|
+
device=device,
|
|
997
|
+
dim=z.nrow,
|
|
998
|
+
inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
|
|
999
|
+
)
|
|
1000
|
+
warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
|
|
1001
|
+
|
|
1002
|
+
# Get back total counts on host
|
|
1003
|
+
if device.is_cuda:
|
|
1004
|
+
wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
|
|
1005
|
+
wp.synchronize_stream(wp.get_stream(device))
|
|
1006
|
+
mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
|
|
1007
|
+
else:
|
|
1008
|
+
mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
|
|
1009
|
+
|
|
1010
|
+
work_arrays._allocate_stage_2(mm_nnz)
|
|
1011
|
+
|
|
1012
|
+
# If z has a non-zero scale, save current data before overwriting it
|
|
1013
|
+
if copied_z_nnz > 0:
|
|
1014
|
+
# Copy z row and column indices
|
|
1015
|
+
wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
|
|
1016
|
+
wp.launch(
|
|
1017
|
+
kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
|
|
1018
|
+
)
|
|
1019
|
+
# Save current z values in temporary buffer
|
|
1020
|
+
wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
|
|
1021
|
+
if z_aliasing:
|
|
1022
|
+
# If z is aliasing with x or y, need to save topology as well
|
|
1023
|
+
wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
|
|
1024
|
+
wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
|
|
1025
|
+
|
|
1026
|
+
# Fill unmerged mm blocks rows and columns
|
|
1027
|
+
wp.launch(
|
|
1028
|
+
kernel=_bsr_mm_list_coeffs,
|
|
1029
|
+
device=device,
|
|
1030
|
+
dim=z.nrow,
|
|
1031
|
+
inputs=[
|
|
1032
|
+
x.offsets,
|
|
1033
|
+
x.columns,
|
|
1034
|
+
y.offsets,
|
|
1035
|
+
y.columns,
|
|
1036
|
+
work_arrays._mm_row_counts,
|
|
1037
|
+
work_arrays._mm_rows,
|
|
1038
|
+
work_arrays._mm_cols,
|
|
1039
|
+
],
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
# Increase dest array size if needed
|
|
1043
|
+
if z.columns.shape[0] < mm_nnz:
|
|
1044
|
+
z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
|
|
1045
|
+
|
|
1046
|
+
from warp.context import runtime
|
|
1047
|
+
|
|
1048
|
+
if device.is_cpu:
|
|
1049
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
1050
|
+
else:
|
|
1051
|
+
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
1052
|
+
|
|
1053
|
+
z.nnz = native_func(
|
|
1054
|
+
z.block_shape[0],
|
|
1055
|
+
z.block_shape[1],
|
|
1056
|
+
z.nrow,
|
|
1057
|
+
mm_nnz,
|
|
1058
|
+
work_arrays._mm_rows.ptr,
|
|
1059
|
+
work_arrays._mm_cols.ptr,
|
|
1060
|
+
0,
|
|
1061
|
+
z.offsets.ptr,
|
|
1062
|
+
z.columns.ptr,
|
|
1063
|
+
0,
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
_bsr_ensure_fits(z)
|
|
1067
|
+
z.values.zero_()
|
|
1068
|
+
|
|
1069
|
+
if copied_z_nnz > 0:
|
|
1070
|
+
# Add back original z values
|
|
1071
|
+
wp.launch(
|
|
1072
|
+
kernel=_bsr_axpy_add_block,
|
|
1073
|
+
device=device,
|
|
1074
|
+
dim=copied_z_nnz,
|
|
1075
|
+
inputs=[
|
|
1076
|
+
0,
|
|
1077
|
+
beta,
|
|
1078
|
+
work_arrays._mm_rows,
|
|
1079
|
+
work_arrays._mm_cols,
|
|
1080
|
+
z.offsets,
|
|
1081
|
+
z.columns,
|
|
1082
|
+
work_arrays._old_z_values,
|
|
1083
|
+
z.values,
|
|
1084
|
+
],
|
|
1085
|
+
)
|
|
1086
|
+
|
|
1087
|
+
# Add mm blocks to z values
|
|
1088
|
+
if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
|
|
1089
|
+
warp.types.type_is_matrix(z.values.dtype)
|
|
1090
|
+
):
|
|
1091
|
+
# Result block type is scalar, but operands are matrices
|
|
1092
|
+
# Cast result to (1x1) matrix to perform multiplication
|
|
1093
|
+
mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
|
|
1094
|
+
else:
|
|
1095
|
+
mm_values = z.values
|
|
1096
|
+
|
|
1097
|
+
wp.launch(
|
|
1098
|
+
kernel=_bsr_mm_compute_values,
|
|
1099
|
+
device=device,
|
|
1100
|
+
dim=z.nrow,
|
|
1101
|
+
inputs=[
|
|
1102
|
+
alpha,
|
|
1103
|
+
work_arrays._old_z_offsets if x == z else x.offsets,
|
|
1104
|
+
work_arrays._old_z_columns if x == z else x.columns,
|
|
1105
|
+
work_arrays._old_z_values if x == z else x.values,
|
|
1106
|
+
work_arrays._old_z_offsets if y == z else y.offsets,
|
|
1107
|
+
work_arrays._old_z_columns if y == z else y.columns,
|
|
1108
|
+
work_arrays._old_z_values if y == z else y.values,
|
|
1109
|
+
z.offsets,
|
|
1110
|
+
z.columns,
|
|
1111
|
+
mm_values,
|
|
1112
|
+
],
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
return z
|
|
1116
|
+
|
|
1117
|
+
|
|
1118
|
+
@wp.kernel
|
|
1119
|
+
def _bsr_mv_kernel(
|
|
1120
|
+
alpha: Any,
|
|
1121
|
+
A_offsets: wp.array(dtype=int),
|
|
1122
|
+
A_columns: wp.array(dtype=int),
|
|
1123
|
+
A_values: wp.array(dtype=Any),
|
|
1124
|
+
x: wp.array(dtype=Any),
|
|
1125
|
+
beta: Any,
|
|
1126
|
+
y: wp.array(dtype=Any),
|
|
1127
|
+
):
|
|
1128
|
+
row = wp.tid()
|
|
1129
|
+
|
|
1130
|
+
# zero-initialize with type of y elements
|
|
1131
|
+
scalar_zero = type(alpha)(0)
|
|
1132
|
+
v = y.dtype(scalar_zero)
|
|
1133
|
+
|
|
1134
|
+
if alpha != scalar_zero:
|
|
1135
|
+
beg = A_offsets[row]
|
|
1136
|
+
end = A_offsets[row + 1]
|
|
1137
|
+
for block in range(beg, end):
|
|
1138
|
+
v += A_values[block] * x[A_columns[block]]
|
|
1139
|
+
v *= alpha
|
|
1140
|
+
|
|
1141
|
+
if beta != scalar_zero:
|
|
1142
|
+
v += beta * y[row]
|
|
1143
|
+
|
|
1144
|
+
y[row] = v
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
def bsr_mv(
|
|
1148
|
+
A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
|
|
1149
|
+
x: "Array[Vector[Cols, Scalar] | Scalar]",
|
|
1150
|
+
y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1151
|
+
alpha: Scalar = 1.0,
|
|
1152
|
+
beta: Scalar = 0.0,
|
|
1153
|
+
work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
|
|
1154
|
+
) -> "Array[Vector[Rows, Scalar] | Scalar]":
|
|
1155
|
+
"""
|
|
1156
|
+
Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
|
|
1157
|
+
|
|
1158
|
+
The `x` and `y` vectors are allowed to alias.
|
|
1159
|
+
|
|
1160
|
+
Args:
|
|
1161
|
+
A: Read-only, left matrix factor of the matrix-vector product.
|
|
1162
|
+
x: Read-only, right vector factor of the matrix-vector product.
|
|
1163
|
+
y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
|
|
1164
|
+
alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
|
|
1165
|
+
beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
|
|
1166
|
+
work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
|
|
1167
|
+
will be used for this purpose, otherwise a temporary allocation will be performed.
|
|
1168
|
+
"""
|
|
1169
|
+
|
|
1170
|
+
if y is None:
|
|
1171
|
+
# If no output array is provided, allocate one for convenience
|
|
1172
|
+
y_vec_len = A.block_shape[0]
|
|
1173
|
+
y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
|
|
1174
|
+
y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
|
|
1175
|
+
y.zero_()
|
|
1176
|
+
beta = 0.0
|
|
1177
|
+
|
|
1178
|
+
if not isinstance(alpha, A.scalar_type):
|
|
1179
|
+
alpha = A.scalar_type(alpha)
|
|
1180
|
+
if not isinstance(beta, A.scalar_type):
|
|
1181
|
+
beta = A.scalar_type(beta)
|
|
1182
|
+
|
|
1183
|
+
if A.values.device != x.device or A.values.device != y.device:
|
|
1184
|
+
raise ValueError("A, x and y must reside on the same device")
|
|
1185
|
+
|
|
1186
|
+
if x.shape[0] != A.ncol:
|
|
1187
|
+
raise ValueError("Number of columns of A must match number of rows of x")
|
|
1188
|
+
if y.shape[0] != A.nrow:
|
|
1189
|
+
raise ValueError("Number of rows of A must match number of rows of y")
|
|
1190
|
+
|
|
1191
|
+
if x == y:
|
|
1192
|
+
# Aliasing case, need temporary storage
|
|
1193
|
+
if work_buffer is None:
|
|
1194
|
+
work_buffer = wp.empty_like(y)
|
|
1195
|
+
elif work_buffer.size < y.size:
|
|
1196
|
+
raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
|
|
1197
|
+
elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
|
|
1198
|
+
raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
|
|
1199
|
+
|
|
1200
|
+
# Save old y values before overwriting vector
|
|
1201
|
+
wp.copy(dest=work_buffer, src=y, count=y.size)
|
|
1202
|
+
x = work_buffer
|
|
1203
|
+
|
|
1204
|
+
# Promote scalar vectors to length-1 vecs and conversely
|
|
1205
|
+
if warp.types.type_is_matrix(A.values.dtype):
|
|
1206
|
+
if A.block_shape[0] == 1:
|
|
1207
|
+
if y.dtype == A.scalar_type:
|
|
1208
|
+
y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
1209
|
+
if A.block_shape[1] == 1:
|
|
1210
|
+
if x.dtype == A.scalar_type:
|
|
1211
|
+
x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
|
|
1212
|
+
else:
|
|
1213
|
+
if A.block_shape[0] == 1:
|
|
1214
|
+
if y.dtype != A.scalar_type:
|
|
1215
|
+
y = y.view(dtype=A.scalar_type)
|
|
1216
|
+
if A.block_shape[1] == 1:
|
|
1217
|
+
if x.dtype != A.scalar_type:
|
|
1218
|
+
x = x.view(dtype=A.scalar_type)
|
|
1219
|
+
|
|
1220
|
+
wp.launch(
|
|
1221
|
+
kernel=_bsr_mv_kernel,
|
|
1222
|
+
device=A.values.device,
|
|
1223
|
+
dim=A.nrow,
|
|
1224
|
+
inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
|
|
1225
|
+
)
|
|
1226
|
+
|
|
1227
|
+
return y
|