warp-lang 1.0.1__py3-none-manylinux2014_aarch64.whl → 1.1.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +115 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3425 -3354
  8. warp/codegen.py +2878 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +45 -45
  11. warp/context.py +5194 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +383 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -279
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
  34. warp/examples/benchmarks/benchmark_launches.py +295 -295
  35. warp/examples/browse.py +29 -28
  36. warp/examples/core/example_dem.py +234 -221
  37. warp/examples/core/example_fluid.py +293 -267
  38. warp/examples/core/example_graph_capture.py +144 -129
  39. warp/examples/core/example_marching_cubes.py +188 -176
  40. warp/examples/core/example_mesh.py +174 -154
  41. warp/examples/core/example_mesh_intersect.py +205 -193
  42. warp/examples/core/example_nvdb.py +176 -169
  43. warp/examples/core/example_raycast.py +105 -89
  44. warp/examples/core/example_raymarch.py +199 -178
  45. warp/examples/core/example_render_opengl.py +185 -141
  46. warp/examples/core/example_sph.py +405 -389
  47. warp/examples/core/example_torch.py +222 -181
  48. warp/examples/core/example_wave.py +263 -249
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +407 -391
  51. warp/examples/fem/example_convection_diffusion.py +182 -168
  52. warp/examples/fem/example_convection_diffusion_dg.py +219 -209
  53. warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
  54. warp/examples/fem/example_deformed_geometry.py +177 -159
  55. warp/examples/fem/example_diffusion.py +201 -173
  56. warp/examples/fem/example_diffusion_3d.py +177 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +221 -214
  58. warp/examples/fem/example_mixed_elasticity.py +244 -222
  59. warp/examples/fem/example_navier_stokes.py +259 -243
  60. warp/examples/fem/example_stokes.py +220 -192
  61. warp/examples/fem/example_stokes_transfer.py +265 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +260 -248
  65. warp/examples/optim/example_cloth_throw.py +222 -210
  66. warp/examples/optim/example_diffray.py +566 -535
  67. warp/examples/optim/example_drone.py +864 -835
  68. warp/examples/optim/example_inverse_kinematics.py +176 -169
  69. warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
  70. warp/examples/optim/example_spring_cage.py +239 -234
  71. warp/examples/optim/example_trajectory.py +223 -201
  72. warp/examples/optim/example_walker.py +306 -292
  73. warp/examples/sim/example_cartpole.py +139 -128
  74. warp/examples/sim/example_cloth.py +196 -184
  75. warp/examples/sim/example_granular.py +124 -113
  76. warp/examples/sim/example_granular_collision_sdf.py +197 -185
  77. warp/examples/sim/example_jacobian_ik.py +236 -213
  78. warp/examples/sim/example_particle_chain.py +118 -106
  79. warp/examples/sim/example_quadruped.py +193 -179
  80. warp/examples/sim/example_rigid_chain.py +197 -189
  81. warp/examples/sim/example_rigid_contact.py +189 -176
  82. warp/examples/sim/example_rigid_force.py +127 -126
  83. warp/examples/sim/example_rigid_gyroscopic.py +109 -97
  84. warp/examples/sim/example_rigid_soft_contact.py +134 -124
  85. warp/examples/sim/example_soft_body.py +190 -178
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +60 -27
  88. warp/fem/cache.py +401 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +15 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +744 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +441 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/partition.py +374 -376
  106. warp/fem/geometry/quadmesh_2d.py +532 -532
  107. warp/fem/geometry/tetmesh.py +840 -840
  108. warp/fem/geometry/trimesh_2d.py +577 -577
  109. warp/fem/integrate.py +1630 -1615
  110. warp/fem/operator.py +190 -191
  111. warp/fem/polynomial.py +214 -213
  112. warp/fem/quadrature/__init__.py +2 -2
  113. warp/fem/quadrature/pic_quadrature.py +243 -245
  114. warp/fem/quadrature/quadrature.py +295 -294
  115. warp/fem/space/__init__.py +294 -292
  116. warp/fem/space/basis_space.py +488 -489
  117. warp/fem/space/collocated_function_space.py +100 -105
  118. warp/fem/space/dof_mapper.py +236 -236
  119. warp/fem/space/function_space.py +148 -145
  120. warp/fem/space/grid_2d_function_space.py +267 -267
  121. warp/fem/space/grid_3d_function_space.py +305 -306
  122. warp/fem/space/hexmesh_function_space.py +350 -352
  123. warp/fem/space/partition.py +350 -350
  124. warp/fem/space/quadmesh_2d_function_space.py +368 -369
  125. warp/fem/space/restriction.py +158 -160
  126. warp/fem/space/shape/__init__.py +13 -15
  127. warp/fem/space/shape/cube_shape_function.py +738 -738
  128. warp/fem/space/shape/shape_function.py +102 -103
  129. warp/fem/space/shape/square_shape_function.py +611 -611
  130. warp/fem/space/shape/tet_shape_function.py +565 -567
  131. warp/fem/space/shape/triangle_shape_function.py +429 -429
  132. warp/fem/space/tetmesh_function_space.py +294 -292
  133. warp/fem/space/topology.py +297 -295
  134. warp/fem/space/trimesh_2d_function_space.py +223 -221
  135. warp/fem/types.py +77 -77
  136. warp/fem/utils.py +495 -495
  137. warp/jax.py +166 -141
  138. warp/jax_experimental.py +341 -339
  139. warp/native/array.h +1072 -1025
  140. warp/native/builtin.h +1560 -1560
  141. warp/native/bvh.cpp +398 -398
  142. warp/native/bvh.cu +525 -525
  143. warp/native/bvh.h +429 -429
  144. warp/native/clang/clang.cpp +495 -464
  145. warp/native/crt.cpp +31 -31
  146. warp/native/crt.h +334 -334
  147. warp/native/cuda_crt.h +1049 -1049
  148. warp/native/cuda_util.cpp +549 -540
  149. warp/native/cuda_util.h +288 -203
  150. warp/native/cutlass_gemm.cpp +34 -34
  151. warp/native/cutlass_gemm.cu +372 -372
  152. warp/native/error.cpp +66 -66
  153. warp/native/error.h +27 -27
  154. warp/native/fabric.h +228 -228
  155. warp/native/hashgrid.cpp +301 -278
  156. warp/native/hashgrid.cu +78 -77
  157. warp/native/hashgrid.h +227 -227
  158. warp/native/initializer_array.h +32 -32
  159. warp/native/intersect.h +1204 -1204
  160. warp/native/intersect_adj.h +365 -365
  161. warp/native/intersect_tri.h +322 -322
  162. warp/native/marching.cpp +2 -2
  163. warp/native/marching.cu +497 -497
  164. warp/native/marching.h +2 -2
  165. warp/native/mat.h +1498 -1498
  166. warp/native/matnn.h +333 -333
  167. warp/native/mesh.cpp +203 -203
  168. warp/native/mesh.cu +293 -293
  169. warp/native/mesh.h +1887 -1887
  170. warp/native/nanovdb/NanoVDB.h +4782 -4782
  171. warp/native/nanovdb/PNanoVDB.h +2553 -2553
  172. warp/native/nanovdb/PNanoVDBWrite.h +294 -294
  173. warp/native/noise.h +850 -850
  174. warp/native/quat.h +1084 -1084
  175. warp/native/rand.h +299 -299
  176. warp/native/range.h +108 -108
  177. warp/native/reduce.cpp +156 -156
  178. warp/native/reduce.cu +348 -348
  179. warp/native/runlength_encode.cpp +61 -61
  180. warp/native/runlength_encode.cu +46 -46
  181. warp/native/scan.cpp +30 -30
  182. warp/native/scan.cu +36 -36
  183. warp/native/scan.h +7 -7
  184. warp/native/solid_angle.h +442 -442
  185. warp/native/sort.cpp +94 -94
  186. warp/native/sort.cu +97 -97
  187. warp/native/sort.h +14 -14
  188. warp/native/sparse.cpp +337 -337
  189. warp/native/sparse.cu +544 -544
  190. warp/native/spatial.h +630 -630
  191. warp/native/svd.h +562 -562
  192. warp/native/temp_buffer.h +30 -30
  193. warp/native/vec.h +1132 -1132
  194. warp/native/volume.cpp +297 -297
  195. warp/native/volume.cu +32 -32
  196. warp/native/volume.h +538 -538
  197. warp/native/volume_builder.cu +425 -425
  198. warp/native/volume_builder.h +19 -19
  199. warp/native/warp.cpp +1057 -1052
  200. warp/native/warp.cu +2943 -2828
  201. warp/native/warp.h +313 -305
  202. warp/optim/__init__.py +9 -9
  203. warp/optim/adam.py +120 -120
  204. warp/optim/linear.py +1104 -939
  205. warp/optim/sgd.py +104 -92
  206. warp/render/__init__.py +10 -10
  207. warp/render/render_opengl.py +3217 -3204
  208. warp/render/render_usd.py +768 -749
  209. warp/render/utils.py +152 -150
  210. warp/sim/__init__.py +52 -59
  211. warp/sim/articulation.py +685 -685
  212. warp/sim/collide.py +1594 -1590
  213. warp/sim/import_mjcf.py +489 -481
  214. warp/sim/import_snu.py +220 -221
  215. warp/sim/import_urdf.py +536 -516
  216. warp/sim/import_usd.py +887 -881
  217. warp/sim/inertia.py +316 -317
  218. warp/sim/integrator.py +234 -233
  219. warp/sim/integrator_euler.py +1956 -1956
  220. warp/sim/integrator_featherstone.py +1910 -1991
  221. warp/sim/integrator_xpbd.py +3294 -3312
  222. warp/sim/model.py +4473 -4314
  223. warp/sim/particles.py +113 -112
  224. warp/sim/render.py +417 -403
  225. warp/sim/utils.py +413 -410
  226. warp/sparse.py +1227 -1227
  227. warp/stubs.py +2109 -2469
  228. warp/tape.py +1162 -225
  229. warp/tests/__init__.py +1 -1
  230. warp/tests/__main__.py +4 -4
  231. warp/tests/assets/torus.usda +105 -105
  232. warp/tests/aux_test_class_kernel.py +26 -26
  233. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  234. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  235. warp/tests/aux_test_dependent.py +22 -22
  236. warp/tests/aux_test_grad_customs.py +23 -23
  237. warp/tests/aux_test_reference.py +11 -11
  238. warp/tests/aux_test_reference_reference.py +10 -10
  239. warp/tests/aux_test_square.py +17 -17
  240. warp/tests/aux_test_unresolved_func.py +14 -14
  241. warp/tests/aux_test_unresolved_symbol.py +14 -14
  242. warp/tests/disabled_kinematics.py +239 -239
  243. warp/tests/run_coverage_serial.py +31 -31
  244. warp/tests/test_adam.py +157 -157
  245. warp/tests/test_arithmetic.py +1124 -1124
  246. warp/tests/test_array.py +2417 -2326
  247. warp/tests/test_array_reduce.py +150 -150
  248. warp/tests/test_async.py +668 -656
  249. warp/tests/test_atomic.py +141 -141
  250. warp/tests/test_bool.py +204 -149
  251. warp/tests/test_builtins_resolution.py +1292 -1292
  252. warp/tests/test_bvh.py +164 -171
  253. warp/tests/test_closest_point_edge_edge.py +228 -228
  254. warp/tests/test_codegen.py +566 -553
  255. warp/tests/test_compile_consts.py +97 -101
  256. warp/tests/test_conditional.py +246 -246
  257. warp/tests/test_copy.py +232 -215
  258. warp/tests/test_ctypes.py +632 -632
  259. warp/tests/test_dense.py +67 -67
  260. warp/tests/test_devices.py +91 -98
  261. warp/tests/test_dlpack.py +530 -529
  262. warp/tests/test_examples.py +400 -378
  263. warp/tests/test_fabricarray.py +955 -955
  264. warp/tests/test_fast_math.py +62 -54
  265. warp/tests/test_fem.py +1277 -1278
  266. warp/tests/test_fp16.py +130 -130
  267. warp/tests/test_func.py +338 -337
  268. warp/tests/test_generics.py +571 -571
  269. warp/tests/test_grad.py +746 -640
  270. warp/tests/test_grad_customs.py +333 -336
  271. warp/tests/test_hash_grid.py +210 -164
  272. warp/tests/test_import.py +39 -39
  273. warp/tests/test_indexedarray.py +1134 -1134
  274. warp/tests/test_intersect.py +67 -67
  275. warp/tests/test_jax.py +307 -307
  276. warp/tests/test_large.py +167 -164
  277. warp/tests/test_launch.py +354 -354
  278. warp/tests/test_lerp.py +261 -261
  279. warp/tests/test_linear_solvers.py +191 -171
  280. warp/tests/test_lvalue.py +421 -493
  281. warp/tests/test_marching_cubes.py +65 -65
  282. warp/tests/test_mat.py +1801 -1827
  283. warp/tests/test_mat_lite.py +115 -115
  284. warp/tests/test_mat_scalar_ops.py +2907 -2889
  285. warp/tests/test_math.py +126 -193
  286. warp/tests/test_matmul.py +500 -499
  287. warp/tests/test_matmul_lite.py +410 -410
  288. warp/tests/test_mempool.py +188 -190
  289. warp/tests/test_mesh.py +284 -324
  290. warp/tests/test_mesh_query_aabb.py +228 -241
  291. warp/tests/test_mesh_query_point.py +692 -702
  292. warp/tests/test_mesh_query_ray.py +292 -303
  293. warp/tests/test_mlp.py +276 -276
  294. warp/tests/test_model.py +110 -110
  295. warp/tests/test_modules_lite.py +39 -39
  296. warp/tests/test_multigpu.py +163 -163
  297. warp/tests/test_noise.py +248 -248
  298. warp/tests/test_operators.py +250 -250
  299. warp/tests/test_options.py +123 -125
  300. warp/tests/test_peer.py +133 -137
  301. warp/tests/test_pinned.py +78 -78
  302. warp/tests/test_print.py +54 -54
  303. warp/tests/test_quat.py +2086 -2086
  304. warp/tests/test_rand.py +288 -288
  305. warp/tests/test_reload.py +217 -217
  306. warp/tests/test_rounding.py +179 -179
  307. warp/tests/test_runlength_encode.py +190 -190
  308. warp/tests/test_sim_grad.py +243 -0
  309. warp/tests/test_sim_kinematics.py +91 -97
  310. warp/tests/test_smoothstep.py +168 -168
  311. warp/tests/test_snippet.py +305 -266
  312. warp/tests/test_sparse.py +468 -460
  313. warp/tests/test_spatial.py +2148 -2148
  314. warp/tests/test_streams.py +486 -473
  315. warp/tests/test_struct.py +710 -675
  316. warp/tests/test_tape.py +173 -148
  317. warp/tests/test_torch.py +743 -743
  318. warp/tests/test_transient_module.py +87 -87
  319. warp/tests/test_types.py +556 -659
  320. warp/tests/test_utils.py +490 -499
  321. warp/tests/test_vec.py +1264 -1268
  322. warp/tests/test_vec_lite.py +73 -73
  323. warp/tests/test_vec_scalar_ops.py +2099 -2099
  324. warp/tests/test_verify_fp.py +94 -94
  325. warp/tests/test_volume.py +737 -736
  326. warp/tests/test_volume_write.py +255 -265
  327. warp/tests/unittest_serial.py +37 -37
  328. warp/tests/unittest_suites.py +363 -359
  329. warp/tests/unittest_utils.py +603 -578
  330. warp/tests/unused_test_misc.py +71 -71
  331. warp/tests/walkthrough_debug.py +85 -85
  332. warp/thirdparty/appdirs.py +598 -598
  333. warp/thirdparty/dlpack.py +143 -143
  334. warp/thirdparty/unittest_parallel.py +566 -561
  335. warp/torch.py +321 -295
  336. warp/types.py +4504 -4450
  337. warp/utils.py +1008 -821
  338. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
  339. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
  340. warp_lang-1.1.0.dist-info/RECORD +352 -0
  341. warp/examples/assets/cube.usda +0 -42
  342. warp/examples/assets/sphere.usda +0 -56
  343. warp/examples/assets/torus.usda +0 -105
  344. warp_lang-1.0.1.dist-info/RECORD +0 -352
  345. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
  346. {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -1,1227 +1,1227 @@
1
- from typing import Any, Generic, Optional, Tuple, TypeVar, Union
2
-
3
- import warp as wp
4
- import warp.types
5
- import warp.utils
6
- from warp.types import Array, Cols, Matrix, Rows, Scalar, Vector
7
-
8
- # typing hints
9
-
10
- _BlockType = TypeVar("BlockType")
11
-
12
-
13
- class _MatrixBlockType(Matrix):
14
- pass
15
-
16
-
17
- class _ScalarBlockType(Generic[Scalar]):
18
- pass
19
-
20
-
21
- BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
22
-
23
- _struct_cache = dict()
24
-
25
-
26
- class BsrMatrix(Generic[_BlockType]):
27
- """Untyped base class for BSR and CSR matrices.
28
-
29
- Should not be constructed directly but through functions such as :func:`bsr_zeros`.
30
-
31
- Attributes:
32
- nrow (int): Number of rows of blocks
33
- ncol (int): Number of columns of blocks
34
- nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
35
- offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
36
- columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
37
- values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
38
- """
39
-
40
- @property
41
- def scalar_type(self) -> Scalar:
42
- """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
43
- return warp.types.type_scalar_type(self.values.dtype)
44
-
45
- @property
46
- def block_shape(self) -> Tuple[int, int]:
47
- """Shape of the individual blocks"""
48
- return getattr(self.values.dtype, "_shape_", (1, 1))
49
-
50
- @property
51
- def block_size(self) -> int:
52
- """Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
53
- return warp.types.type_length(self.values.dtype)
54
-
55
- @property
56
- def shape(self) -> Tuple[int, int]:
57
- """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
58
- block_shape = self.block_shape
59
- return (self.nrow * block_shape[0], self.ncol * block_shape[1])
60
-
61
- @property
62
- def dtype(self) -> type:
63
- """Data type for individual block values"""
64
- return self.values.dtype
65
-
66
- @property
67
- def device(self) -> wp.context.Device:
68
- """Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays """
69
- return self.values.device
70
-
71
-
72
- def bsr_matrix_t(dtype: BlockType):
73
- dtype = wp.types.type_to_warp(dtype)
74
-
75
- if not warp.types.type_is_matrix(dtype) and not dtype in warp.types.scalar_types:
76
- raise ValueError(
77
- f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
78
- )
79
-
80
- class BsrMatrixTyped(BsrMatrix):
81
- nrow: int
82
- """Number of rows of blocks"""
83
- ncol: int
84
- """Number of columns of blocks"""
85
- nnz: int
86
- """Number of non-zero blocks: equal to offsets[-1], cached on host for convenience"""
87
- offsets: wp.array(dtype=int)
88
- """Array of size at least 1 + nrows"""
89
- columns: wp.array(dtype=int)
90
- """Array of size at least equal to nnz"""
91
- values: wp.array(dtype=dtype)
92
-
93
- module = wp.get_module(BsrMatrix.__module__)
94
-
95
- if hasattr(dtype, "_shape_"):
96
- type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
97
- else:
98
- type_str = dtype.__name__
99
- key = f"{BsrMatrix.__qualname__}_{type_str}"
100
-
101
- if key not in _struct_cache:
102
- _struct_cache[key] = wp.codegen.Struct(
103
- cls=BsrMatrixTyped,
104
- key=key,
105
- module=module,
106
- )
107
-
108
- return _struct_cache[key]
109
-
110
-
111
- def bsr_zeros(
112
- rows_of_blocks: int,
113
- cols_of_blocks: int,
114
- block_type: BlockType,
115
- device: wp.context.Devicelike = None,
116
- ) -> BsrMatrix:
117
- """
118
- Constructs and returns an empty BSR or CSR matrix with the given shape
119
-
120
- Args:
121
- bsr: The BSR or CSR matrix to set to zero
122
- rows_of_blocks: Number of rows of blocks
123
- cols_of_blocks: Number of columns of blocks
124
- block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
125
- for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
126
- device: Device on which to allocate the matrix arrays
127
- """
128
-
129
- bsr = bsr_matrix_t(block_type)()
130
-
131
- bsr.nrow = rows_of_blocks
132
- bsr.ncol = cols_of_blocks
133
- bsr.nnz = 0
134
- bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
135
- bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
136
- bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
137
-
138
- return bsr
139
-
140
-
141
- def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
142
- if nrow is None:
143
- nrow = bsr.nrow
144
- if nnz is None:
145
- nnz = bsr.nnz
146
-
147
- if bsr.offsets.size < nrow + 1:
148
- bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
149
- if bsr.columns.size < nnz:
150
- bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
151
- if bsr.values.size < nnz:
152
- bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
153
-
154
-
155
- def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
156
- """
157
- Sets a BSR matrix to zero, possibly changing its size
158
-
159
- Args:
160
- bsr: The BSR or CSR matrix to set to zero
161
- rows_of_blocks: If not ``None``, the new number of rows of blocks
162
- cols_of_blocks: If not ``None``, the new number of columns of blocks
163
- """
164
-
165
- if rows_of_blocks is not None:
166
- bsr.nrow = rows_of_blocks
167
- if cols_of_blocks is not None:
168
- bsr.ncol = cols_of_blocks
169
- bsr.nnz = 0
170
- _bsr_ensure_fits(bsr)
171
- bsr.offsets.zero_()
172
-
173
-
174
- def bsr_set_from_triplets(
175
- dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
176
- rows: "Array[int]",
177
- columns: "Array[int]",
178
- values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
179
- ):
180
- """
181
- Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
182
-
183
- The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
184
-
185
- Args:
186
- dest: Sparse matrix to populate
187
- rows: Row index for each non-zero
188
- columns: Columns index for each non-zero
189
- values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
190
- to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
191
- """
192
-
193
- if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
194
- raise ValueError("All arguments must reside on the same device")
195
-
196
- if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]:
197
- raise ValueError("All triplet arrays must have the same length")
198
-
199
- # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
200
- if values.ndim == 1:
201
- if values.dtype != dest.values.dtype:
202
- raise ValueError("Values array type must correspond to that of dest matrix")
203
- elif values.ndim == 3:
204
- if values.shape[1:] != dest.block_shape:
205
- raise ValueError(
206
- f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
207
- )
208
-
209
- if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
210
- raise ValueError("Scalar type of values array should correspond to that of matrix")
211
-
212
- if not values.is_contiguous:
213
- raise ValueError("Multi-dimensional values array should be contiguous")
214
- else:
215
- raise ValueError("Number of dimension for values array should be 1 or 3")
216
-
217
- nnz = rows.shape[0]
218
- if nnz == 0:
219
- bsr_set_zero(dest)
220
- return
221
-
222
- # Increase dest array sizes if needed
223
- _bsr_ensure_fits(dest, nnz=nnz)
224
-
225
- device = dest.values.device
226
- scalar_type = dest.scalar_type
227
- from warp.context import runtime
228
-
229
- if device.is_cpu:
230
- if scalar_type == wp.float32:
231
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
232
- elif scalar_type == wp.float64:
233
- native_func = runtime.core.bsr_matrix_from_triplets_double_host
234
- else:
235
- if scalar_type == wp.float32:
236
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
237
- elif scalar_type == wp.float64:
238
- native_func = runtime.core.bsr_matrix_from_triplets_double_device
239
-
240
- if not native_func:
241
- raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
242
-
243
- dest.nnz = native_func(
244
- dest.block_shape[0],
245
- dest.block_shape[1],
246
- dest.nrow,
247
- nnz,
248
- rows.ptr,
249
- columns.ptr,
250
- values.ptr,
251
- dest.offsets.ptr,
252
- dest.columns.ptr,
253
- dest.values.ptr,
254
- )
255
-
256
-
257
- def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
258
- """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
259
-
260
- if dest.values.device != src.values.device:
261
- raise ValueError("Source and destination matrices must reside on the same device")
262
-
263
- if dest.block_shape != src.block_shape:
264
- raise ValueError("Source and destination matrices must have the same block shape")
265
-
266
- dest.nrow = src.nrow
267
- dest.ncol = src.ncol
268
- dest.nnz = src.nnz
269
-
270
- _bsr_ensure_fits(dest)
271
-
272
- wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
273
- if src.nnz > 0:
274
- wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
275
- warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
276
-
277
-
278
- def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
279
- """Returns a copy of matrix ``A``, possibly changing its scalar type.
280
-
281
- Args:
282
- scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
283
- """
284
- if scalar_type is None:
285
- block_type = A.values.dtype
286
- elif A.block_shape == (1, 1):
287
- block_type = scalar_type
288
- else:
289
- block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type)
290
-
291
- copy = bsr_zeros(rows_of_blocks=A.nrow, cols_of_blocks=A.ncol, block_type=block_type, device=A.values.device)
292
- bsr_assign(dest=copy, src=A)
293
- return copy
294
-
295
-
296
- def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
297
- """Assigns the transposed matrix `src` to matrix `dest`"""
298
-
299
- if dest.values.device != src.values.device:
300
- raise ValueError("All arguments must reside on the same device")
301
-
302
- if dest.scalar_type != src.scalar_type:
303
- raise ValueError("All arguments must have the same scalar type")
304
-
305
- transpose_block_shape = src.block_shape[::-1]
306
-
307
- if dest.block_shape != transpose_block_shape:
308
- raise ValueError(f"Destination block shape must be {transpose_block_shape}")
309
-
310
- dest.nrow = src.ncol
311
- dest.ncol = src.nrow
312
- dest.nnz = src.nnz
313
-
314
- if src.nnz == 0:
315
- return
316
-
317
- # Increase dest array sizes if needed
318
- _bsr_ensure_fits(dest)
319
-
320
- from warp.context import runtime
321
-
322
- if dest.values.device.is_cpu:
323
- if dest.scalar_type == wp.float32:
324
- native_func = runtime.core.bsr_transpose_float_host
325
- elif dest.scalar_type == wp.float64:
326
- native_func = runtime.core.bsr_transpose_double_host
327
- else:
328
- if dest.scalar_type == wp.float32:
329
- native_func = runtime.core.bsr_transpose_float_device
330
- elif dest.scalar_type == wp.float64:
331
- native_func = runtime.core.bsr_transpose_double_device
332
-
333
- if not native_func:
334
- raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
335
-
336
- native_func(
337
- src.block_shape[0],
338
- src.block_shape[1],
339
- src.nrow,
340
- src.ncol,
341
- src.nnz,
342
- src.offsets.ptr,
343
- src.columns.ptr,
344
- src.values.ptr,
345
- dest.offsets.ptr,
346
- dest.columns.ptr,
347
- dest.values.ptr,
348
- )
349
-
350
-
351
- def bsr_transposed(A: BsrMatrix):
352
- """Returns a copy of the transposed matrix `A`"""
353
-
354
- if A.block_shape == (1, 1):
355
- block_type = A.values.dtype
356
- else:
357
- block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type)
358
-
359
- transposed = bsr_zeros(rows_of_blocks=A.ncol, cols_of_blocks=A.nrow, block_type=block_type, device=A.values.device)
360
- bsr_set_transpose(dest=transposed, src=A)
361
- return transposed
362
-
363
-
364
- @wp.kernel
365
- def _bsr_get_diag_kernel(
366
- A_offsets: wp.array(dtype=int),
367
- A_columns: wp.array(dtype=int),
368
- A_values: wp.array(dtype=Any),
369
- out: wp.array(dtype=Any),
370
- ):
371
- row = wp.tid()
372
- beg = A_offsets[row]
373
- end = A_offsets[row + 1]
374
-
375
- diag = wp.lower_bound(A_columns, beg, end, row)
376
- if diag < end:
377
- if A_columns[diag] == row:
378
- out[row] = A_values[diag]
379
-
380
-
381
- def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
382
- """Returns the array of blocks that constitute the diagonal of a sparse matrix.
383
-
384
- Args:
385
- A: the sparse matrix from which to extract the diagonal
386
- out: if provided, the array into which to store the diagonal blocks
387
- """
388
-
389
- dim = min(A.nrow, A.ncol)
390
-
391
- if out is None:
392
- out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
393
- else:
394
- if out.dtype != A.values.dtype:
395
- raise ValueError(f"Output array must have type {A.values.dtype}")
396
- if out.device != A.values.device:
397
- raise ValueError(f"Output array must reside on device {A.values.device}")
398
- if out.shape[0] < dim:
399
- raise ValueError(f"Output array must be of length at least {dim}")
400
-
401
- wp.launch(
402
- kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
403
- )
404
-
405
- return out
406
-
407
-
408
- @wp.kernel
409
- def _bsr_set_diag_kernel(
410
- diag: wp.array(dtype=Any),
411
- A_offsets: wp.array(dtype=int),
412
- A_columns: wp.array(dtype=int),
413
- A_values: wp.array(dtype=Any),
414
- ):
415
- row = wp.tid()
416
- A_offsets[row + 1] = row + 1
417
- A_columns[row] = row
418
- A_values[row] = diag[row]
419
-
420
- if row == 0:
421
- A_offsets[0] = 0
422
-
423
-
424
- @wp.kernel
425
- def _bsr_set_diag_constant_kernel(
426
- diag_value: Any,
427
- A_offsets: wp.array(dtype=int),
428
- A_columns: wp.array(dtype=int),
429
- A_values: wp.array(dtype=Any),
430
- ):
431
- row = wp.tid()
432
- A_offsets[row + 1] = row + 1
433
- A_columns[row] = row
434
- A_values[row] = diag_value
435
-
436
- if row == 0:
437
- A_offsets[0] = 0
438
-
439
-
440
- def bsr_set_diag(
441
- A: BsrMatrix[BlockType],
442
- diag: "Union[BlockType, Array[BlockType]]",
443
- rows_of_blocks: Optional[int] = None,
444
- cols_of_blocks: Optional[int] = None,
445
- ):
446
- """Sets `A` as a block-diagonal matrix
447
-
448
- Args:
449
- A: the sparse matrix to modify
450
- diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
451
- or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
452
- rows_of_blocks: If not ``None``, the new number of rows of blocks
453
- cols_of_blocks: If not ``None``, the new number of columns of blocks
454
-
455
- The shape of the matrix will be defined one of the following, in that order:
456
- - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
457
- - the first dimension of `diag`, if `diag` is an array
458
- - the current dimensions of `A` otherwise
459
- """
460
-
461
- if rows_of_blocks is None and cols_of_blocks is not None:
462
- rows_of_blocks = cols_of_blocks
463
- if cols_of_blocks is None and rows_of_blocks is not None:
464
- cols_of_blocks = rows_of_blocks
465
-
466
- if warp.types.is_array(diag):
467
- if rows_of_blocks is None:
468
- rows_of_blocks = diag.shape[0]
469
- cols_of_blocks = diag.shape[0]
470
-
471
- if rows_of_blocks is not None:
472
- A.nrow = rows_of_blocks
473
- A.ncol = cols_of_blocks
474
-
475
- A.nnz = min(A.nrow, A.ncol)
476
- _bsr_ensure_fits(A)
477
-
478
- if warp.types.is_array(diag):
479
- wp.launch(
480
- kernel=_bsr_set_diag_kernel,
481
- dim=A.nnz,
482
- device=A.values.device,
483
- inputs=[diag, A.offsets, A.columns, A.values],
484
- )
485
- else:
486
- if not warp.types.type_is_value(type(diag)):
487
- # Cast to launchable type
488
- diag = A.values.dtype(diag)
489
- wp.launch(
490
- kernel=_bsr_set_diag_constant_kernel,
491
- dim=A.nnz,
492
- device=A.values.device,
493
- inputs=[diag, A.offsets, A.columns, A.values],
494
- )
495
-
496
-
497
- def bsr_diag(
498
- diag: "Union[BlockType, Array[BlockType]]",
499
- rows_of_blocks: Optional[int] = None,
500
- cols_of_blocks: Optional[int] = None,
501
- ) -> BsrMatrix["BlockType"]:
502
- """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
503
-
504
- Args:
505
- diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
506
- or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
507
- rows_of_blocks: If not ``None``, the new number of rows of blocks
508
- cols_of_blocks: If not ``None``, the new number of columns of blocks
509
-
510
- The shape of the matrix will be defined one of the following, in that order:
511
- - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
512
- - the first dimension of `diag`, if `diag` is an array
513
- """
514
-
515
- if rows_of_blocks is None and cols_of_blocks is not None:
516
- rows_of_blocks = cols_of_blocks
517
- if cols_of_blocks is None and rows_of_blocks is not None:
518
- cols_of_blocks = rows_of_blocks
519
-
520
- if warp.types.is_array(diag):
521
- if rows_of_blocks is None:
522
- rows_of_blocks = diag.shape[0]
523
- cols_of_blocks = diag.shape[0]
524
-
525
- A = bsr_zeros(
526
- rows_of_blocks,
527
- cols_of_blocks,
528
- block_type=diag.dtype,
529
- device=diag.device,
530
- )
531
- else:
532
- if rows_of_blocks is None:
533
- raise ValueError(
534
- "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
535
- )
536
-
537
- block_type = type(diag)
538
- if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
539
- block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
540
-
541
- A = bsr_zeros(
542
- rows_of_blocks,
543
- cols_of_blocks,
544
- block_type=block_type,
545
- )
546
-
547
- bsr_set_diag(A, diag)
548
- return A
549
-
550
-
551
- def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
552
- """Sets `A` as the identity matrix
553
-
554
- Args:
555
- A: the sparse matrix to modify
556
- rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
557
- """
558
-
559
- if A.block_shape == (1, 1):
560
- identity = A.scalar_type(1.0)
561
- else:
562
- from numpy import eye
563
-
564
- identity = eye(A.block_shape[0])
565
-
566
- bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
567
-
568
-
569
- def bsr_identity(
570
- rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
571
- ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
572
- """Creates and returns a square identity matrix.
573
-
574
- Args:
575
- rows_of_blocks: Number of rows and columns of blocks in the created matrix.
576
- block_type: Block type for the newly created matrix -- must be square
577
- device: Device onto which to allocate the data arrays
578
- """
579
- A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
580
- bsr_set_identity(A)
581
- return A
582
-
583
-
584
- @wp.kernel
585
- def _bsr_scale_kernel(
586
- alpha: Any,
587
- values: wp.array(dtype=Any),
588
- ):
589
- values[wp.tid()] = alpha * values[wp.tid()]
590
-
591
-
592
- def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
593
- """
594
- Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
595
- """
596
-
597
- if alpha != 1.0 and x.nnz > 0:
598
- if alpha == 0.0:
599
- bsr_set_zero(x)
600
- else:
601
- if not isinstance(alpha, x.scalar_type):
602
- alpha = x.scalar_type(alpha)
603
-
604
- wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
605
-
606
- return x
607
-
608
-
609
- @wp.kernel
610
- def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
611
- i = wp.tid()
612
-
613
- row = wp.lower_bound(bsr_offsets, i + 1) - 1
614
- rows[dest_offset + i] = row
615
-
616
-
617
- @wp.kernel
618
- def _bsr_axpy_add_block(
619
- src_offset: int,
620
- scale: Any,
621
- rows: wp.array(dtype=int),
622
- cols: wp.array(dtype=int),
623
- dst_offsets: wp.array(dtype=int),
624
- dst_columns: wp.array(dtype=int),
625
- src_values: wp.array(dtype=Any),
626
- dst_values: wp.array(dtype=Any),
627
- ):
628
- i = wp.tid()
629
- row = rows[i + src_offset]
630
- col = cols[i + src_offset]
631
- beg = dst_offsets[row]
632
- end = dst_offsets[row + 1]
633
-
634
- block = wp.lower_bound(dst_columns, beg, end, col)
635
-
636
- dst_values[block] = dst_values[block] + scale * src_values[i]
637
-
638
-
639
- class bsr_axpy_work_arrays:
640
- """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
641
-
642
- def __init__(self):
643
- self._reset(None)
644
-
645
- def _reset(self, device):
646
- self.device = device
647
- self._sum_rows = None
648
- self._sum_cols = None
649
- self._old_y_values = None
650
- self._old_x_values = None
651
-
652
- def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
653
- if self.device != device:
654
- self._reset(device)
655
-
656
- if self._sum_rows is None or self._sum_rows.size < sum_nnz:
657
- self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
658
- if self._sum_cols is None or self._sum_cols.size < sum_nnz:
659
- self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
660
-
661
- if self._old_y_values is None or self._old_y_values.size < y.nnz:
662
- self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
663
-
664
-
665
- def bsr_axpy(
666
- x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
667
- y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
668
- alpha: Scalar = 1.0,
669
- beta: Scalar = 1.0,
670
- work_arrays: Optional[bsr_axpy_work_arrays] = None,
671
- ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
672
- """
673
- Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
674
-
675
- The `x` and `y` matrices are allowed to alias.
676
-
677
- Args:
678
- x: Read-only right-hand-side.
679
- y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
680
- alpha: Uniform scaling factor for `x`
681
- beta: Uniform scaling factor for `y`
682
- work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
683
- """
684
-
685
- if y is None:
686
- # If not output matrix is provided, allocate it for convenience
687
- y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
688
- beta = 0.0
689
-
690
- # Handle easy cases first
691
- if beta == 0.0 or y.nnz == 0:
692
- bsr_assign(src=x, dest=y)
693
- return bsr_scale(y, alpha=alpha)
694
-
695
- if alpha == 0.0 or x.nnz == 0:
696
- return bsr_scale(y, alpha=beta)
697
-
698
- if not isinstance(alpha, y.scalar_type):
699
- alpha = y.scalar_type(alpha)
700
- if not isinstance(beta, y.scalar_type):
701
- beta = y.scalar_type(beta)
702
-
703
- if x == y:
704
- # Aliasing case
705
- return bsr_scale(y, alpha=alpha.value + beta.value)
706
-
707
- # General case
708
-
709
- if x.values.device != y.values.device:
710
- raise ValueError("All arguments must reside on the same device")
711
-
712
- if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
713
- raise ValueError("Matrices must have the same block type")
714
-
715
- if x.nrow != y.nrow or x.ncol != y.ncol:
716
- raise ValueError("Matrices must have the same number of rows and columns")
717
-
718
- if work_arrays is None:
719
- work_arrays = bsr_axpy_work_arrays()
720
-
721
- sum_nnz = x.nnz + y.nnz
722
- device = y.values.device
723
- work_arrays._allocate(device, y, sum_nnz)
724
-
725
- wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
726
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
727
-
728
- wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
729
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
730
-
731
- # Save old y values before overwriting matrix
732
- wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
733
-
734
- # Increase dest array sizes if needed
735
- if y.columns.shape[0] < sum_nnz:
736
- y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device)
737
-
738
- from warp.context import runtime
739
-
740
- if device.is_cpu:
741
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
742
- else:
743
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
744
-
745
- old_y_nnz = y.nnz
746
- y.nnz = native_func(
747
- y.block_shape[0],
748
- y.block_shape[1],
749
- y.nrow,
750
- sum_nnz,
751
- work_arrays._sum_rows.ptr,
752
- work_arrays._sum_cols.ptr,
753
- 0,
754
- y.offsets.ptr,
755
- y.columns.ptr,
756
- 0,
757
- )
758
-
759
- _bsr_ensure_fits(y)
760
- y.values.zero_()
761
-
762
- wp.launch(
763
- kernel=_bsr_axpy_add_block,
764
- device=device,
765
- dim=old_y_nnz,
766
- inputs=[
767
- 0,
768
- beta,
769
- work_arrays._sum_rows,
770
- work_arrays._sum_cols,
771
- y.offsets,
772
- y.columns,
773
- work_arrays._old_y_values,
774
- y.values,
775
- ],
776
- )
777
-
778
- wp.launch(
779
- kernel=_bsr_axpy_add_block,
780
- device=device,
781
- dim=x.nnz,
782
- inputs=[
783
- old_y_nnz,
784
- alpha,
785
- work_arrays._sum_rows,
786
- work_arrays._sum_cols,
787
- y.offsets,
788
- y.columns,
789
- x.values,
790
- y.values,
791
- ],
792
- )
793
-
794
- return y
795
-
796
-
797
- @wp.kernel
798
- def _bsr_mm_count_coeffs(
799
- z_nnz: int,
800
- x_offsets: wp.array(dtype=int),
801
- x_columns: wp.array(dtype=int),
802
- y_offsets: wp.array(dtype=int),
803
- counts: wp.array(dtype=int),
804
- ):
805
- row = wp.tid()
806
- count = int(0)
807
-
808
- x_beg = x_offsets[row]
809
- x_end = x_offsets[row + 1]
810
-
811
- for x_block in range(x_beg, x_end):
812
- x_col = x_columns[x_block]
813
- count += y_offsets[x_col + 1] - y_offsets[x_col]
814
-
815
- counts[row + 1] = count
816
-
817
- if row == 0:
818
- counts[0] = z_nnz
819
-
820
-
821
- @wp.kernel
822
- def _bsr_mm_list_coeffs(
823
- x_offsets: wp.array(dtype=int),
824
- x_columns: wp.array(dtype=int),
825
- y_offsets: wp.array(dtype=int),
826
- y_columns: wp.array(dtype=int),
827
- mm_offsets: wp.array(dtype=int),
828
- mm_rows: wp.array(dtype=int),
829
- mm_cols: wp.array(dtype=int),
830
- ):
831
- row = wp.tid()
832
- mm_block = mm_offsets[row]
833
-
834
- x_beg = x_offsets[row]
835
- x_end = x_offsets[row + 1]
836
-
837
- for x_block in range(x_beg, x_end):
838
- x_col = x_columns[x_block]
839
-
840
- y_beg = y_offsets[x_col]
841
- y_end = y_offsets[x_col + 1]
842
- for y_block in range(y_beg, y_end):
843
- mm_cols[mm_block] = y_columns[y_block]
844
- mm_rows[mm_block] = row
845
- mm_block += 1
846
-
847
-
848
- @wp.kernel
849
- def _bsr_mm_compute_values(
850
- alpha: Any,
851
- x_offsets: wp.array(dtype=int),
852
- x_columns: wp.array(dtype=int),
853
- x_values: wp.array(dtype=Any),
854
- y_offsets: wp.array(dtype=int),
855
- y_columns: wp.array(dtype=int),
856
- y_values: wp.array(dtype=Any),
857
- mm_offsets: wp.array(dtype=int),
858
- mm_cols: wp.array(dtype=int),
859
- mm_values: wp.array(dtype=Any),
860
- ):
861
- row = wp.tid()
862
- mm_beg = mm_offsets[row]
863
- mm_end = mm_offsets[row + 1]
864
-
865
- x_beg = x_offsets[row]
866
- x_end = x_offsets[row + 1]
867
- for x_block in range(x_beg, x_end):
868
- x_col = x_columns[x_block]
869
- ax_val = alpha * x_values[x_block]
870
-
871
- y_beg = y_offsets[x_col]
872
- y_end = y_offsets[x_col + 1]
873
-
874
- for y_block in range(y_beg, y_end):
875
- mm_block = wp.lower_bound(mm_cols, mm_beg, mm_end, y_columns[y_block])
876
- mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
877
-
878
-
879
- class bsr_mm_work_arrays:
880
- """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
881
-
882
- def __init__(self):
883
- self._reset(None)
884
-
885
- def _reset(self, device):
886
- self.device = device
887
- self._pinned_count_buffer = None
888
- self._mm_row_counts = None
889
- self._mm_rows = None
890
- self._mm_cols = None
891
- self._old_z_values = None
892
- self._old_z_offsets = None
893
- self._old_z_columns = None
894
-
895
- def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
896
- if self.device != device:
897
- self._reset(device)
898
-
899
- # Allocations that do not depend on any computation
900
- if self.device.is_cuda:
901
- if self._pinned_count_buffer is None:
902
- self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
903
-
904
- if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
905
- self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
906
-
907
- if copied_z_nnz > 0:
908
- if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
909
- self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
910
-
911
- if z_aliasing:
912
- if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
913
- self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
914
- if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
915
- self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
916
-
917
- def _allocate_stage_2(self, mm_nnz: int):
918
- # Allocations that depend on unmerged nnz estimate
919
- if self._mm_rows is None or self._mm_rows.size < mm_nnz:
920
- self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
921
- if self._mm_cols is None or self._mm_cols.size < mm_nnz:
922
- self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
923
-
924
-
925
- def bsr_mm(
926
- x: BsrMatrix[BlockType[Rows, Any, Scalar]],
927
- y: BsrMatrix[BlockType[Any, Cols, Scalar]],
928
- z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
929
- alpha: Scalar = 1.0,
930
- beta: Scalar = 0.0,
931
- work_arrays: Optional[bsr_mm_work_arrays] = None,
932
- ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
933
- """
934
- Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
935
-
936
- The `x`, `y` and `z` matrices are allowed to alias.
937
- If the matrix `z` is not provided as input, it will be allocated and treated as zero.
938
-
939
- Args:
940
- x: Read-only left factor of the matrix-matrix product.
941
- y: Read-only right factor of the matrix-matrix product.
942
- z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
943
- alpha: Uniform scaling factor for the ``x * y`` product
944
- beta: Uniform scaling factor for `z`
945
- work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
946
- """
947
-
948
- if z is None:
949
- # If not output matrix is provided, allocate it for convenience
950
- z_block_shape = (x.block_shape[0], y.block_shape[1])
951
- if z_block_shape == (1, 1):
952
- z_block_type = x.scalar_type
953
- else:
954
- z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type)
955
- z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
956
- beta = 0.0
957
-
958
- if x.values.device != y.values.device or x.values.device != z.values.device:
959
- raise ValueError("All arguments must reside on the same device")
960
-
961
- if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
962
- raise ValueError("Matrices must have the same scalar type")
963
-
964
- if (
965
- x.block_shape[0] != z.block_shape[0]
966
- or y.block_shape[1] != z.block_shape[1]
967
- or x.block_shape[1] != y.block_shape[0]
968
- ):
969
- raise ValueError("Incompatible block sizes for matrix multiplication")
970
-
971
- if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
972
- raise ValueError("Incompatible number of rows/columns for matrix multiplication")
973
-
974
- device = z.values.device
975
-
976
- if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
977
- # Easy case
978
- return bsr_scale(z, beta)
979
-
980
- if not isinstance(alpha, z.scalar_type):
981
- alpha = z.scalar_type(alpha)
982
- if not isinstance(beta, z.scalar_type):
983
- beta = z.scalar_type(beta)
984
-
985
- if work_arrays is None:
986
- work_arrays = bsr_mm_work_arrays()
987
-
988
- z_aliasing = z == x or z == y
989
- copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
990
-
991
- work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
992
-
993
- # Prefix sum of number of (unmerged) mm blocks per row
994
- wp.launch(
995
- kernel=_bsr_mm_count_coeffs,
996
- device=device,
997
- dim=z.nrow,
998
- inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
999
- )
1000
- warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
1001
-
1002
- # Get back total counts on host
1003
- if device.is_cuda:
1004
- wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
1005
- wp.synchronize_stream(wp.get_stream(device))
1006
- mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
1007
- else:
1008
- mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
1009
-
1010
- work_arrays._allocate_stage_2(mm_nnz)
1011
-
1012
- # If z has a non-zero scale, save current data before overwriting it
1013
- if copied_z_nnz > 0:
1014
- # Copy z row and column indices
1015
- wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1016
- wp.launch(
1017
- kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
1018
- )
1019
- # Save current z values in temporary buffer
1020
- wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1021
- if z_aliasing:
1022
- # If z is aliasing with x or y, need to save topology as well
1023
- wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1024
- wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1025
-
1026
- # Fill unmerged mm blocks rows and columns
1027
- wp.launch(
1028
- kernel=_bsr_mm_list_coeffs,
1029
- device=device,
1030
- dim=z.nrow,
1031
- inputs=[
1032
- x.offsets,
1033
- x.columns,
1034
- y.offsets,
1035
- y.columns,
1036
- work_arrays._mm_row_counts,
1037
- work_arrays._mm_rows,
1038
- work_arrays._mm_cols,
1039
- ],
1040
- )
1041
-
1042
- # Increase dest array size if needed
1043
- if z.columns.shape[0] < mm_nnz:
1044
- z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1045
-
1046
- from warp.context import runtime
1047
-
1048
- if device.is_cpu:
1049
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
1050
- else:
1051
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
1052
-
1053
- z.nnz = native_func(
1054
- z.block_shape[0],
1055
- z.block_shape[1],
1056
- z.nrow,
1057
- mm_nnz,
1058
- work_arrays._mm_rows.ptr,
1059
- work_arrays._mm_cols.ptr,
1060
- 0,
1061
- z.offsets.ptr,
1062
- z.columns.ptr,
1063
- 0,
1064
- )
1065
-
1066
- _bsr_ensure_fits(z)
1067
- z.values.zero_()
1068
-
1069
- if copied_z_nnz > 0:
1070
- # Add back original z values
1071
- wp.launch(
1072
- kernel=_bsr_axpy_add_block,
1073
- device=device,
1074
- dim=copied_z_nnz,
1075
- inputs=[
1076
- 0,
1077
- beta,
1078
- work_arrays._mm_rows,
1079
- work_arrays._mm_cols,
1080
- z.offsets,
1081
- z.columns,
1082
- work_arrays._old_z_values,
1083
- z.values,
1084
- ],
1085
- )
1086
-
1087
- # Add mm blocks to z values
1088
- if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
1089
- warp.types.type_is_matrix(z.values.dtype)
1090
- ):
1091
- # Result block type is scalar, but operands are matrices
1092
- # Cast result to (1x1) matrix to perform multiplication
1093
- mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1094
- else:
1095
- mm_values = z.values
1096
-
1097
- wp.launch(
1098
- kernel=_bsr_mm_compute_values,
1099
- device=device,
1100
- dim=z.nrow,
1101
- inputs=[
1102
- alpha,
1103
- work_arrays._old_z_offsets if x == z else x.offsets,
1104
- work_arrays._old_z_columns if x == z else x.columns,
1105
- work_arrays._old_z_values if x == z else x.values,
1106
- work_arrays._old_z_offsets if y == z else y.offsets,
1107
- work_arrays._old_z_columns if y == z else y.columns,
1108
- work_arrays._old_z_values if y == z else y.values,
1109
- z.offsets,
1110
- z.columns,
1111
- mm_values,
1112
- ],
1113
- )
1114
-
1115
- return z
1116
-
1117
-
1118
- @wp.kernel
1119
- def _bsr_mv_kernel(
1120
- alpha: Any,
1121
- A_offsets: wp.array(dtype=int),
1122
- A_columns: wp.array(dtype=int),
1123
- A_values: wp.array(dtype=Any),
1124
- x: wp.array(dtype=Any),
1125
- beta: Any,
1126
- y: wp.array(dtype=Any),
1127
- ):
1128
- row = wp.tid()
1129
-
1130
- # zero-initialize with type of y elements
1131
- scalar_zero = type(alpha)(0)
1132
- v = y.dtype(scalar_zero)
1133
-
1134
- if alpha != scalar_zero:
1135
- beg = A_offsets[row]
1136
- end = A_offsets[row + 1]
1137
- for block in range(beg, end):
1138
- v += A_values[block] * x[A_columns[block]]
1139
- v *= alpha
1140
-
1141
- if beta != scalar_zero:
1142
- v += beta * y[row]
1143
-
1144
- y[row] = v
1145
-
1146
-
1147
- def bsr_mv(
1148
- A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
1149
- x: "Array[Vector[Cols, Scalar] | Scalar]",
1150
- y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1151
- alpha: Scalar = 1.0,
1152
- beta: Scalar = 0.0,
1153
- work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1154
- ) -> "Array[Vector[Rows, Scalar] | Scalar]":
1155
- """
1156
- Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
1157
-
1158
- The `x` and `y` vectors are allowed to alias.
1159
-
1160
- Args:
1161
- A: Read-only, left matrix factor of the matrix-vector product.
1162
- x: Read-only, right vector factor of the matrix-vector product.
1163
- y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1164
- alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1165
- beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1166
- work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1167
- will be used for this purpose, otherwise a temporary allocation will be performed.
1168
- """
1169
-
1170
- if y is None:
1171
- # If no output array is provided, allocate one for convenience
1172
- y_vec_len = A.block_shape[0]
1173
- y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1174
- y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
1175
- y.zero_()
1176
- beta = 0.0
1177
-
1178
- if not isinstance(alpha, A.scalar_type):
1179
- alpha = A.scalar_type(alpha)
1180
- if not isinstance(beta, A.scalar_type):
1181
- beta = A.scalar_type(beta)
1182
-
1183
- if A.values.device != x.device or A.values.device != y.device:
1184
- raise ValueError("A, x and y must reside on the same device")
1185
-
1186
- if x.shape[0] != A.ncol:
1187
- raise ValueError("Number of columns of A must match number of rows of x")
1188
- if y.shape[0] != A.nrow:
1189
- raise ValueError("Number of rows of A must match number of rows of y")
1190
-
1191
- if x == y:
1192
- # Aliasing case, need temporary storage
1193
- if work_buffer is None:
1194
- work_buffer = wp.empty_like(y)
1195
- elif work_buffer.size < y.size:
1196
- raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
1197
- elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
1198
- raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
1199
-
1200
- # Save old y values before overwriting vector
1201
- wp.copy(dest=work_buffer, src=y, count=y.size)
1202
- x = work_buffer
1203
-
1204
- # Promote scalar vectors to length-1 vecs and conversely
1205
- if warp.types.type_is_matrix(A.values.dtype):
1206
- if A.block_shape[0] == 1:
1207
- if y.dtype == A.scalar_type:
1208
- y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1209
- if A.block_shape[1] == 1:
1210
- if x.dtype == A.scalar_type:
1211
- x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1212
- else:
1213
- if A.block_shape[0] == 1:
1214
- if y.dtype != A.scalar_type:
1215
- y = y.view(dtype=A.scalar_type)
1216
- if A.block_shape[1] == 1:
1217
- if x.dtype != A.scalar_type:
1218
- x = x.view(dtype=A.scalar_type)
1219
-
1220
- wp.launch(
1221
- kernel=_bsr_mv_kernel,
1222
- device=A.values.device,
1223
- dim=A.nrow,
1224
- inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
1225
- )
1226
-
1227
- return y
1
+ from typing import Any, Generic, Optional, Tuple, TypeVar, Union
2
+
3
+ import warp as wp
4
+ import warp.types
5
+ import warp.utils
6
+ from warp.types import Array, Cols, Rows, Scalar, Vector
7
+
8
+ # typing hints
9
+
10
+ _BlockType = TypeVar("BlockType")
11
+
12
+
13
+ class _MatrixBlockType(Generic[Rows, Cols, Scalar]):
14
+ pass
15
+
16
+
17
+ class _ScalarBlockType(Generic[Scalar]):
18
+ pass
19
+
20
+
21
+ BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
22
+
23
+ _struct_cache = {}
24
+
25
+
26
+ class BsrMatrix(Generic[_BlockType]):
27
+ """Untyped base class for BSR and CSR matrices.
28
+
29
+ Should not be constructed directly but through functions such as :func:`bsr_zeros`.
30
+
31
+ Attributes:
32
+ nrow (int): Number of rows of blocks
33
+ ncol (int): Number of columns of blocks
34
+ nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
35
+ offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
36
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
37
+ values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
38
+ """
39
+
40
+ @property
41
+ def scalar_type(self) -> Scalar:
42
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
43
+ return warp.types.type_scalar_type(self.values.dtype)
44
+
45
+ @property
46
+ def block_shape(self) -> Tuple[int, int]:
47
+ """Shape of the individual blocks"""
48
+ return getattr(self.values.dtype, "_shape_", (1, 1))
49
+
50
+ @property
51
+ def block_size(self) -> int:
52
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
53
+ return warp.types.type_length(self.values.dtype)
54
+
55
+ @property
56
+ def shape(self) -> Tuple[int, int]:
57
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
58
+ block_shape = self.block_shape
59
+ return (self.nrow * block_shape[0], self.ncol * block_shape[1])
60
+
61
+ @property
62
+ def dtype(self) -> type:
63
+ """Data type for individual block values"""
64
+ return self.values.dtype
65
+
66
+ @property
67
+ def device(self) -> wp.context.Device:
68
+ """Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays"""
69
+ return self.values.device
70
+
71
+
72
+ def bsr_matrix_t(dtype: BlockType):
73
+ dtype = wp.types.type_to_warp(dtype)
74
+
75
+ if not warp.types.type_is_matrix(dtype) and dtype not in warp.types.scalar_types:
76
+ raise ValueError(
77
+ f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
78
+ )
79
+
80
+ class BsrMatrixTyped(BsrMatrix):
81
+ nrow: int
82
+ """Number of rows of blocks"""
83
+ ncol: int
84
+ """Number of columns of blocks"""
85
+ nnz: int
86
+ """Number of non-zero blocks: equal to offsets[-1], cached on host for convenience"""
87
+ offsets: wp.array(dtype=int)
88
+ """Array of size at least 1 + nrows"""
89
+ columns: wp.array(dtype=int)
90
+ """Array of size at least equal to nnz"""
91
+ values: wp.array(dtype=dtype)
92
+
93
+ module = wp.get_module(BsrMatrix.__module__)
94
+
95
+ if hasattr(dtype, "_shape_"):
96
+ type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
97
+ else:
98
+ type_str = dtype.__name__
99
+ key = f"{BsrMatrix.__qualname__}_{type_str}"
100
+
101
+ if key not in _struct_cache:
102
+ _struct_cache[key] = wp.codegen.Struct(
103
+ cls=BsrMatrixTyped,
104
+ key=key,
105
+ module=module,
106
+ )
107
+
108
+ return _struct_cache[key]
109
+
110
+
111
+ def bsr_zeros(
112
+ rows_of_blocks: int,
113
+ cols_of_blocks: int,
114
+ block_type: BlockType,
115
+ device: wp.context.Devicelike = None,
116
+ ) -> BsrMatrix:
117
+ """
118
+ Constructs and returns an empty BSR or CSR matrix with the given shape
119
+
120
+ Args:
121
+ bsr: The BSR or CSR matrix to set to zero
122
+ rows_of_blocks: Number of rows of blocks
123
+ cols_of_blocks: Number of columns of blocks
124
+ block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
125
+ for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
126
+ device: Device on which to allocate the matrix arrays
127
+ """
128
+
129
+ bsr = bsr_matrix_t(block_type)()
130
+
131
+ bsr.nrow = rows_of_blocks
132
+ bsr.ncol = cols_of_blocks
133
+ bsr.nnz = 0
134
+ bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
135
+ bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
136
+ bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
137
+
138
+ return bsr
139
+
140
+
141
+ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
142
+ if nrow is None:
143
+ nrow = bsr.nrow
144
+ if nnz is None:
145
+ nnz = bsr.nnz
146
+
147
+ if bsr.offsets.size < nrow + 1:
148
+ bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
149
+ if bsr.columns.size < nnz:
150
+ bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
151
+ if bsr.values.size < nnz:
152
+ bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
153
+
154
+
155
+ def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
156
+ """
157
+ Sets a BSR matrix to zero, possibly changing its size
158
+
159
+ Args:
160
+ bsr: The BSR or CSR matrix to set to zero
161
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
162
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
163
+ """
164
+
165
+ if rows_of_blocks is not None:
166
+ bsr.nrow = rows_of_blocks
167
+ if cols_of_blocks is not None:
168
+ bsr.ncol = cols_of_blocks
169
+ bsr.nnz = 0
170
+ _bsr_ensure_fits(bsr)
171
+ bsr.offsets.zero_()
172
+
173
+
174
+ def bsr_set_from_triplets(
175
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
176
+ rows: "Array[int]",
177
+ columns: "Array[int]",
178
+ values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
179
+ ):
180
+ """
181
+ Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
182
+
183
+ The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
184
+
185
+ Args:
186
+ dest: Sparse matrix to populate
187
+ rows: Row index for each non-zero
188
+ columns: Columns index for each non-zero
189
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
190
+ to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
191
+ """
192
+
193
+ if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
194
+ raise ValueError("All arguments must reside on the same device")
195
+
196
+ if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]:
197
+ raise ValueError("All triplet arrays must have the same length")
198
+
199
+ # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
200
+ if values.ndim == 1:
201
+ if values.dtype != dest.values.dtype:
202
+ raise ValueError("Values array type must correspond to that of dest matrix")
203
+ elif values.ndim == 3:
204
+ if values.shape[1:] != dest.block_shape:
205
+ raise ValueError(
206
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
207
+ )
208
+
209
+ if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
210
+ raise ValueError("Scalar type of values array should correspond to that of matrix")
211
+
212
+ if not values.is_contiguous:
213
+ raise ValueError("Multi-dimensional values array should be contiguous")
214
+ else:
215
+ raise ValueError("Number of dimension for values array should be 1 or 3")
216
+
217
+ nnz = rows.shape[0]
218
+ if nnz == 0:
219
+ bsr_set_zero(dest)
220
+ return
221
+
222
+ # Increase dest array sizes if needed
223
+ _bsr_ensure_fits(dest, nnz=nnz)
224
+
225
+ device = dest.values.device
226
+ scalar_type = dest.scalar_type
227
+ from warp.context import runtime
228
+
229
+ if device.is_cpu:
230
+ if scalar_type == wp.float32:
231
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
232
+ elif scalar_type == wp.float64:
233
+ native_func = runtime.core.bsr_matrix_from_triplets_double_host
234
+ else:
235
+ if scalar_type == wp.float32:
236
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
237
+ elif scalar_type == wp.float64:
238
+ native_func = runtime.core.bsr_matrix_from_triplets_double_device
239
+
240
+ if not native_func:
241
+ raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
242
+
243
+ dest.nnz = native_func(
244
+ dest.block_shape[0],
245
+ dest.block_shape[1],
246
+ dest.nrow,
247
+ nnz,
248
+ rows.ptr,
249
+ columns.ptr,
250
+ values.ptr,
251
+ dest.offsets.ptr,
252
+ dest.columns.ptr,
253
+ dest.values.ptr,
254
+ )
255
+
256
+
257
+ def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
258
+ """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
259
+
260
+ if dest.values.device != src.values.device:
261
+ raise ValueError("Source and destination matrices must reside on the same device")
262
+
263
+ if dest.block_shape != src.block_shape:
264
+ raise ValueError("Source and destination matrices must have the same block shape")
265
+
266
+ dest.nrow = src.nrow
267
+ dest.ncol = src.ncol
268
+ dest.nnz = src.nnz
269
+
270
+ _bsr_ensure_fits(dest)
271
+
272
+ wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
273
+ if src.nnz > 0:
274
+ wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
275
+ warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
276
+
277
+
278
+ def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
279
+ """Returns a copy of matrix ``A``, possibly changing its scalar type.
280
+
281
+ Args:
282
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
283
+ """
284
+ if scalar_type is None:
285
+ block_type = A.values.dtype
286
+ elif A.block_shape == (1, 1):
287
+ block_type = scalar_type
288
+ else:
289
+ block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type)
290
+
291
+ copy = bsr_zeros(rows_of_blocks=A.nrow, cols_of_blocks=A.ncol, block_type=block_type, device=A.values.device)
292
+ bsr_assign(dest=copy, src=A)
293
+ return copy
294
+
295
+
296
+ def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
297
+ """Assigns the transposed matrix `src` to matrix `dest`"""
298
+
299
+ if dest.values.device != src.values.device:
300
+ raise ValueError("All arguments must reside on the same device")
301
+
302
+ if dest.scalar_type != src.scalar_type:
303
+ raise ValueError("All arguments must have the same scalar type")
304
+
305
+ transpose_block_shape = src.block_shape[::-1]
306
+
307
+ if dest.block_shape != transpose_block_shape:
308
+ raise ValueError(f"Destination block shape must be {transpose_block_shape}")
309
+
310
+ dest.nrow = src.ncol
311
+ dest.ncol = src.nrow
312
+ dest.nnz = src.nnz
313
+
314
+ if src.nnz == 0:
315
+ return
316
+
317
+ # Increase dest array sizes if needed
318
+ _bsr_ensure_fits(dest)
319
+
320
+ from warp.context import runtime
321
+
322
+ if dest.values.device.is_cpu:
323
+ if dest.scalar_type == wp.float32:
324
+ native_func = runtime.core.bsr_transpose_float_host
325
+ elif dest.scalar_type == wp.float64:
326
+ native_func = runtime.core.bsr_transpose_double_host
327
+ else:
328
+ if dest.scalar_type == wp.float32:
329
+ native_func = runtime.core.bsr_transpose_float_device
330
+ elif dest.scalar_type == wp.float64:
331
+ native_func = runtime.core.bsr_transpose_double_device
332
+
333
+ if not native_func:
334
+ raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
335
+
336
+ native_func(
337
+ src.block_shape[0],
338
+ src.block_shape[1],
339
+ src.nrow,
340
+ src.ncol,
341
+ src.nnz,
342
+ src.offsets.ptr,
343
+ src.columns.ptr,
344
+ src.values.ptr,
345
+ dest.offsets.ptr,
346
+ dest.columns.ptr,
347
+ dest.values.ptr,
348
+ )
349
+
350
+
351
+ def bsr_transposed(A: BsrMatrix):
352
+ """Returns a copy of the transposed matrix `A`"""
353
+
354
+ if A.block_shape == (1, 1):
355
+ block_type = A.values.dtype
356
+ else:
357
+ block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type)
358
+
359
+ transposed = bsr_zeros(rows_of_blocks=A.ncol, cols_of_blocks=A.nrow, block_type=block_type, device=A.values.device)
360
+ bsr_set_transpose(dest=transposed, src=A)
361
+ return transposed
362
+
363
+
364
+ @wp.kernel
365
+ def _bsr_get_diag_kernel(
366
+ A_offsets: wp.array(dtype=int),
367
+ A_columns: wp.array(dtype=int),
368
+ A_values: wp.array(dtype=Any),
369
+ out: wp.array(dtype=Any),
370
+ ):
371
+ row = wp.tid()
372
+ beg = A_offsets[row]
373
+ end = A_offsets[row + 1]
374
+
375
+ diag = wp.lower_bound(A_columns, beg, end, row)
376
+ if diag < end:
377
+ if A_columns[diag] == row:
378
+ out[row] = A_values[diag]
379
+
380
+
381
+ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
382
+ """Returns the array of blocks that constitute the diagonal of a sparse matrix.
383
+
384
+ Args:
385
+ A: the sparse matrix from which to extract the diagonal
386
+ out: if provided, the array into which to store the diagonal blocks
387
+ """
388
+
389
+ dim = min(A.nrow, A.ncol)
390
+
391
+ if out is None:
392
+ out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
393
+ else:
394
+ if out.dtype != A.values.dtype:
395
+ raise ValueError(f"Output array must have type {A.values.dtype}")
396
+ if out.device != A.values.device:
397
+ raise ValueError(f"Output array must reside on device {A.values.device}")
398
+ if out.shape[0] < dim:
399
+ raise ValueError(f"Output array must be of length at least {dim}")
400
+
401
+ wp.launch(
402
+ kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
403
+ )
404
+
405
+ return out
406
+
407
+
408
+ @wp.kernel
409
+ def _bsr_set_diag_kernel(
410
+ diag: wp.array(dtype=Any),
411
+ A_offsets: wp.array(dtype=int),
412
+ A_columns: wp.array(dtype=int),
413
+ A_values: wp.array(dtype=Any),
414
+ ):
415
+ row = wp.tid()
416
+ A_offsets[row + 1] = row + 1
417
+ A_columns[row] = row
418
+ A_values[row] = diag[row]
419
+
420
+ if row == 0:
421
+ A_offsets[0] = 0
422
+
423
+
424
+ @wp.kernel
425
+ def _bsr_set_diag_constant_kernel(
426
+ diag_value: Any,
427
+ A_offsets: wp.array(dtype=int),
428
+ A_columns: wp.array(dtype=int),
429
+ A_values: wp.array(dtype=Any),
430
+ ):
431
+ row = wp.tid()
432
+ A_offsets[row + 1] = row + 1
433
+ A_columns[row] = row
434
+ A_values[row] = diag_value
435
+
436
+ if row == 0:
437
+ A_offsets[0] = 0
438
+
439
+
440
+ def bsr_set_diag(
441
+ A: BsrMatrix[BlockType],
442
+ diag: "Union[BlockType, Array[BlockType]]",
443
+ rows_of_blocks: Optional[int] = None,
444
+ cols_of_blocks: Optional[int] = None,
445
+ ):
446
+ """Sets `A` as a block-diagonal matrix
447
+
448
+ Args:
449
+ A: the sparse matrix to modify
450
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
451
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
452
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
453
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
454
+
455
+ The shape of the matrix will be defined one of the following, in that order:
456
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
457
+ - the first dimension of `diag`, if `diag` is an array
458
+ - the current dimensions of `A` otherwise
459
+ """
460
+
461
+ if rows_of_blocks is None and cols_of_blocks is not None:
462
+ rows_of_blocks = cols_of_blocks
463
+ if cols_of_blocks is None and rows_of_blocks is not None:
464
+ cols_of_blocks = rows_of_blocks
465
+
466
+ if warp.types.is_array(diag):
467
+ if rows_of_blocks is None:
468
+ rows_of_blocks = diag.shape[0]
469
+ cols_of_blocks = diag.shape[0]
470
+
471
+ if rows_of_blocks is not None:
472
+ A.nrow = rows_of_blocks
473
+ A.ncol = cols_of_blocks
474
+
475
+ A.nnz = min(A.nrow, A.ncol)
476
+ _bsr_ensure_fits(A)
477
+
478
+ if warp.types.is_array(diag):
479
+ wp.launch(
480
+ kernel=_bsr_set_diag_kernel,
481
+ dim=A.nnz,
482
+ device=A.values.device,
483
+ inputs=[diag, A.offsets, A.columns, A.values],
484
+ )
485
+ else:
486
+ if not warp.types.type_is_value(type(diag)):
487
+ # Cast to launchable type
488
+ diag = A.values.dtype(diag)
489
+ wp.launch(
490
+ kernel=_bsr_set_diag_constant_kernel,
491
+ dim=A.nnz,
492
+ device=A.values.device,
493
+ inputs=[diag, A.offsets, A.columns, A.values],
494
+ )
495
+
496
+
497
+ def bsr_diag(
498
+ diag: "Union[BlockType, Array[BlockType]]",
499
+ rows_of_blocks: Optional[int] = None,
500
+ cols_of_blocks: Optional[int] = None,
501
+ ) -> BsrMatrix["BlockType"]:
502
+ """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
503
+
504
+ Args:
505
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
506
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
507
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
508
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
509
+
510
+ The shape of the matrix will be defined one of the following, in that order:
511
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
512
+ - the first dimension of `diag`, if `diag` is an array
513
+ """
514
+
515
+ if rows_of_blocks is None and cols_of_blocks is not None:
516
+ rows_of_blocks = cols_of_blocks
517
+ if cols_of_blocks is None and rows_of_blocks is not None:
518
+ cols_of_blocks = rows_of_blocks
519
+
520
+ if warp.types.is_array(diag):
521
+ if rows_of_blocks is None:
522
+ rows_of_blocks = diag.shape[0]
523
+ cols_of_blocks = diag.shape[0]
524
+
525
+ A = bsr_zeros(
526
+ rows_of_blocks,
527
+ cols_of_blocks,
528
+ block_type=diag.dtype,
529
+ device=diag.device,
530
+ )
531
+ else:
532
+ if rows_of_blocks is None:
533
+ raise ValueError(
534
+ "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
535
+ )
536
+
537
+ block_type = type(diag)
538
+ if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
539
+ block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
540
+
541
+ A = bsr_zeros(
542
+ rows_of_blocks,
543
+ cols_of_blocks,
544
+ block_type=block_type,
545
+ )
546
+
547
+ bsr_set_diag(A, diag)
548
+ return A
549
+
550
+
551
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
552
+ """Sets `A` as the identity matrix
553
+
554
+ Args:
555
+ A: the sparse matrix to modify
556
+ rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
557
+ """
558
+
559
+ if A.block_shape == (1, 1):
560
+ identity = A.scalar_type(1.0)
561
+ else:
562
+ from numpy import eye
563
+
564
+ identity = eye(A.block_shape[0])
565
+
566
+ bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
567
+
568
+
569
+ def bsr_identity(
570
+ rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
571
+ ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
572
+ """Creates and returns a square identity matrix.
573
+
574
+ Args:
575
+ rows_of_blocks: Number of rows and columns of blocks in the created matrix.
576
+ block_type: Block type for the newly created matrix -- must be square
577
+ device: Device onto which to allocate the data arrays
578
+ """
579
+ A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
580
+ bsr_set_identity(A)
581
+ return A
582
+
583
+
584
+ @wp.kernel
585
+ def _bsr_scale_kernel(
586
+ alpha: Any,
587
+ values: wp.array(dtype=Any),
588
+ ):
589
+ values[wp.tid()] = alpha * values[wp.tid()]
590
+
591
+
592
+ def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
593
+ """
594
+ Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
595
+ """
596
+
597
+ if alpha != 1.0 and x.nnz > 0:
598
+ if alpha == 0.0:
599
+ bsr_set_zero(x)
600
+ else:
601
+ if not isinstance(alpha, x.scalar_type):
602
+ alpha = x.scalar_type(alpha)
603
+
604
+ wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
605
+
606
+ return x
607
+
608
+
609
+ @wp.kernel
610
+ def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
611
+ i = wp.tid()
612
+
613
+ row = wp.lower_bound(bsr_offsets, i + 1) - 1
614
+ rows[dest_offset + i] = row
615
+
616
+
617
+ @wp.kernel
618
+ def _bsr_axpy_add_block(
619
+ src_offset: int,
620
+ scale: Any,
621
+ rows: wp.array(dtype=int),
622
+ cols: wp.array(dtype=int),
623
+ dst_offsets: wp.array(dtype=int),
624
+ dst_columns: wp.array(dtype=int),
625
+ src_values: wp.array(dtype=Any),
626
+ dst_values: wp.array(dtype=Any),
627
+ ):
628
+ i = wp.tid()
629
+ row = rows[i + src_offset]
630
+ col = cols[i + src_offset]
631
+ beg = dst_offsets[row]
632
+ end = dst_offsets[row + 1]
633
+
634
+ block = wp.lower_bound(dst_columns, beg, end, col)
635
+
636
+ dst_values[block] = dst_values[block] + scale * src_values[i]
637
+
638
+
639
+ class bsr_axpy_work_arrays:
640
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
641
+
642
+ def __init__(self):
643
+ self._reset(None)
644
+
645
+ def _reset(self, device):
646
+ self.device = device
647
+ self._sum_rows = None
648
+ self._sum_cols = None
649
+ self._old_y_values = None
650
+ self._old_x_values = None
651
+
652
+ def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
653
+ if self.device != device:
654
+ self._reset(device)
655
+
656
+ if self._sum_rows is None or self._sum_rows.size < sum_nnz:
657
+ self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
658
+ if self._sum_cols is None or self._sum_cols.size < sum_nnz:
659
+ self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
660
+
661
+ if self._old_y_values is None or self._old_y_values.size < y.nnz:
662
+ self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
663
+
664
+
665
+ def bsr_axpy(
666
+ x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
667
+ y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
668
+ alpha: Scalar = 1.0,
669
+ beta: Scalar = 1.0,
670
+ work_arrays: Optional[bsr_axpy_work_arrays] = None,
671
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
672
+ """
673
+ Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
674
+
675
+ The `x` and `y` matrices are allowed to alias.
676
+
677
+ Args:
678
+ x: Read-only right-hand-side.
679
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
680
+ alpha: Uniform scaling factor for `x`
681
+ beta: Uniform scaling factor for `y`
682
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
683
+ """
684
+
685
+ if y is None:
686
+ # If not output matrix is provided, allocate it for convenience
687
+ y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
688
+ beta = 0.0
689
+
690
+ # Handle easy cases first
691
+ if beta == 0.0 or y.nnz == 0:
692
+ bsr_assign(src=x, dest=y)
693
+ return bsr_scale(y, alpha=alpha)
694
+
695
+ if alpha == 0.0 or x.nnz == 0:
696
+ return bsr_scale(y, alpha=beta)
697
+
698
+ if not isinstance(alpha, y.scalar_type):
699
+ alpha = y.scalar_type(alpha)
700
+ if not isinstance(beta, y.scalar_type):
701
+ beta = y.scalar_type(beta)
702
+
703
+ if x == y:
704
+ # Aliasing case
705
+ return bsr_scale(y, alpha=alpha.value + beta.value)
706
+
707
+ # General case
708
+
709
+ if x.values.device != y.values.device:
710
+ raise ValueError("All arguments must reside on the same device")
711
+
712
+ if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
713
+ raise ValueError("Matrices must have the same block type")
714
+
715
+ if x.nrow != y.nrow or x.ncol != y.ncol:
716
+ raise ValueError("Matrices must have the same number of rows and columns")
717
+
718
+ if work_arrays is None:
719
+ work_arrays = bsr_axpy_work_arrays()
720
+
721
+ sum_nnz = x.nnz + y.nnz
722
+ device = y.values.device
723
+ work_arrays._allocate(device, y, sum_nnz)
724
+
725
+ wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
726
+ wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
727
+
728
+ wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
729
+ wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
730
+
731
+ # Save old y values before overwriting matrix
732
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
733
+
734
+ # Increase dest array sizes if needed
735
+ if y.columns.shape[0] < sum_nnz:
736
+ y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device)
737
+
738
+ from warp.context import runtime
739
+
740
+ if device.is_cpu:
741
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
742
+ else:
743
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
744
+
745
+ old_y_nnz = y.nnz
746
+ y.nnz = native_func(
747
+ y.block_shape[0],
748
+ y.block_shape[1],
749
+ y.nrow,
750
+ sum_nnz,
751
+ work_arrays._sum_rows.ptr,
752
+ work_arrays._sum_cols.ptr,
753
+ 0,
754
+ y.offsets.ptr,
755
+ y.columns.ptr,
756
+ 0,
757
+ )
758
+
759
+ _bsr_ensure_fits(y)
760
+ y.values.zero_()
761
+
762
+ wp.launch(
763
+ kernel=_bsr_axpy_add_block,
764
+ device=device,
765
+ dim=old_y_nnz,
766
+ inputs=[
767
+ 0,
768
+ beta,
769
+ work_arrays._sum_rows,
770
+ work_arrays._sum_cols,
771
+ y.offsets,
772
+ y.columns,
773
+ work_arrays._old_y_values,
774
+ y.values,
775
+ ],
776
+ )
777
+
778
+ wp.launch(
779
+ kernel=_bsr_axpy_add_block,
780
+ device=device,
781
+ dim=x.nnz,
782
+ inputs=[
783
+ old_y_nnz,
784
+ alpha,
785
+ work_arrays._sum_rows,
786
+ work_arrays._sum_cols,
787
+ y.offsets,
788
+ y.columns,
789
+ x.values,
790
+ y.values,
791
+ ],
792
+ )
793
+
794
+ return y
795
+
796
+
797
+ @wp.kernel
798
+ def _bsr_mm_count_coeffs(
799
+ z_nnz: int,
800
+ x_offsets: wp.array(dtype=int),
801
+ x_columns: wp.array(dtype=int),
802
+ y_offsets: wp.array(dtype=int),
803
+ counts: wp.array(dtype=int),
804
+ ):
805
+ row = wp.tid()
806
+ count = int(0)
807
+
808
+ x_beg = x_offsets[row]
809
+ x_end = x_offsets[row + 1]
810
+
811
+ for x_block in range(x_beg, x_end):
812
+ x_col = x_columns[x_block]
813
+ count += y_offsets[x_col + 1] - y_offsets[x_col]
814
+
815
+ counts[row + 1] = count
816
+
817
+ if row == 0:
818
+ counts[0] = z_nnz
819
+
820
+
821
+ @wp.kernel
822
+ def _bsr_mm_list_coeffs(
823
+ x_offsets: wp.array(dtype=int),
824
+ x_columns: wp.array(dtype=int),
825
+ y_offsets: wp.array(dtype=int),
826
+ y_columns: wp.array(dtype=int),
827
+ mm_offsets: wp.array(dtype=int),
828
+ mm_rows: wp.array(dtype=int),
829
+ mm_cols: wp.array(dtype=int),
830
+ ):
831
+ row = wp.tid()
832
+ mm_block = mm_offsets[row]
833
+
834
+ x_beg = x_offsets[row]
835
+ x_end = x_offsets[row + 1]
836
+
837
+ for x_block in range(x_beg, x_end):
838
+ x_col = x_columns[x_block]
839
+
840
+ y_beg = y_offsets[x_col]
841
+ y_end = y_offsets[x_col + 1]
842
+ for y_block in range(y_beg, y_end):
843
+ mm_cols[mm_block] = y_columns[y_block]
844
+ mm_rows[mm_block] = row
845
+ mm_block += 1
846
+
847
+
848
+ @wp.kernel
849
+ def _bsr_mm_compute_values(
850
+ alpha: Any,
851
+ x_offsets: wp.array(dtype=int),
852
+ x_columns: wp.array(dtype=int),
853
+ x_values: wp.array(dtype=Any),
854
+ y_offsets: wp.array(dtype=int),
855
+ y_columns: wp.array(dtype=int),
856
+ y_values: wp.array(dtype=Any),
857
+ mm_offsets: wp.array(dtype=int),
858
+ mm_cols: wp.array(dtype=int),
859
+ mm_values: wp.array(dtype=Any),
860
+ ):
861
+ row = wp.tid()
862
+ mm_beg = mm_offsets[row]
863
+ mm_end = mm_offsets[row + 1]
864
+
865
+ x_beg = x_offsets[row]
866
+ x_end = x_offsets[row + 1]
867
+ for x_block in range(x_beg, x_end):
868
+ x_col = x_columns[x_block]
869
+ ax_val = alpha * x_values[x_block]
870
+
871
+ y_beg = y_offsets[x_col]
872
+ y_end = y_offsets[x_col + 1]
873
+
874
+ for y_block in range(y_beg, y_end):
875
+ mm_block = wp.lower_bound(mm_cols, mm_beg, mm_end, y_columns[y_block])
876
+ mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
877
+
878
+
879
+ class bsr_mm_work_arrays:
880
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
881
+
882
+ def __init__(self):
883
+ self._reset(None)
884
+
885
+ def _reset(self, device):
886
+ self.device = device
887
+ self._pinned_count_buffer = None
888
+ self._mm_row_counts = None
889
+ self._mm_rows = None
890
+ self._mm_cols = None
891
+ self._old_z_values = None
892
+ self._old_z_offsets = None
893
+ self._old_z_columns = None
894
+
895
+ def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
896
+ if self.device != device:
897
+ self._reset(device)
898
+
899
+ # Allocations that do not depend on any computation
900
+ if self.device.is_cuda:
901
+ if self._pinned_count_buffer is None:
902
+ self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
903
+
904
+ if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
905
+ self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
906
+
907
+ if copied_z_nnz > 0:
908
+ if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
909
+ self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
910
+
911
+ if z_aliasing:
912
+ if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
913
+ self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
914
+ if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
915
+ self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
916
+
917
+ def _allocate_stage_2(self, mm_nnz: int):
918
+ # Allocations that depend on unmerged nnz estimate
919
+ if self._mm_rows is None or self._mm_rows.size < mm_nnz:
920
+ self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
921
+ if self._mm_cols is None or self._mm_cols.size < mm_nnz:
922
+ self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
923
+
924
+
925
+ def bsr_mm(
926
+ x: BsrMatrix[BlockType[Rows, Any, Scalar]],
927
+ y: BsrMatrix[BlockType[Any, Cols, Scalar]],
928
+ z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
929
+ alpha: Scalar = 1.0,
930
+ beta: Scalar = 0.0,
931
+ work_arrays: Optional[bsr_mm_work_arrays] = None,
932
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
933
+ """
934
+ Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
935
+
936
+ The `x`, `y` and `z` matrices are allowed to alias.
937
+ If the matrix `z` is not provided as input, it will be allocated and treated as zero.
938
+
939
+ Args:
940
+ x: Read-only left factor of the matrix-matrix product.
941
+ y: Read-only right factor of the matrix-matrix product.
942
+ z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
943
+ alpha: Uniform scaling factor for the ``x * y`` product
944
+ beta: Uniform scaling factor for `z`
945
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
946
+ """
947
+
948
+ if z is None:
949
+ # If not output matrix is provided, allocate it for convenience
950
+ z_block_shape = (x.block_shape[0], y.block_shape[1])
951
+ if z_block_shape == (1, 1):
952
+ z_block_type = x.scalar_type
953
+ else:
954
+ z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type)
955
+ z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
956
+ beta = 0.0
957
+
958
+ if x.values.device != y.values.device or x.values.device != z.values.device:
959
+ raise ValueError("All arguments must reside on the same device")
960
+
961
+ if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
962
+ raise ValueError("Matrices must have the same scalar type")
963
+
964
+ if (
965
+ x.block_shape[0] != z.block_shape[0]
966
+ or y.block_shape[1] != z.block_shape[1]
967
+ or x.block_shape[1] != y.block_shape[0]
968
+ ):
969
+ raise ValueError("Incompatible block sizes for matrix multiplication")
970
+
971
+ if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
972
+ raise ValueError("Incompatible number of rows/columns for matrix multiplication")
973
+
974
+ device = z.values.device
975
+
976
+ if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
977
+ # Easy case
978
+ return bsr_scale(z, beta)
979
+
980
+ if not isinstance(alpha, z.scalar_type):
981
+ alpha = z.scalar_type(alpha)
982
+ if not isinstance(beta, z.scalar_type):
983
+ beta = z.scalar_type(beta)
984
+
985
+ if work_arrays is None:
986
+ work_arrays = bsr_mm_work_arrays()
987
+
988
+ z_aliasing = z == x or z == y
989
+ copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
990
+
991
+ work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
992
+
993
+ # Prefix sum of number of (unmerged) mm blocks per row
994
+ wp.launch(
995
+ kernel=_bsr_mm_count_coeffs,
996
+ device=device,
997
+ dim=z.nrow,
998
+ inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
999
+ )
1000
+ warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
1001
+
1002
+ # Get back total counts on host
1003
+ if device.is_cuda:
1004
+ wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
1005
+ wp.synchronize_stream(wp.get_stream(device))
1006
+ mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
1007
+ else:
1008
+ mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
1009
+
1010
+ work_arrays._allocate_stage_2(mm_nnz)
1011
+
1012
+ # If z has a non-zero scale, save current data before overwriting it
1013
+ if copied_z_nnz > 0:
1014
+ # Copy z row and column indices
1015
+ wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1016
+ wp.launch(
1017
+ kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
1018
+ )
1019
+ # Save current z values in temporary buffer
1020
+ wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1021
+ if z_aliasing:
1022
+ # If z is aliasing with x or y, need to save topology as well
1023
+ wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1024
+ wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1025
+
1026
+ # Fill unmerged mm blocks rows and columns
1027
+ wp.launch(
1028
+ kernel=_bsr_mm_list_coeffs,
1029
+ device=device,
1030
+ dim=z.nrow,
1031
+ inputs=[
1032
+ x.offsets,
1033
+ x.columns,
1034
+ y.offsets,
1035
+ y.columns,
1036
+ work_arrays._mm_row_counts,
1037
+ work_arrays._mm_rows,
1038
+ work_arrays._mm_cols,
1039
+ ],
1040
+ )
1041
+
1042
+ # Increase dest array size if needed
1043
+ if z.columns.shape[0] < mm_nnz:
1044
+ z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1045
+
1046
+ from warp.context import runtime
1047
+
1048
+ if device.is_cpu:
1049
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
1050
+ else:
1051
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
1052
+
1053
+ z.nnz = native_func(
1054
+ z.block_shape[0],
1055
+ z.block_shape[1],
1056
+ z.nrow,
1057
+ mm_nnz,
1058
+ work_arrays._mm_rows.ptr,
1059
+ work_arrays._mm_cols.ptr,
1060
+ 0,
1061
+ z.offsets.ptr,
1062
+ z.columns.ptr,
1063
+ 0,
1064
+ )
1065
+
1066
+ _bsr_ensure_fits(z)
1067
+ z.values.zero_()
1068
+
1069
+ if copied_z_nnz > 0:
1070
+ # Add back original z values
1071
+ wp.launch(
1072
+ kernel=_bsr_axpy_add_block,
1073
+ device=device,
1074
+ dim=copied_z_nnz,
1075
+ inputs=[
1076
+ 0,
1077
+ beta,
1078
+ work_arrays._mm_rows,
1079
+ work_arrays._mm_cols,
1080
+ z.offsets,
1081
+ z.columns,
1082
+ work_arrays._old_z_values,
1083
+ z.values,
1084
+ ],
1085
+ )
1086
+
1087
+ # Add mm blocks to z values
1088
+ if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
1089
+ warp.types.type_is_matrix(z.values.dtype)
1090
+ ):
1091
+ # Result block type is scalar, but operands are matrices
1092
+ # Cast result to (1x1) matrix to perform multiplication
1093
+ mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1094
+ else:
1095
+ mm_values = z.values
1096
+
1097
+ wp.launch(
1098
+ kernel=_bsr_mm_compute_values,
1099
+ device=device,
1100
+ dim=z.nrow,
1101
+ inputs=[
1102
+ alpha,
1103
+ work_arrays._old_z_offsets if x == z else x.offsets,
1104
+ work_arrays._old_z_columns if x == z else x.columns,
1105
+ work_arrays._old_z_values if x == z else x.values,
1106
+ work_arrays._old_z_offsets if y == z else y.offsets,
1107
+ work_arrays._old_z_columns if y == z else y.columns,
1108
+ work_arrays._old_z_values if y == z else y.values,
1109
+ z.offsets,
1110
+ z.columns,
1111
+ mm_values,
1112
+ ],
1113
+ )
1114
+
1115
+ return z
1116
+
1117
+
1118
+ @wp.kernel
1119
+ def _bsr_mv_kernel(
1120
+ alpha: Any,
1121
+ A_offsets: wp.array(dtype=int),
1122
+ A_columns: wp.array(dtype=int),
1123
+ A_values: wp.array(dtype=Any),
1124
+ x: wp.array(dtype=Any),
1125
+ beta: Any,
1126
+ y: wp.array(dtype=Any),
1127
+ ):
1128
+ row = wp.tid()
1129
+
1130
+ # zero-initialize with type of y elements
1131
+ scalar_zero = type(alpha)(0)
1132
+ v = y.dtype(scalar_zero)
1133
+
1134
+ if alpha != scalar_zero:
1135
+ beg = A_offsets[row]
1136
+ end = A_offsets[row + 1]
1137
+ for block in range(beg, end):
1138
+ v += A_values[block] * x[A_columns[block]]
1139
+ v *= alpha
1140
+
1141
+ if beta != scalar_zero:
1142
+ v += beta * y[row]
1143
+
1144
+ y[row] = v
1145
+
1146
+
1147
+ def bsr_mv(
1148
+ A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
1149
+ x: "Array[Vector[Cols, Scalar] | Scalar]",
1150
+ y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1151
+ alpha: Scalar = 1.0,
1152
+ beta: Scalar = 0.0,
1153
+ work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1154
+ ) -> "Array[Vector[Rows, Scalar] | Scalar]":
1155
+ """
1156
+ Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
1157
+
1158
+ The `x` and `y` vectors are allowed to alias.
1159
+
1160
+ Args:
1161
+ A: Read-only, left matrix factor of the matrix-vector product.
1162
+ x: Read-only, right vector factor of the matrix-vector product.
1163
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1164
+ alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1165
+ beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1166
+ work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1167
+ will be used for this purpose, otherwise a temporary allocation will be performed.
1168
+ """
1169
+
1170
+ if y is None:
1171
+ # If no output array is provided, allocate one for convenience
1172
+ y_vec_len = A.block_shape[0]
1173
+ y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1174
+ y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
1175
+ y.zero_()
1176
+ beta = 0.0
1177
+
1178
+ if not isinstance(alpha, A.scalar_type):
1179
+ alpha = A.scalar_type(alpha)
1180
+ if not isinstance(beta, A.scalar_type):
1181
+ beta = A.scalar_type(beta)
1182
+
1183
+ if A.values.device != x.device or A.values.device != y.device:
1184
+ raise ValueError("A, x and y must reside on the same device")
1185
+
1186
+ if x.shape[0] != A.ncol:
1187
+ raise ValueError("Number of columns of A must match number of rows of x")
1188
+ if y.shape[0] != A.nrow:
1189
+ raise ValueError("Number of rows of A must match number of rows of y")
1190
+
1191
+ if x == y:
1192
+ # Aliasing case, need temporary storage
1193
+ if work_buffer is None:
1194
+ work_buffer = wp.empty_like(y)
1195
+ elif work_buffer.size < y.size:
1196
+ raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
1197
+ elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
1198
+ raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
1199
+
1200
+ # Save old y values before overwriting vector
1201
+ wp.copy(dest=work_buffer, src=y, count=y.size)
1202
+ x = work_buffer
1203
+
1204
+ # Promote scalar vectors to length-1 vecs and conversely
1205
+ if warp.types.type_is_matrix(A.values.dtype):
1206
+ if A.block_shape[0] == 1:
1207
+ if y.dtype == A.scalar_type:
1208
+ y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1209
+ if A.block_shape[1] == 1:
1210
+ if x.dtype == A.scalar_type:
1211
+ x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1212
+ else:
1213
+ if A.block_shape[0] == 1:
1214
+ if y.dtype != A.scalar_type:
1215
+ y = y.view(dtype=A.scalar_type)
1216
+ if A.block_shape[1] == 1:
1217
+ if x.dtype != A.scalar_type:
1218
+ x = x.view(dtype=A.scalar_type)
1219
+
1220
+ wp.launch(
1221
+ kernel=_bsr_mv_kernel,
1222
+ device=A.values.device,
1223
+ dim=A.nrow,
1224
+ inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
1225
+ )
1226
+
1227
+ return y