warp-lang 1.0.2__py3-none-manylinux2014_x86_64.whl → 1.2.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (356) hide show
  1. warp/__init__.py +108 -97
  2. warp/__init__.pyi +1 -1
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +88 -113
  6. warp/build_dll.py +383 -375
  7. warp/builtins.py +3693 -3354
  8. warp/codegen.py +2925 -2792
  9. warp/config.py +40 -36
  10. warp/constants.py +49 -45
  11. warp/context.py +5409 -5102
  12. warp/dlpack.py +442 -442
  13. warp/examples/__init__.py +16 -16
  14. warp/examples/assets/bear.usd +0 -0
  15. warp/examples/assets/bunny.usd +0 -0
  16. warp/examples/assets/cartpole.urdf +110 -110
  17. warp/examples/assets/crazyflie.usd +0 -0
  18. warp/examples/assets/cube.usd +0 -0
  19. warp/examples/assets/nv_ant.xml +92 -92
  20. warp/examples/assets/nv_humanoid.xml +183 -183
  21. warp/examples/assets/quadruped.urdf +267 -267
  22. warp/examples/assets/rocks.nvdb +0 -0
  23. warp/examples/assets/rocks.usd +0 -0
  24. warp/examples/assets/sphere.usd +0 -0
  25. warp/examples/benchmarks/benchmark_api.py +381 -383
  26. warp/examples/benchmarks/benchmark_cloth.py +278 -277
  27. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
  28. warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
  29. warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
  30. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
  31. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
  32. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
  33. warp/examples/benchmarks/benchmark_cloth_warp.py +145 -146
  34. warp/examples/benchmarks/benchmark_launches.py +293 -295
  35. warp/examples/browse.py +29 -29
  36. warp/examples/core/example_dem.py +232 -219
  37. warp/examples/core/example_fluid.py +291 -267
  38. warp/examples/core/example_graph_capture.py +142 -126
  39. warp/examples/core/example_marching_cubes.py +186 -174
  40. warp/examples/core/example_mesh.py +172 -155
  41. warp/examples/core/example_mesh_intersect.py +203 -193
  42. warp/examples/core/example_nvdb.py +174 -170
  43. warp/examples/core/example_raycast.py +103 -90
  44. warp/examples/core/example_raymarch.py +197 -178
  45. warp/examples/core/example_render_opengl.py +183 -141
  46. warp/examples/core/example_sph.py +403 -387
  47. warp/examples/core/example_torch.py +219 -181
  48. warp/examples/core/example_wave.py +261 -248
  49. warp/examples/fem/bsr_utils.py +378 -380
  50. warp/examples/fem/example_apic_fluid.py +432 -389
  51. warp/examples/fem/example_burgers.py +262 -0
  52. warp/examples/fem/example_convection_diffusion.py +180 -168
  53. warp/examples/fem/example_convection_diffusion_dg.py +217 -209
  54. warp/examples/fem/example_deformed_geometry.py +175 -159
  55. warp/examples/fem/example_diffusion.py +199 -173
  56. warp/examples/fem/example_diffusion_3d.py +178 -152
  57. warp/examples/fem/example_diffusion_mgpu.py +219 -214
  58. warp/examples/fem/example_mixed_elasticity.py +242 -222
  59. warp/examples/fem/example_navier_stokes.py +257 -243
  60. warp/examples/fem/example_stokes.py +218 -192
  61. warp/examples/fem/example_stokes_transfer.py +263 -249
  62. warp/examples/fem/mesh_utils.py +133 -109
  63. warp/examples/fem/plot_utils.py +292 -287
  64. warp/examples/optim/example_bounce.py +258 -246
  65. warp/examples/optim/example_cloth_throw.py +220 -209
  66. warp/examples/optim/example_diffray.py +564 -536
  67. warp/examples/optim/example_drone.py +862 -835
  68. warp/examples/optim/example_inverse_kinematics.py +174 -168
  69. warp/examples/optim/example_inverse_kinematics_torch.py +183 -169
  70. warp/examples/optim/example_spring_cage.py +237 -231
  71. warp/examples/optim/example_trajectory.py +221 -199
  72. warp/examples/optim/example_walker.py +304 -293
  73. warp/examples/sim/example_cartpole.py +137 -129
  74. warp/examples/sim/example_cloth.py +194 -186
  75. warp/examples/sim/example_granular.py +122 -111
  76. warp/examples/sim/example_granular_collision_sdf.py +195 -186
  77. warp/examples/sim/example_jacobian_ik.py +234 -214
  78. warp/examples/sim/example_particle_chain.py +116 -105
  79. warp/examples/sim/example_quadruped.py +191 -180
  80. warp/examples/sim/example_rigid_chain.py +195 -187
  81. warp/examples/sim/example_rigid_contact.py +187 -177
  82. warp/examples/sim/example_rigid_force.py +125 -125
  83. warp/examples/sim/example_rigid_gyroscopic.py +107 -95
  84. warp/examples/sim/example_rigid_soft_contact.py +132 -122
  85. warp/examples/sim/example_soft_body.py +188 -177
  86. warp/fabric.py +337 -335
  87. warp/fem/__init__.py +61 -27
  88. warp/fem/cache.py +403 -388
  89. warp/fem/dirichlet.py +178 -179
  90. warp/fem/domain.py +262 -263
  91. warp/fem/field/__init__.py +100 -101
  92. warp/fem/field/field.py +148 -149
  93. warp/fem/field/nodal_field.py +298 -299
  94. warp/fem/field/restriction.py +22 -21
  95. warp/fem/field/test.py +180 -181
  96. warp/fem/field/trial.py +183 -183
  97. warp/fem/geometry/__init__.py +16 -19
  98. warp/fem/geometry/closest_point.py +69 -70
  99. warp/fem/geometry/deformed_geometry.py +270 -271
  100. warp/fem/geometry/element.py +748 -744
  101. warp/fem/geometry/geometry.py +184 -186
  102. warp/fem/geometry/grid_2d.py +380 -373
  103. warp/fem/geometry/grid_3d.py +437 -435
  104. warp/fem/geometry/hexmesh.py +953 -953
  105. warp/fem/geometry/nanogrid.py +455 -0
  106. warp/fem/geometry/partition.py +374 -376
  107. warp/fem/geometry/quadmesh_2d.py +532 -532
  108. warp/fem/geometry/tetmesh.py +840 -840
  109. warp/fem/geometry/trimesh_2d.py +577 -577
  110. warp/fem/integrate.py +1684 -1615
  111. warp/fem/operator.py +190 -191
  112. warp/fem/polynomial.py +214 -213
  113. warp/fem/quadrature/__init__.py +2 -2
  114. warp/fem/quadrature/pic_quadrature.py +243 -245
  115. warp/fem/quadrature/quadrature.py +295 -294
  116. warp/fem/space/__init__.py +179 -292
  117. warp/fem/space/basis_space.py +522 -489
  118. warp/fem/space/collocated_function_space.py +100 -105
  119. warp/fem/space/dof_mapper.py +236 -236
  120. warp/fem/space/function_space.py +148 -145
  121. warp/fem/space/grid_2d_function_space.py +148 -267
  122. warp/fem/space/grid_3d_function_space.py +167 -306
  123. warp/fem/space/hexmesh_function_space.py +253 -352
  124. warp/fem/space/nanogrid_function_space.py +202 -0
  125. warp/fem/space/partition.py +350 -350
  126. warp/fem/space/quadmesh_2d_function_space.py +261 -369
  127. warp/fem/space/restriction.py +161 -160
  128. warp/fem/space/shape/__init__.py +90 -15
  129. warp/fem/space/shape/cube_shape_function.py +728 -738
  130. warp/fem/space/shape/shape_function.py +102 -103
  131. warp/fem/space/shape/square_shape_function.py +611 -611
  132. warp/fem/space/shape/tet_shape_function.py +565 -567
  133. warp/fem/space/shape/triangle_shape_function.py +429 -429
  134. warp/fem/space/tetmesh_function_space.py +224 -292
  135. warp/fem/space/topology.py +297 -295
  136. warp/fem/space/trimesh_2d_function_space.py +153 -221
  137. warp/fem/types.py +77 -77
  138. warp/fem/utils.py +495 -495
  139. warp/jax.py +166 -141
  140. warp/jax_experimental.py +341 -339
  141. warp/native/array.h +1081 -1025
  142. warp/native/builtin.h +1603 -1560
  143. warp/native/bvh.cpp +402 -398
  144. warp/native/bvh.cu +533 -525
  145. warp/native/bvh.h +430 -429
  146. warp/native/clang/clang.cpp +496 -464
  147. warp/native/crt.cpp +42 -32
  148. warp/native/crt.h +352 -335
  149. warp/native/cuda_crt.h +1049 -1049
  150. warp/native/cuda_util.cpp +549 -540
  151. warp/native/cuda_util.h +288 -203
  152. warp/native/cutlass_gemm.cpp +34 -34
  153. warp/native/cutlass_gemm.cu +372 -372
  154. warp/native/error.cpp +66 -66
  155. warp/native/error.h +27 -27
  156. warp/native/exports.h +187 -0
  157. warp/native/fabric.h +228 -228
  158. warp/native/hashgrid.cpp +301 -278
  159. warp/native/hashgrid.cu +78 -77
  160. warp/native/hashgrid.h +227 -227
  161. warp/native/initializer_array.h +32 -32
  162. warp/native/intersect.h +1204 -1204
  163. warp/native/intersect_adj.h +365 -365
  164. warp/native/intersect_tri.h +322 -322
  165. warp/native/marching.cpp +2 -2
  166. warp/native/marching.cu +497 -497
  167. warp/native/marching.h +2 -2
  168. warp/native/mat.h +1545 -1498
  169. warp/native/matnn.h +333 -333
  170. warp/native/mesh.cpp +203 -203
  171. warp/native/mesh.cu +292 -293
  172. warp/native/mesh.h +1887 -1887
  173. warp/native/nanovdb/GridHandle.h +366 -0
  174. warp/native/nanovdb/HostBuffer.h +590 -0
  175. warp/native/nanovdb/NanoVDB.h +6624 -4782
  176. warp/native/nanovdb/PNanoVDB.h +3390 -2553
  177. warp/native/noise.h +850 -850
  178. warp/native/quat.h +1112 -1085
  179. warp/native/rand.h +303 -299
  180. warp/native/range.h +108 -108
  181. warp/native/reduce.cpp +156 -156
  182. warp/native/reduce.cu +348 -348
  183. warp/native/runlength_encode.cpp +61 -61
  184. warp/native/runlength_encode.cu +46 -46
  185. warp/native/scan.cpp +30 -30
  186. warp/native/scan.cu +36 -36
  187. warp/native/scan.h +7 -7
  188. warp/native/solid_angle.h +442 -442
  189. warp/native/sort.cpp +94 -94
  190. warp/native/sort.cu +97 -97
  191. warp/native/sort.h +14 -14
  192. warp/native/sparse.cpp +337 -337
  193. warp/native/sparse.cu +544 -544
  194. warp/native/spatial.h +630 -630
  195. warp/native/svd.h +562 -562
  196. warp/native/temp_buffer.h +30 -30
  197. warp/native/vec.h +1177 -1133
  198. warp/native/volume.cpp +529 -297
  199. warp/native/volume.cu +58 -32
  200. warp/native/volume.h +960 -538
  201. warp/native/volume_builder.cu +446 -425
  202. warp/native/volume_builder.h +34 -19
  203. warp/native/volume_impl.h +61 -0
  204. warp/native/warp.cpp +1057 -1052
  205. warp/native/warp.cu +2949 -2828
  206. warp/native/warp.h +321 -305
  207. warp/optim/__init__.py +9 -9
  208. warp/optim/adam.py +120 -120
  209. warp/optim/linear.py +1104 -939
  210. warp/optim/sgd.py +104 -92
  211. warp/render/__init__.py +10 -10
  212. warp/render/render_opengl.py +3356 -3204
  213. warp/render/render_usd.py +768 -749
  214. warp/render/utils.py +152 -150
  215. warp/sim/__init__.py +52 -59
  216. warp/sim/articulation.py +685 -685
  217. warp/sim/collide.py +1594 -1590
  218. warp/sim/import_mjcf.py +489 -481
  219. warp/sim/import_snu.py +220 -221
  220. warp/sim/import_urdf.py +536 -516
  221. warp/sim/import_usd.py +887 -881
  222. warp/sim/inertia.py +316 -317
  223. warp/sim/integrator.py +234 -233
  224. warp/sim/integrator_euler.py +1956 -1956
  225. warp/sim/integrator_featherstone.py +1917 -1991
  226. warp/sim/integrator_xpbd.py +3288 -3312
  227. warp/sim/model.py +4473 -4314
  228. warp/sim/particles.py +113 -112
  229. warp/sim/render.py +417 -403
  230. warp/sim/utils.py +413 -410
  231. warp/sparse.py +1289 -1227
  232. warp/stubs.py +2192 -2469
  233. warp/tape.py +1162 -225
  234. warp/tests/__init__.py +1 -1
  235. warp/tests/__main__.py +4 -4
  236. warp/tests/assets/test_index_grid.nvdb +0 -0
  237. warp/tests/assets/torus.usda +105 -105
  238. warp/tests/aux_test_class_kernel.py +26 -26
  239. warp/tests/aux_test_compile_consts_dummy.py +10 -10
  240. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
  241. warp/tests/aux_test_dependent.py +20 -22
  242. warp/tests/aux_test_grad_customs.py +21 -23
  243. warp/tests/aux_test_reference.py +9 -11
  244. warp/tests/aux_test_reference_reference.py +8 -10
  245. warp/tests/aux_test_square.py +15 -17
  246. warp/tests/aux_test_unresolved_func.py +14 -14
  247. warp/tests/aux_test_unresolved_symbol.py +14 -14
  248. warp/tests/disabled_kinematics.py +237 -239
  249. warp/tests/run_coverage_serial.py +31 -31
  250. warp/tests/test_adam.py +155 -157
  251. warp/tests/test_arithmetic.py +1088 -1124
  252. warp/tests/test_array.py +2415 -2326
  253. warp/tests/test_array_reduce.py +148 -150
  254. warp/tests/test_async.py +666 -656
  255. warp/tests/test_atomic.py +139 -141
  256. warp/tests/test_bool.py +212 -149
  257. warp/tests/test_builtins_resolution.py +1290 -1292
  258. warp/tests/test_bvh.py +162 -171
  259. warp/tests/test_closest_point_edge_edge.py +227 -228
  260. warp/tests/test_codegen.py +562 -553
  261. warp/tests/test_compile_consts.py +217 -101
  262. warp/tests/test_conditional.py +244 -246
  263. warp/tests/test_copy.py +230 -215
  264. warp/tests/test_ctypes.py +630 -632
  265. warp/tests/test_dense.py +65 -67
  266. warp/tests/test_devices.py +89 -98
  267. warp/tests/test_dlpack.py +528 -529
  268. warp/tests/test_examples.py +403 -378
  269. warp/tests/test_fabricarray.py +952 -955
  270. warp/tests/test_fast_math.py +60 -54
  271. warp/tests/test_fem.py +1298 -1278
  272. warp/tests/test_fp16.py +128 -130
  273. warp/tests/test_func.py +336 -337
  274. warp/tests/test_generics.py +596 -571
  275. warp/tests/test_grad.py +885 -640
  276. warp/tests/test_grad_customs.py +331 -336
  277. warp/tests/test_hash_grid.py +208 -164
  278. warp/tests/test_import.py +37 -39
  279. warp/tests/test_indexedarray.py +1132 -1134
  280. warp/tests/test_intersect.py +65 -67
  281. warp/tests/test_jax.py +305 -307
  282. warp/tests/test_large.py +169 -164
  283. warp/tests/test_launch.py +352 -354
  284. warp/tests/test_lerp.py +217 -261
  285. warp/tests/test_linear_solvers.py +189 -171
  286. warp/tests/test_lvalue.py +419 -493
  287. warp/tests/test_marching_cubes.py +63 -65
  288. warp/tests/test_mat.py +1799 -1827
  289. warp/tests/test_mat_lite.py +113 -115
  290. warp/tests/test_mat_scalar_ops.py +2905 -2889
  291. warp/tests/test_math.py +124 -193
  292. warp/tests/test_matmul.py +498 -499
  293. warp/tests/test_matmul_lite.py +408 -410
  294. warp/tests/test_mempool.py +186 -190
  295. warp/tests/test_mesh.py +281 -324
  296. warp/tests/test_mesh_query_aabb.py +226 -241
  297. warp/tests/test_mesh_query_point.py +690 -702
  298. warp/tests/test_mesh_query_ray.py +290 -303
  299. warp/tests/test_mlp.py +274 -276
  300. warp/tests/test_model.py +108 -110
  301. warp/tests/test_module_hashing.py +111 -0
  302. warp/tests/test_modules_lite.py +36 -39
  303. warp/tests/test_multigpu.py +161 -163
  304. warp/tests/test_noise.py +244 -248
  305. warp/tests/test_operators.py +248 -250
  306. warp/tests/test_options.py +121 -125
  307. warp/tests/test_peer.py +131 -137
  308. warp/tests/test_pinned.py +76 -78
  309. warp/tests/test_print.py +52 -54
  310. warp/tests/test_quat.py +2084 -2086
  311. warp/tests/test_rand.py +324 -288
  312. warp/tests/test_reload.py +207 -217
  313. warp/tests/test_rounding.py +177 -179
  314. warp/tests/test_runlength_encode.py +188 -190
  315. warp/tests/test_sim_grad.py +241 -0
  316. warp/tests/test_sim_kinematics.py +89 -97
  317. warp/tests/test_smoothstep.py +166 -168
  318. warp/tests/test_snippet.py +303 -266
  319. warp/tests/test_sparse.py +466 -460
  320. warp/tests/test_spatial.py +2146 -2148
  321. warp/tests/test_special_values.py +362 -0
  322. warp/tests/test_streams.py +484 -473
  323. warp/tests/test_struct.py +708 -675
  324. warp/tests/test_tape.py +171 -148
  325. warp/tests/test_torch.py +741 -743
  326. warp/tests/test_transient_module.py +85 -87
  327. warp/tests/test_types.py +554 -659
  328. warp/tests/test_utils.py +488 -499
  329. warp/tests/test_vec.py +1262 -1268
  330. warp/tests/test_vec_lite.py +71 -73
  331. warp/tests/test_vec_scalar_ops.py +2097 -2099
  332. warp/tests/test_verify_fp.py +92 -94
  333. warp/tests/test_volume.py +961 -736
  334. warp/tests/test_volume_write.py +338 -265
  335. warp/tests/unittest_serial.py +38 -37
  336. warp/tests/unittest_suites.py +367 -359
  337. warp/tests/unittest_utils.py +434 -578
  338. warp/tests/unused_test_misc.py +69 -71
  339. warp/tests/walkthrough_debug.py +85 -85
  340. warp/thirdparty/appdirs.py +598 -598
  341. warp/thirdparty/dlpack.py +143 -143
  342. warp/thirdparty/unittest_parallel.py +563 -561
  343. warp/torch.py +321 -295
  344. warp/types.py +4941 -4450
  345. warp/utils.py +1008 -821
  346. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +126 -126
  347. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/METADATA +365 -400
  348. warp_lang-1.2.0.dist-info/RECORD +359 -0
  349. warp/examples/assets/cube.usda +0 -42
  350. warp/examples/assets/sphere.usda +0 -56
  351. warp/examples/assets/torus.usda +0 -105
  352. warp/examples/fem/example_convection_diffusion_dg0.py +0 -194
  353. warp/native/nanovdb/PNanoVDBWrite.h +0 -295
  354. warp_lang-1.0.2.dist-info/RECORD +0 -352
  355. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
  356. {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -1,1227 +1,1289 @@
1
- from typing import Any, Generic, Optional, Tuple, TypeVar, Union
2
-
3
- import warp as wp
4
- import warp.types
5
- import warp.utils
6
- from warp.types import Array, Cols, Matrix, Rows, Scalar, Vector
7
-
8
- # typing hints
9
-
10
- _BlockType = TypeVar("BlockType")
11
-
12
-
13
- class _MatrixBlockType(Matrix):
14
- pass
15
-
16
-
17
- class _ScalarBlockType(Generic[Scalar]):
18
- pass
19
-
20
-
21
- BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
22
-
23
- _struct_cache = dict()
24
-
25
-
26
- class BsrMatrix(Generic[_BlockType]):
27
- """Untyped base class for BSR and CSR matrices.
28
-
29
- Should not be constructed directly but through functions such as :func:`bsr_zeros`.
30
-
31
- Attributes:
32
- nrow (int): Number of rows of blocks
33
- ncol (int): Number of columns of blocks
34
- nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
35
- offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
36
- columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
37
- values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
38
- """
39
-
40
- @property
41
- def scalar_type(self) -> Scalar:
42
- """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
43
- return warp.types.type_scalar_type(self.values.dtype)
44
-
45
- @property
46
- def block_shape(self) -> Tuple[int, int]:
47
- """Shape of the individual blocks"""
48
- return getattr(self.values.dtype, "_shape_", (1, 1))
49
-
50
- @property
51
- def block_size(self) -> int:
52
- """Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
53
- return warp.types.type_length(self.values.dtype)
54
-
55
- @property
56
- def shape(self) -> Tuple[int, int]:
57
- """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
58
- block_shape = self.block_shape
59
- return (self.nrow * block_shape[0], self.ncol * block_shape[1])
60
-
61
- @property
62
- def dtype(self) -> type:
63
- """Data type for individual block values"""
64
- return self.values.dtype
65
-
66
- @property
67
- def device(self) -> wp.context.Device:
68
- """Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays """
69
- return self.values.device
70
-
71
-
72
- def bsr_matrix_t(dtype: BlockType):
73
- dtype = wp.types.type_to_warp(dtype)
74
-
75
- if not warp.types.type_is_matrix(dtype) and not dtype in warp.types.scalar_types:
76
- raise ValueError(
77
- f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
78
- )
79
-
80
- class BsrMatrixTyped(BsrMatrix):
81
- nrow: int
82
- """Number of rows of blocks"""
83
- ncol: int
84
- """Number of columns of blocks"""
85
- nnz: int
86
- """Number of non-zero blocks: equal to offsets[-1], cached on host for convenience"""
87
- offsets: wp.array(dtype=int)
88
- """Array of size at least 1 + nrows"""
89
- columns: wp.array(dtype=int)
90
- """Array of size at least equal to nnz"""
91
- values: wp.array(dtype=dtype)
92
-
93
- module = wp.get_module(BsrMatrix.__module__)
94
-
95
- if hasattr(dtype, "_shape_"):
96
- type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
97
- else:
98
- type_str = dtype.__name__
99
- key = f"{BsrMatrix.__qualname__}_{type_str}"
100
-
101
- if key not in _struct_cache:
102
- _struct_cache[key] = wp.codegen.Struct(
103
- cls=BsrMatrixTyped,
104
- key=key,
105
- module=module,
106
- )
107
-
108
- return _struct_cache[key]
109
-
110
-
111
- def bsr_zeros(
112
- rows_of_blocks: int,
113
- cols_of_blocks: int,
114
- block_type: BlockType,
115
- device: wp.context.Devicelike = None,
116
- ) -> BsrMatrix:
117
- """
118
- Constructs and returns an empty BSR or CSR matrix with the given shape
119
-
120
- Args:
121
- bsr: The BSR or CSR matrix to set to zero
122
- rows_of_blocks: Number of rows of blocks
123
- cols_of_blocks: Number of columns of blocks
124
- block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
125
- for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
126
- device: Device on which to allocate the matrix arrays
127
- """
128
-
129
- bsr = bsr_matrix_t(block_type)()
130
-
131
- bsr.nrow = rows_of_blocks
132
- bsr.ncol = cols_of_blocks
133
- bsr.nnz = 0
134
- bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
135
- bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
136
- bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
137
-
138
- return bsr
139
-
140
-
141
- def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
142
- if nrow is None:
143
- nrow = bsr.nrow
144
- if nnz is None:
145
- nnz = bsr.nnz
146
-
147
- if bsr.offsets.size < nrow + 1:
148
- bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
149
- if bsr.columns.size < nnz:
150
- bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
151
- if bsr.values.size < nnz:
152
- bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
153
-
154
-
155
- def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
156
- """
157
- Sets a BSR matrix to zero, possibly changing its size
158
-
159
- Args:
160
- bsr: The BSR or CSR matrix to set to zero
161
- rows_of_blocks: If not ``None``, the new number of rows of blocks
162
- cols_of_blocks: If not ``None``, the new number of columns of blocks
163
- """
164
-
165
- if rows_of_blocks is not None:
166
- bsr.nrow = rows_of_blocks
167
- if cols_of_blocks is not None:
168
- bsr.ncol = cols_of_blocks
169
- bsr.nnz = 0
170
- _bsr_ensure_fits(bsr)
171
- bsr.offsets.zero_()
172
-
173
-
174
- def bsr_set_from_triplets(
175
- dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
176
- rows: "Array[int]",
177
- columns: "Array[int]",
178
- values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
179
- ):
180
- """
181
- Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
182
-
183
- The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
184
-
185
- Args:
186
- dest: Sparse matrix to populate
187
- rows: Row index for each non-zero
188
- columns: Columns index for each non-zero
189
- values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
190
- to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
191
- """
192
-
193
- if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
194
- raise ValueError("All arguments must reside on the same device")
195
-
196
- if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]:
197
- raise ValueError("All triplet arrays must have the same length")
198
-
199
- # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
200
- if values.ndim == 1:
201
- if values.dtype != dest.values.dtype:
202
- raise ValueError("Values array type must correspond to that of dest matrix")
203
- elif values.ndim == 3:
204
- if values.shape[1:] != dest.block_shape:
205
- raise ValueError(
206
- f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
207
- )
208
-
209
- if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
210
- raise ValueError("Scalar type of values array should correspond to that of matrix")
211
-
212
- if not values.is_contiguous:
213
- raise ValueError("Multi-dimensional values array should be contiguous")
214
- else:
215
- raise ValueError("Number of dimension for values array should be 1 or 3")
216
-
217
- nnz = rows.shape[0]
218
- if nnz == 0:
219
- bsr_set_zero(dest)
220
- return
221
-
222
- # Increase dest array sizes if needed
223
- _bsr_ensure_fits(dest, nnz=nnz)
224
-
225
- device = dest.values.device
226
- scalar_type = dest.scalar_type
227
- from warp.context import runtime
228
-
229
- if device.is_cpu:
230
- if scalar_type == wp.float32:
231
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
232
- elif scalar_type == wp.float64:
233
- native_func = runtime.core.bsr_matrix_from_triplets_double_host
234
- else:
235
- if scalar_type == wp.float32:
236
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
237
- elif scalar_type == wp.float64:
238
- native_func = runtime.core.bsr_matrix_from_triplets_double_device
239
-
240
- if not native_func:
241
- raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
242
-
243
- dest.nnz = native_func(
244
- dest.block_shape[0],
245
- dest.block_shape[1],
246
- dest.nrow,
247
- nnz,
248
- rows.ptr,
249
- columns.ptr,
250
- values.ptr,
251
- dest.offsets.ptr,
252
- dest.columns.ptr,
253
- dest.values.ptr,
254
- )
255
-
256
-
257
- def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
258
- """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
259
-
260
- if dest.values.device != src.values.device:
261
- raise ValueError("Source and destination matrices must reside on the same device")
262
-
263
- if dest.block_shape != src.block_shape:
264
- raise ValueError("Source and destination matrices must have the same block shape")
265
-
266
- dest.nrow = src.nrow
267
- dest.ncol = src.ncol
268
- dest.nnz = src.nnz
269
-
270
- _bsr_ensure_fits(dest)
271
-
272
- wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
273
- if src.nnz > 0:
274
- wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
275
- warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
276
-
277
-
278
- def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
279
- """Returns a copy of matrix ``A``, possibly changing its scalar type.
280
-
281
- Args:
282
- scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
283
- """
284
- if scalar_type is None:
285
- block_type = A.values.dtype
286
- elif A.block_shape == (1, 1):
287
- block_type = scalar_type
288
- else:
289
- block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type)
290
-
291
- copy = bsr_zeros(rows_of_blocks=A.nrow, cols_of_blocks=A.ncol, block_type=block_type, device=A.values.device)
292
- bsr_assign(dest=copy, src=A)
293
- return copy
294
-
295
-
296
- def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
297
- """Assigns the transposed matrix `src` to matrix `dest`"""
298
-
299
- if dest.values.device != src.values.device:
300
- raise ValueError("All arguments must reside on the same device")
301
-
302
- if dest.scalar_type != src.scalar_type:
303
- raise ValueError("All arguments must have the same scalar type")
304
-
305
- transpose_block_shape = src.block_shape[::-1]
306
-
307
- if dest.block_shape != transpose_block_shape:
308
- raise ValueError(f"Destination block shape must be {transpose_block_shape}")
309
-
310
- dest.nrow = src.ncol
311
- dest.ncol = src.nrow
312
- dest.nnz = src.nnz
313
-
314
- if src.nnz == 0:
315
- return
316
-
317
- # Increase dest array sizes if needed
318
- _bsr_ensure_fits(dest)
319
-
320
- from warp.context import runtime
321
-
322
- if dest.values.device.is_cpu:
323
- if dest.scalar_type == wp.float32:
324
- native_func = runtime.core.bsr_transpose_float_host
325
- elif dest.scalar_type == wp.float64:
326
- native_func = runtime.core.bsr_transpose_double_host
327
- else:
328
- if dest.scalar_type == wp.float32:
329
- native_func = runtime.core.bsr_transpose_float_device
330
- elif dest.scalar_type == wp.float64:
331
- native_func = runtime.core.bsr_transpose_double_device
332
-
333
- if not native_func:
334
- raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
335
-
336
- native_func(
337
- src.block_shape[0],
338
- src.block_shape[1],
339
- src.nrow,
340
- src.ncol,
341
- src.nnz,
342
- src.offsets.ptr,
343
- src.columns.ptr,
344
- src.values.ptr,
345
- dest.offsets.ptr,
346
- dest.columns.ptr,
347
- dest.values.ptr,
348
- )
349
-
350
-
351
- def bsr_transposed(A: BsrMatrix):
352
- """Returns a copy of the transposed matrix `A`"""
353
-
354
- if A.block_shape == (1, 1):
355
- block_type = A.values.dtype
356
- else:
357
- block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type)
358
-
359
- transposed = bsr_zeros(rows_of_blocks=A.ncol, cols_of_blocks=A.nrow, block_type=block_type, device=A.values.device)
360
- bsr_set_transpose(dest=transposed, src=A)
361
- return transposed
362
-
363
-
364
- @wp.kernel
365
- def _bsr_get_diag_kernel(
366
- A_offsets: wp.array(dtype=int),
367
- A_columns: wp.array(dtype=int),
368
- A_values: wp.array(dtype=Any),
369
- out: wp.array(dtype=Any),
370
- ):
371
- row = wp.tid()
372
- beg = A_offsets[row]
373
- end = A_offsets[row + 1]
374
-
375
- diag = wp.lower_bound(A_columns, beg, end, row)
376
- if diag < end:
377
- if A_columns[diag] == row:
378
- out[row] = A_values[diag]
379
-
380
-
381
- def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
382
- """Returns the array of blocks that constitute the diagonal of a sparse matrix.
383
-
384
- Args:
385
- A: the sparse matrix from which to extract the diagonal
386
- out: if provided, the array into which to store the diagonal blocks
387
- """
388
-
389
- dim = min(A.nrow, A.ncol)
390
-
391
- if out is None:
392
- out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
393
- else:
394
- if out.dtype != A.values.dtype:
395
- raise ValueError(f"Output array must have type {A.values.dtype}")
396
- if out.device != A.values.device:
397
- raise ValueError(f"Output array must reside on device {A.values.device}")
398
- if out.shape[0] < dim:
399
- raise ValueError(f"Output array must be of length at least {dim}")
400
-
401
- wp.launch(
402
- kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
403
- )
404
-
405
- return out
406
-
407
-
408
- @wp.kernel
409
- def _bsr_set_diag_kernel(
410
- diag: wp.array(dtype=Any),
411
- A_offsets: wp.array(dtype=int),
412
- A_columns: wp.array(dtype=int),
413
- A_values: wp.array(dtype=Any),
414
- ):
415
- row = wp.tid()
416
- A_offsets[row + 1] = row + 1
417
- A_columns[row] = row
418
- A_values[row] = diag[row]
419
-
420
- if row == 0:
421
- A_offsets[0] = 0
422
-
423
-
424
- @wp.kernel
425
- def _bsr_set_diag_constant_kernel(
426
- diag_value: Any,
427
- A_offsets: wp.array(dtype=int),
428
- A_columns: wp.array(dtype=int),
429
- A_values: wp.array(dtype=Any),
430
- ):
431
- row = wp.tid()
432
- A_offsets[row + 1] = row + 1
433
- A_columns[row] = row
434
- A_values[row] = diag_value
435
-
436
- if row == 0:
437
- A_offsets[0] = 0
438
-
439
-
440
- def bsr_set_diag(
441
- A: BsrMatrix[BlockType],
442
- diag: "Union[BlockType, Array[BlockType]]",
443
- rows_of_blocks: Optional[int] = None,
444
- cols_of_blocks: Optional[int] = None,
445
- ):
446
- """Sets `A` as a block-diagonal matrix
447
-
448
- Args:
449
- A: the sparse matrix to modify
450
- diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
451
- or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
452
- rows_of_blocks: If not ``None``, the new number of rows of blocks
453
- cols_of_blocks: If not ``None``, the new number of columns of blocks
454
-
455
- The shape of the matrix will be defined one of the following, in that order:
456
- - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
457
- - the first dimension of `diag`, if `diag` is an array
458
- - the current dimensions of `A` otherwise
459
- """
460
-
461
- if rows_of_blocks is None and cols_of_blocks is not None:
462
- rows_of_blocks = cols_of_blocks
463
- if cols_of_blocks is None and rows_of_blocks is not None:
464
- cols_of_blocks = rows_of_blocks
465
-
466
- if warp.types.is_array(diag):
467
- if rows_of_blocks is None:
468
- rows_of_blocks = diag.shape[0]
469
- cols_of_blocks = diag.shape[0]
470
-
471
- if rows_of_blocks is not None:
472
- A.nrow = rows_of_blocks
473
- A.ncol = cols_of_blocks
474
-
475
- A.nnz = min(A.nrow, A.ncol)
476
- _bsr_ensure_fits(A)
477
-
478
- if warp.types.is_array(diag):
479
- wp.launch(
480
- kernel=_bsr_set_diag_kernel,
481
- dim=A.nnz,
482
- device=A.values.device,
483
- inputs=[diag, A.offsets, A.columns, A.values],
484
- )
485
- else:
486
- if not warp.types.type_is_value(type(diag)):
487
- # Cast to launchable type
488
- diag = A.values.dtype(diag)
489
- wp.launch(
490
- kernel=_bsr_set_diag_constant_kernel,
491
- dim=A.nnz,
492
- device=A.values.device,
493
- inputs=[diag, A.offsets, A.columns, A.values],
494
- )
495
-
496
-
497
- def bsr_diag(
498
- diag: "Union[BlockType, Array[BlockType]]",
499
- rows_of_blocks: Optional[int] = None,
500
- cols_of_blocks: Optional[int] = None,
501
- ) -> BsrMatrix["BlockType"]:
502
- """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
503
-
504
- Args:
505
- diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
506
- or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
507
- rows_of_blocks: If not ``None``, the new number of rows of blocks
508
- cols_of_blocks: If not ``None``, the new number of columns of blocks
509
-
510
- The shape of the matrix will be defined one of the following, in that order:
511
- - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
512
- - the first dimension of `diag`, if `diag` is an array
513
- """
514
-
515
- if rows_of_blocks is None and cols_of_blocks is not None:
516
- rows_of_blocks = cols_of_blocks
517
- if cols_of_blocks is None and rows_of_blocks is not None:
518
- cols_of_blocks = rows_of_blocks
519
-
520
- if warp.types.is_array(diag):
521
- if rows_of_blocks is None:
522
- rows_of_blocks = diag.shape[0]
523
- cols_of_blocks = diag.shape[0]
524
-
525
- A = bsr_zeros(
526
- rows_of_blocks,
527
- cols_of_blocks,
528
- block_type=diag.dtype,
529
- device=diag.device,
530
- )
531
- else:
532
- if rows_of_blocks is None:
533
- raise ValueError(
534
- "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
535
- )
536
-
537
- block_type = type(diag)
538
- if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
539
- block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
540
-
541
- A = bsr_zeros(
542
- rows_of_blocks,
543
- cols_of_blocks,
544
- block_type=block_type,
545
- )
546
-
547
- bsr_set_diag(A, diag)
548
- return A
549
-
550
-
551
- def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
552
- """Sets `A` as the identity matrix
553
-
554
- Args:
555
- A: the sparse matrix to modify
556
- rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
557
- """
558
-
559
- if A.block_shape == (1, 1):
560
- identity = A.scalar_type(1.0)
561
- else:
562
- from numpy import eye
563
-
564
- identity = eye(A.block_shape[0])
565
-
566
- bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
567
-
568
-
569
- def bsr_identity(
570
- rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
571
- ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
572
- """Creates and returns a square identity matrix.
573
-
574
- Args:
575
- rows_of_blocks: Number of rows and columns of blocks in the created matrix.
576
- block_type: Block type for the newly created matrix -- must be square
577
- device: Device onto which to allocate the data arrays
578
- """
579
- A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
580
- bsr_set_identity(A)
581
- return A
582
-
583
-
584
- @wp.kernel
585
- def _bsr_scale_kernel(
586
- alpha: Any,
587
- values: wp.array(dtype=Any),
588
- ):
589
- values[wp.tid()] = alpha * values[wp.tid()]
590
-
591
-
592
- def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
593
- """
594
- Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
595
- """
596
-
597
- if alpha != 1.0 and x.nnz > 0:
598
- if alpha == 0.0:
599
- bsr_set_zero(x)
600
- else:
601
- if not isinstance(alpha, x.scalar_type):
602
- alpha = x.scalar_type(alpha)
603
-
604
- wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
605
-
606
- return x
607
-
608
-
609
- @wp.kernel
610
- def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
611
- i = wp.tid()
612
-
613
- row = wp.lower_bound(bsr_offsets, i + 1) - 1
614
- rows[dest_offset + i] = row
615
-
616
-
617
- @wp.kernel
618
- def _bsr_axpy_add_block(
619
- src_offset: int,
620
- scale: Any,
621
- rows: wp.array(dtype=int),
622
- cols: wp.array(dtype=int),
623
- dst_offsets: wp.array(dtype=int),
624
- dst_columns: wp.array(dtype=int),
625
- src_values: wp.array(dtype=Any),
626
- dst_values: wp.array(dtype=Any),
627
- ):
628
- i = wp.tid()
629
- row = rows[i + src_offset]
630
- col = cols[i + src_offset]
631
- beg = dst_offsets[row]
632
- end = dst_offsets[row + 1]
633
-
634
- block = wp.lower_bound(dst_columns, beg, end, col)
635
-
636
- dst_values[block] = dst_values[block] + scale * src_values[i]
637
-
638
-
639
- class bsr_axpy_work_arrays:
640
- """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
641
-
642
- def __init__(self):
643
- self._reset(None)
644
-
645
- def _reset(self, device):
646
- self.device = device
647
- self._sum_rows = None
648
- self._sum_cols = None
649
- self._old_y_values = None
650
- self._old_x_values = None
651
-
652
- def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
653
- if self.device != device:
654
- self._reset(device)
655
-
656
- if self._sum_rows is None or self._sum_rows.size < sum_nnz:
657
- self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
658
- if self._sum_cols is None or self._sum_cols.size < sum_nnz:
659
- self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
660
-
661
- if self._old_y_values is None or self._old_y_values.size < y.nnz:
662
- self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
663
-
664
-
665
- def bsr_axpy(
666
- x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
667
- y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
668
- alpha: Scalar = 1.0,
669
- beta: Scalar = 1.0,
670
- work_arrays: Optional[bsr_axpy_work_arrays] = None,
671
- ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
672
- """
673
- Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
674
-
675
- The `x` and `y` matrices are allowed to alias.
676
-
677
- Args:
678
- x: Read-only right-hand-side.
679
- y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
680
- alpha: Uniform scaling factor for `x`
681
- beta: Uniform scaling factor for `y`
682
- work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
683
- """
684
-
685
- if y is None:
686
- # If not output matrix is provided, allocate it for convenience
687
- y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
688
- beta = 0.0
689
-
690
- # Handle easy cases first
691
- if beta == 0.0 or y.nnz == 0:
692
- bsr_assign(src=x, dest=y)
693
- return bsr_scale(y, alpha=alpha)
694
-
695
- if alpha == 0.0 or x.nnz == 0:
696
- return bsr_scale(y, alpha=beta)
697
-
698
- if not isinstance(alpha, y.scalar_type):
699
- alpha = y.scalar_type(alpha)
700
- if not isinstance(beta, y.scalar_type):
701
- beta = y.scalar_type(beta)
702
-
703
- if x == y:
704
- # Aliasing case
705
- return bsr_scale(y, alpha=alpha.value + beta.value)
706
-
707
- # General case
708
-
709
- if x.values.device != y.values.device:
710
- raise ValueError("All arguments must reside on the same device")
711
-
712
- if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
713
- raise ValueError("Matrices must have the same block type")
714
-
715
- if x.nrow != y.nrow or x.ncol != y.ncol:
716
- raise ValueError("Matrices must have the same number of rows and columns")
717
-
718
- if work_arrays is None:
719
- work_arrays = bsr_axpy_work_arrays()
720
-
721
- sum_nnz = x.nnz + y.nnz
722
- device = y.values.device
723
- work_arrays._allocate(device, y, sum_nnz)
724
-
725
- wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
726
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
727
-
728
- wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
729
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
730
-
731
- # Save old y values before overwriting matrix
732
- wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
733
-
734
- # Increase dest array sizes if needed
735
- if y.columns.shape[0] < sum_nnz:
736
- y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device)
737
-
738
- from warp.context import runtime
739
-
740
- if device.is_cpu:
741
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
742
- else:
743
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
744
-
745
- old_y_nnz = y.nnz
746
- y.nnz = native_func(
747
- y.block_shape[0],
748
- y.block_shape[1],
749
- y.nrow,
750
- sum_nnz,
751
- work_arrays._sum_rows.ptr,
752
- work_arrays._sum_cols.ptr,
753
- 0,
754
- y.offsets.ptr,
755
- y.columns.ptr,
756
- 0,
757
- )
758
-
759
- _bsr_ensure_fits(y)
760
- y.values.zero_()
761
-
762
- wp.launch(
763
- kernel=_bsr_axpy_add_block,
764
- device=device,
765
- dim=old_y_nnz,
766
- inputs=[
767
- 0,
768
- beta,
769
- work_arrays._sum_rows,
770
- work_arrays._sum_cols,
771
- y.offsets,
772
- y.columns,
773
- work_arrays._old_y_values,
774
- y.values,
775
- ],
776
- )
777
-
778
- wp.launch(
779
- kernel=_bsr_axpy_add_block,
780
- device=device,
781
- dim=x.nnz,
782
- inputs=[
783
- old_y_nnz,
784
- alpha,
785
- work_arrays._sum_rows,
786
- work_arrays._sum_cols,
787
- y.offsets,
788
- y.columns,
789
- x.values,
790
- y.values,
791
- ],
792
- )
793
-
794
- return y
795
-
796
-
797
- @wp.kernel
798
- def _bsr_mm_count_coeffs(
799
- z_nnz: int,
800
- x_offsets: wp.array(dtype=int),
801
- x_columns: wp.array(dtype=int),
802
- y_offsets: wp.array(dtype=int),
803
- counts: wp.array(dtype=int),
804
- ):
805
- row = wp.tid()
806
- count = int(0)
807
-
808
- x_beg = x_offsets[row]
809
- x_end = x_offsets[row + 1]
810
-
811
- for x_block in range(x_beg, x_end):
812
- x_col = x_columns[x_block]
813
- count += y_offsets[x_col + 1] - y_offsets[x_col]
814
-
815
- counts[row + 1] = count
816
-
817
- if row == 0:
818
- counts[0] = z_nnz
819
-
820
-
821
- @wp.kernel
822
- def _bsr_mm_list_coeffs(
823
- x_offsets: wp.array(dtype=int),
824
- x_columns: wp.array(dtype=int),
825
- y_offsets: wp.array(dtype=int),
826
- y_columns: wp.array(dtype=int),
827
- mm_offsets: wp.array(dtype=int),
828
- mm_rows: wp.array(dtype=int),
829
- mm_cols: wp.array(dtype=int),
830
- ):
831
- row = wp.tid()
832
- mm_block = mm_offsets[row]
833
-
834
- x_beg = x_offsets[row]
835
- x_end = x_offsets[row + 1]
836
-
837
- for x_block in range(x_beg, x_end):
838
- x_col = x_columns[x_block]
839
-
840
- y_beg = y_offsets[x_col]
841
- y_end = y_offsets[x_col + 1]
842
- for y_block in range(y_beg, y_end):
843
- mm_cols[mm_block] = y_columns[y_block]
844
- mm_rows[mm_block] = row
845
- mm_block += 1
846
-
847
-
848
- @wp.kernel
849
- def _bsr_mm_compute_values(
850
- alpha: Any,
851
- x_offsets: wp.array(dtype=int),
852
- x_columns: wp.array(dtype=int),
853
- x_values: wp.array(dtype=Any),
854
- y_offsets: wp.array(dtype=int),
855
- y_columns: wp.array(dtype=int),
856
- y_values: wp.array(dtype=Any),
857
- mm_offsets: wp.array(dtype=int),
858
- mm_cols: wp.array(dtype=int),
859
- mm_values: wp.array(dtype=Any),
860
- ):
861
- row = wp.tid()
862
- mm_beg = mm_offsets[row]
863
- mm_end = mm_offsets[row + 1]
864
-
865
- x_beg = x_offsets[row]
866
- x_end = x_offsets[row + 1]
867
- for x_block in range(x_beg, x_end):
868
- x_col = x_columns[x_block]
869
- ax_val = alpha * x_values[x_block]
870
-
871
- y_beg = y_offsets[x_col]
872
- y_end = y_offsets[x_col + 1]
873
-
874
- for y_block in range(y_beg, y_end):
875
- mm_block = wp.lower_bound(mm_cols, mm_beg, mm_end, y_columns[y_block])
876
- mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
877
-
878
-
879
- class bsr_mm_work_arrays:
880
- """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
881
-
882
- def __init__(self):
883
- self._reset(None)
884
-
885
- def _reset(self, device):
886
- self.device = device
887
- self._pinned_count_buffer = None
888
- self._mm_row_counts = None
889
- self._mm_rows = None
890
- self._mm_cols = None
891
- self._old_z_values = None
892
- self._old_z_offsets = None
893
- self._old_z_columns = None
894
-
895
- def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
896
- if self.device != device:
897
- self._reset(device)
898
-
899
- # Allocations that do not depend on any computation
900
- if self.device.is_cuda:
901
- if self._pinned_count_buffer is None:
902
- self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
903
-
904
- if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
905
- self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
906
-
907
- if copied_z_nnz > 0:
908
- if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
909
- self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
910
-
911
- if z_aliasing:
912
- if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
913
- self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
914
- if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
915
- self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
916
-
917
- def _allocate_stage_2(self, mm_nnz: int):
918
- # Allocations that depend on unmerged nnz estimate
919
- if self._mm_rows is None or self._mm_rows.size < mm_nnz:
920
- self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
921
- if self._mm_cols is None or self._mm_cols.size < mm_nnz:
922
- self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
923
-
924
-
925
- def bsr_mm(
926
- x: BsrMatrix[BlockType[Rows, Any, Scalar]],
927
- y: BsrMatrix[BlockType[Any, Cols, Scalar]],
928
- z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
929
- alpha: Scalar = 1.0,
930
- beta: Scalar = 0.0,
931
- work_arrays: Optional[bsr_mm_work_arrays] = None,
932
- ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
933
- """
934
- Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
935
-
936
- The `x`, `y` and `z` matrices are allowed to alias.
937
- If the matrix `z` is not provided as input, it will be allocated and treated as zero.
938
-
939
- Args:
940
- x: Read-only left factor of the matrix-matrix product.
941
- y: Read-only right factor of the matrix-matrix product.
942
- z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
943
- alpha: Uniform scaling factor for the ``x * y`` product
944
- beta: Uniform scaling factor for `z`
945
- work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
946
- """
947
-
948
- if z is None:
949
- # If not output matrix is provided, allocate it for convenience
950
- z_block_shape = (x.block_shape[0], y.block_shape[1])
951
- if z_block_shape == (1, 1):
952
- z_block_type = x.scalar_type
953
- else:
954
- z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type)
955
- z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
956
- beta = 0.0
957
-
958
- if x.values.device != y.values.device or x.values.device != z.values.device:
959
- raise ValueError("All arguments must reside on the same device")
960
-
961
- if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
962
- raise ValueError("Matrices must have the same scalar type")
963
-
964
- if (
965
- x.block_shape[0] != z.block_shape[0]
966
- or y.block_shape[1] != z.block_shape[1]
967
- or x.block_shape[1] != y.block_shape[0]
968
- ):
969
- raise ValueError("Incompatible block sizes for matrix multiplication")
970
-
971
- if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
972
- raise ValueError("Incompatible number of rows/columns for matrix multiplication")
973
-
974
- device = z.values.device
975
-
976
- if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
977
- # Easy case
978
- return bsr_scale(z, beta)
979
-
980
- if not isinstance(alpha, z.scalar_type):
981
- alpha = z.scalar_type(alpha)
982
- if not isinstance(beta, z.scalar_type):
983
- beta = z.scalar_type(beta)
984
-
985
- if work_arrays is None:
986
- work_arrays = bsr_mm_work_arrays()
987
-
988
- z_aliasing = z == x or z == y
989
- copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
990
-
991
- work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
992
-
993
- # Prefix sum of number of (unmerged) mm blocks per row
994
- wp.launch(
995
- kernel=_bsr_mm_count_coeffs,
996
- device=device,
997
- dim=z.nrow,
998
- inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
999
- )
1000
- warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
1001
-
1002
- # Get back total counts on host
1003
- if device.is_cuda:
1004
- wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
1005
- wp.synchronize_stream(wp.get_stream(device))
1006
- mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
1007
- else:
1008
- mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
1009
-
1010
- work_arrays._allocate_stage_2(mm_nnz)
1011
-
1012
- # If z has a non-zero scale, save current data before overwriting it
1013
- if copied_z_nnz > 0:
1014
- # Copy z row and column indices
1015
- wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1016
- wp.launch(
1017
- kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
1018
- )
1019
- # Save current z values in temporary buffer
1020
- wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1021
- if z_aliasing:
1022
- # If z is aliasing with x or y, need to save topology as well
1023
- wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1024
- wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1025
-
1026
- # Fill unmerged mm blocks rows and columns
1027
- wp.launch(
1028
- kernel=_bsr_mm_list_coeffs,
1029
- device=device,
1030
- dim=z.nrow,
1031
- inputs=[
1032
- x.offsets,
1033
- x.columns,
1034
- y.offsets,
1035
- y.columns,
1036
- work_arrays._mm_row_counts,
1037
- work_arrays._mm_rows,
1038
- work_arrays._mm_cols,
1039
- ],
1040
- )
1041
-
1042
- # Increase dest array size if needed
1043
- if z.columns.shape[0] < mm_nnz:
1044
- z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1045
-
1046
- from warp.context import runtime
1047
-
1048
- if device.is_cpu:
1049
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
1050
- else:
1051
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
1052
-
1053
- z.nnz = native_func(
1054
- z.block_shape[0],
1055
- z.block_shape[1],
1056
- z.nrow,
1057
- mm_nnz,
1058
- work_arrays._mm_rows.ptr,
1059
- work_arrays._mm_cols.ptr,
1060
- 0,
1061
- z.offsets.ptr,
1062
- z.columns.ptr,
1063
- 0,
1064
- )
1065
-
1066
- _bsr_ensure_fits(z)
1067
- z.values.zero_()
1068
-
1069
- if copied_z_nnz > 0:
1070
- # Add back original z values
1071
- wp.launch(
1072
- kernel=_bsr_axpy_add_block,
1073
- device=device,
1074
- dim=copied_z_nnz,
1075
- inputs=[
1076
- 0,
1077
- beta,
1078
- work_arrays._mm_rows,
1079
- work_arrays._mm_cols,
1080
- z.offsets,
1081
- z.columns,
1082
- work_arrays._old_z_values,
1083
- z.values,
1084
- ],
1085
- )
1086
-
1087
- # Add mm blocks to z values
1088
- if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
1089
- warp.types.type_is_matrix(z.values.dtype)
1090
- ):
1091
- # Result block type is scalar, but operands are matrices
1092
- # Cast result to (1x1) matrix to perform multiplication
1093
- mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1094
- else:
1095
- mm_values = z.values
1096
-
1097
- wp.launch(
1098
- kernel=_bsr_mm_compute_values,
1099
- device=device,
1100
- dim=z.nrow,
1101
- inputs=[
1102
- alpha,
1103
- work_arrays._old_z_offsets if x == z else x.offsets,
1104
- work_arrays._old_z_columns if x == z else x.columns,
1105
- work_arrays._old_z_values if x == z else x.values,
1106
- work_arrays._old_z_offsets if y == z else y.offsets,
1107
- work_arrays._old_z_columns if y == z else y.columns,
1108
- work_arrays._old_z_values if y == z else y.values,
1109
- z.offsets,
1110
- z.columns,
1111
- mm_values,
1112
- ],
1113
- )
1114
-
1115
- return z
1116
-
1117
-
1118
- @wp.kernel
1119
- def _bsr_mv_kernel(
1120
- alpha: Any,
1121
- A_offsets: wp.array(dtype=int),
1122
- A_columns: wp.array(dtype=int),
1123
- A_values: wp.array(dtype=Any),
1124
- x: wp.array(dtype=Any),
1125
- beta: Any,
1126
- y: wp.array(dtype=Any),
1127
- ):
1128
- row = wp.tid()
1129
-
1130
- # zero-initialize with type of y elements
1131
- scalar_zero = type(alpha)(0)
1132
- v = y.dtype(scalar_zero)
1133
-
1134
- if alpha != scalar_zero:
1135
- beg = A_offsets[row]
1136
- end = A_offsets[row + 1]
1137
- for block in range(beg, end):
1138
- v += A_values[block] * x[A_columns[block]]
1139
- v *= alpha
1140
-
1141
- if beta != scalar_zero:
1142
- v += beta * y[row]
1143
-
1144
- y[row] = v
1145
-
1146
-
1147
- def bsr_mv(
1148
- A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
1149
- x: "Array[Vector[Cols, Scalar] | Scalar]",
1150
- y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1151
- alpha: Scalar = 1.0,
1152
- beta: Scalar = 0.0,
1153
- work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1154
- ) -> "Array[Vector[Rows, Scalar] | Scalar]":
1155
- """
1156
- Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
1157
-
1158
- The `x` and `y` vectors are allowed to alias.
1159
-
1160
- Args:
1161
- A: Read-only, left matrix factor of the matrix-vector product.
1162
- x: Read-only, right vector factor of the matrix-vector product.
1163
- y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1164
- alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1165
- beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1166
- work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1167
- will be used for this purpose, otherwise a temporary allocation will be performed.
1168
- """
1169
-
1170
- if y is None:
1171
- # If no output array is provided, allocate one for convenience
1172
- y_vec_len = A.block_shape[0]
1173
- y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1174
- y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
1175
- y.zero_()
1176
- beta = 0.0
1177
-
1178
- if not isinstance(alpha, A.scalar_type):
1179
- alpha = A.scalar_type(alpha)
1180
- if not isinstance(beta, A.scalar_type):
1181
- beta = A.scalar_type(beta)
1182
-
1183
- if A.values.device != x.device or A.values.device != y.device:
1184
- raise ValueError("A, x and y must reside on the same device")
1185
-
1186
- if x.shape[0] != A.ncol:
1187
- raise ValueError("Number of columns of A must match number of rows of x")
1188
- if y.shape[0] != A.nrow:
1189
- raise ValueError("Number of rows of A must match number of rows of y")
1190
-
1191
- if x == y:
1192
- # Aliasing case, need temporary storage
1193
- if work_buffer is None:
1194
- work_buffer = wp.empty_like(y)
1195
- elif work_buffer.size < y.size:
1196
- raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
1197
- elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
1198
- raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
1199
-
1200
- # Save old y values before overwriting vector
1201
- wp.copy(dest=work_buffer, src=y, count=y.size)
1202
- x = work_buffer
1203
-
1204
- # Promote scalar vectors to length-1 vecs and conversely
1205
- if warp.types.type_is_matrix(A.values.dtype):
1206
- if A.block_shape[0] == 1:
1207
- if y.dtype == A.scalar_type:
1208
- y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1209
- if A.block_shape[1] == 1:
1210
- if x.dtype == A.scalar_type:
1211
- x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1212
- else:
1213
- if A.block_shape[0] == 1:
1214
- if y.dtype != A.scalar_type:
1215
- y = y.view(dtype=A.scalar_type)
1216
- if A.block_shape[1] == 1:
1217
- if x.dtype != A.scalar_type:
1218
- x = x.view(dtype=A.scalar_type)
1219
-
1220
- wp.launch(
1221
- kernel=_bsr_mv_kernel,
1222
- device=A.values.device,
1223
- dim=A.nrow,
1224
- inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
1225
- )
1226
-
1227
- return y
1
+ from typing import Any, Generic, Optional, Tuple, TypeVar, Union
2
+
3
+ import warp as wp
4
+ import warp.types
5
+ import warp.utils
6
+ from warp.types import Array, Cols, Rows, Scalar, Vector
7
+
8
+ # typing hints
9
+
10
+ _BlockType = TypeVar("BlockType")
11
+
12
+
13
+ class _MatrixBlockType(Generic[Rows, Cols, Scalar]):
14
+ pass
15
+
16
+
17
+ class _ScalarBlockType(Generic[Scalar]):
18
+ pass
19
+
20
+
21
+ BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
22
+
23
+ _struct_cache = {}
24
+
25
+
26
+ class BsrMatrix(Generic[_BlockType]):
27
+ """Untyped base class for BSR and CSR matrices.
28
+
29
+ Should not be constructed directly but through functions such as :func:`bsr_zeros`.
30
+
31
+ Attributes:
32
+ nrow (int): Number of rows of blocks
33
+ ncol (int): Number of columns of blocks
34
+ nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
35
+ offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
36
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
37
+ values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
38
+ """
39
+
40
+ @property
41
+ def scalar_type(self) -> Scalar:
42
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
43
+ return warp.types.type_scalar_type(self.values.dtype)
44
+
45
+ @property
46
+ def block_shape(self) -> Tuple[int, int]:
47
+ """Shape of the individual blocks"""
48
+ return getattr(self.values.dtype, "_shape_", (1, 1))
49
+
50
+ @property
51
+ def block_size(self) -> int:
52
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
53
+ return warp.types.type_length(self.values.dtype)
54
+
55
+ @property
56
+ def shape(self) -> Tuple[int, int]:
57
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
58
+ block_shape = self.block_shape
59
+ return (self.nrow * block_shape[0], self.ncol * block_shape[1])
60
+
61
+ @property
62
+ def dtype(self) -> type:
63
+ """Data type for individual block values"""
64
+ return self.values.dtype
65
+
66
+ @property
67
+ def device(self) -> wp.context.Device:
68
+ """Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays"""
69
+ return self.values.device
70
+
71
+
72
+ def bsr_matrix_t(dtype: BlockType):
73
+ dtype = wp.types.type_to_warp(dtype)
74
+
75
+ if not warp.types.type_is_matrix(dtype) and dtype not in warp.types.scalar_types:
76
+ raise ValueError(
77
+ f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
78
+ )
79
+
80
+ class BsrMatrixTyped(BsrMatrix):
81
+ nrow: int
82
+ """Number of rows of blocks"""
83
+ ncol: int
84
+ """Number of columns of blocks"""
85
+ nnz: int
86
+ """Number of non-zero blocks: equal to offsets[-1], cached on host for convenience"""
87
+ offsets: wp.array(dtype=int)
88
+ """Array of size at least 1 + nrows"""
89
+ columns: wp.array(dtype=int)
90
+ """Array of size at least equal to nnz"""
91
+ values: wp.array(dtype=dtype)
92
+
93
+ module = wp.get_module(BsrMatrix.__module__)
94
+
95
+ if hasattr(dtype, "_shape_"):
96
+ type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
97
+ else:
98
+ type_str = dtype.__name__
99
+ key = f"{BsrMatrix.__qualname__}_{type_str}"
100
+
101
+ if key not in _struct_cache:
102
+ _struct_cache[key] = wp.codegen.Struct(
103
+ cls=BsrMatrixTyped,
104
+ key=key,
105
+ module=module,
106
+ )
107
+
108
+ return _struct_cache[key]
109
+
110
+
111
+ def bsr_zeros(
112
+ rows_of_blocks: int,
113
+ cols_of_blocks: int,
114
+ block_type: BlockType,
115
+ device: wp.context.Devicelike = None,
116
+ ) -> BsrMatrix:
117
+ """
118
+ Constructs and returns an empty BSR or CSR matrix with the given shape
119
+
120
+ Args:
121
+ bsr: The BSR or CSR matrix to set to zero
122
+ rows_of_blocks: Number of rows of blocks
123
+ cols_of_blocks: Number of columns of blocks
124
+ block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
125
+ for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
126
+ device: Device on which to allocate the matrix arrays
127
+ """
128
+
129
+ bsr = bsr_matrix_t(block_type)()
130
+
131
+ bsr.nrow = int(rows_of_blocks)
132
+ bsr.ncol = int(cols_of_blocks)
133
+ bsr.nnz = 0
134
+ bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
135
+ bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
136
+ bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
137
+
138
+ return bsr
139
+
140
+
141
+ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
142
+ if nrow is None:
143
+ nrow = bsr.nrow
144
+ if nnz is None:
145
+ nnz = bsr.nnz
146
+
147
+ if bsr.offsets.size < nrow + 1:
148
+ bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
149
+ if bsr.columns.size < nnz:
150
+ bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
151
+ if bsr.values.size < nnz:
152
+ bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
153
+
154
+
155
+ def bsr_set_zero(
156
+ bsr: BsrMatrix,
157
+ rows_of_blocks: Optional[int] = None,
158
+ cols_of_blocks: Optional[int] = None,
159
+ ):
160
+ """
161
+ Sets a BSR matrix to zero, possibly changing its size
162
+
163
+ Args:
164
+ bsr: The BSR or CSR matrix to set to zero
165
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
166
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
167
+ """
168
+
169
+ if rows_of_blocks is not None:
170
+ bsr.nrow = int(rows_of_blocks)
171
+ if cols_of_blocks is not None:
172
+ bsr.ncol = int(cols_of_blocks)
173
+ bsr.nnz = 0
174
+ _bsr_ensure_fits(bsr)
175
+ bsr.offsets.zero_()
176
+
177
+
178
+ def bsr_set_from_triplets(
179
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
180
+ rows: "Array[int]",
181
+ columns: "Array[int]",
182
+ values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
183
+ ):
184
+ """
185
+ Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
186
+
187
+ The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
188
+
189
+ Args:
190
+ dest: Sparse matrix to populate
191
+ rows: Row index for each non-zero
192
+ columns: Columns index for each non-zero
193
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
194
+ to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
195
+ """
196
+
197
+ if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
198
+ raise ValueError("All arguments must reside on the same device")
199
+
200
+ if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]:
201
+ raise ValueError("All triplet arrays must have the same length")
202
+
203
+ # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
204
+ if values.ndim == 1:
205
+ if values.dtype != dest.values.dtype:
206
+ raise ValueError("Values array type must correspond to that of dest matrix")
207
+ elif values.ndim == 3:
208
+ if values.shape[1:] != dest.block_shape:
209
+ raise ValueError(
210
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
211
+ )
212
+
213
+ if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
214
+ raise ValueError("Scalar type of values array should correspond to that of matrix")
215
+
216
+ if not values.is_contiguous:
217
+ raise ValueError("Multi-dimensional values array should be contiguous")
218
+ else:
219
+ raise ValueError("Number of dimension for values array should be 1 or 3")
220
+
221
+ nnz = rows.shape[0]
222
+ if nnz == 0:
223
+ bsr_set_zero(dest)
224
+ return
225
+
226
+ # Increase dest array sizes if needed
227
+ _bsr_ensure_fits(dest, nnz=nnz)
228
+
229
+ device = dest.values.device
230
+ scalar_type = dest.scalar_type
231
+ from warp.context import runtime
232
+
233
+ if device.is_cpu:
234
+ if scalar_type == wp.float32:
235
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
236
+ elif scalar_type == wp.float64:
237
+ native_func = runtime.core.bsr_matrix_from_triplets_double_host
238
+ else:
239
+ if scalar_type == wp.float32:
240
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
241
+ elif scalar_type == wp.float64:
242
+ native_func = runtime.core.bsr_matrix_from_triplets_double_device
243
+
244
+ if not native_func:
245
+ raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
246
+
247
+ dest.nnz = native_func(
248
+ dest.block_shape[0],
249
+ dest.block_shape[1],
250
+ dest.nrow,
251
+ nnz,
252
+ rows.ptr,
253
+ columns.ptr,
254
+ values.ptr,
255
+ dest.offsets.ptr,
256
+ dest.columns.ptr,
257
+ dest.values.ptr,
258
+ )
259
+
260
+
261
+ def bsr_assign(
262
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
263
+ src: BsrMatrix[BlockType[Rows, Cols, Any]],
264
+ ):
265
+ """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
266
+
267
+ if dest.values.device != src.values.device:
268
+ raise ValueError("Source and destination matrices must reside on the same device")
269
+
270
+ if dest.block_shape != src.block_shape:
271
+ raise ValueError("Source and destination matrices must have the same block shape")
272
+
273
+ dest.nrow = src.nrow
274
+ dest.ncol = src.ncol
275
+ dest.nnz = src.nnz
276
+
277
+ _bsr_ensure_fits(dest)
278
+
279
+ wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
280
+ if src.nnz > 0:
281
+ wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
282
+ warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
283
+
284
+
285
+ def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
286
+ """Returns a copy of matrix ``A``, possibly changing its scalar type.
287
+
288
+ Args:
289
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
290
+ """
291
+ if scalar_type is None:
292
+ block_type = A.values.dtype
293
+ elif A.block_shape == (1, 1):
294
+ block_type = scalar_type
295
+ else:
296
+ block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type)
297
+
298
+ copy = bsr_zeros(
299
+ rows_of_blocks=A.nrow,
300
+ cols_of_blocks=A.ncol,
301
+ block_type=block_type,
302
+ device=A.values.device,
303
+ )
304
+ bsr_assign(dest=copy, src=A)
305
+ return copy
306
+
307
+
308
+ def bsr_set_transpose(
309
+ dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
310
+ src: BsrMatrix[BlockType[Rows, Cols, Scalar]],
311
+ ):
312
+ """Assigns the transposed matrix `src` to matrix `dest`"""
313
+
314
+ if dest.values.device != src.values.device:
315
+ raise ValueError("All arguments must reside on the same device")
316
+
317
+ if dest.scalar_type != src.scalar_type:
318
+ raise ValueError("All arguments must have the same scalar type")
319
+
320
+ transpose_block_shape = src.block_shape[::-1]
321
+
322
+ if dest.block_shape != transpose_block_shape:
323
+ raise ValueError(f"Destination block shape must be {transpose_block_shape}")
324
+
325
+ dest.nrow = src.ncol
326
+ dest.ncol = src.nrow
327
+ dest.nnz = src.nnz
328
+
329
+ if src.nnz == 0:
330
+ return
331
+
332
+ # Increase dest array sizes if needed
333
+ _bsr_ensure_fits(dest)
334
+
335
+ from warp.context import runtime
336
+
337
+ if dest.values.device.is_cpu:
338
+ if dest.scalar_type == wp.float32:
339
+ native_func = runtime.core.bsr_transpose_float_host
340
+ elif dest.scalar_type == wp.float64:
341
+ native_func = runtime.core.bsr_transpose_double_host
342
+ else:
343
+ if dest.scalar_type == wp.float32:
344
+ native_func = runtime.core.bsr_transpose_float_device
345
+ elif dest.scalar_type == wp.float64:
346
+ native_func = runtime.core.bsr_transpose_double_device
347
+
348
+ if not native_func:
349
+ raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
350
+
351
+ native_func(
352
+ src.block_shape[0],
353
+ src.block_shape[1],
354
+ src.nrow,
355
+ src.ncol,
356
+ src.nnz,
357
+ src.offsets.ptr,
358
+ src.columns.ptr,
359
+ src.values.ptr,
360
+ dest.offsets.ptr,
361
+ dest.columns.ptr,
362
+ dest.values.ptr,
363
+ )
364
+
365
+
366
+ def bsr_transposed(A: BsrMatrix):
367
+ """Returns a copy of the transposed matrix `A`"""
368
+
369
+ if A.block_shape == (1, 1):
370
+ block_type = A.values.dtype
371
+ else:
372
+ block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type)
373
+
374
+ transposed = bsr_zeros(
375
+ rows_of_blocks=A.ncol,
376
+ cols_of_blocks=A.nrow,
377
+ block_type=block_type,
378
+ device=A.values.device,
379
+ )
380
+ bsr_set_transpose(dest=transposed, src=A)
381
+ return transposed
382
+
383
+
384
+ @wp.kernel
385
+ def _bsr_get_diag_kernel(
386
+ A_offsets: wp.array(dtype=int),
387
+ A_columns: wp.array(dtype=int),
388
+ A_values: wp.array(dtype=Any),
389
+ out: wp.array(dtype=Any),
390
+ ):
391
+ row = wp.tid()
392
+ beg = A_offsets[row]
393
+ end = A_offsets[row + 1]
394
+
395
+ diag = wp.lower_bound(A_columns, beg, end, row)
396
+ if diag < end:
397
+ if A_columns[diag] == row:
398
+ out[row] = A_values[diag]
399
+
400
+
401
+ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
402
+ """Returns the array of blocks that constitute the diagonal of a sparse matrix.
403
+
404
+ Args:
405
+ A: the sparse matrix from which to extract the diagonal
406
+ out: if provided, the array into which to store the diagonal blocks
407
+ """
408
+
409
+ dim = min(A.nrow, A.ncol)
410
+
411
+ if out is None:
412
+ out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
413
+ else:
414
+ if out.dtype != A.values.dtype:
415
+ raise ValueError(f"Output array must have type {A.values.dtype}")
416
+ if out.device != A.values.device:
417
+ raise ValueError(f"Output array must reside on device {A.values.device}")
418
+ if out.shape[0] < dim:
419
+ raise ValueError(f"Output array must be of length at least {dim}")
420
+
421
+ wp.launch(
422
+ kernel=_bsr_get_diag_kernel,
423
+ dim=dim,
424
+ device=A.values.device,
425
+ inputs=[A.offsets, A.columns, A.values, out],
426
+ )
427
+
428
+ return out
429
+
430
+
431
+ @wp.kernel
432
+ def _bsr_set_diag_kernel(
433
+ diag: wp.array(dtype=Any),
434
+ A_offsets: wp.array(dtype=int),
435
+ A_columns: wp.array(dtype=int),
436
+ A_values: wp.array(dtype=Any),
437
+ ):
438
+ row = wp.tid()
439
+ A_offsets[row + 1] = row + 1
440
+ A_columns[row] = row
441
+ A_values[row] = diag[row]
442
+
443
+ if row == 0:
444
+ A_offsets[0] = 0
445
+
446
+
447
+ @wp.kernel
448
+ def _bsr_set_diag_constant_kernel(
449
+ diag_value: Any,
450
+ A_offsets: wp.array(dtype=int),
451
+ A_columns: wp.array(dtype=int),
452
+ A_values: wp.array(dtype=Any),
453
+ ):
454
+ row = wp.tid()
455
+ A_offsets[row + 1] = row + 1
456
+ A_columns[row] = row
457
+ A_values[row] = diag_value
458
+
459
+ if row == 0:
460
+ A_offsets[0] = 0
461
+
462
+
463
+ def bsr_set_diag(
464
+ A: BsrMatrix[BlockType],
465
+ diag: "Union[BlockType, Array[BlockType]]",
466
+ rows_of_blocks: Optional[int] = None,
467
+ cols_of_blocks: Optional[int] = None,
468
+ ):
469
+ """Sets `A` as a block-diagonal matrix
470
+
471
+ Args:
472
+ A: the sparse matrix to modify
473
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
474
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
475
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
476
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
477
+
478
+ The shape of the matrix will be defined one of the following, in that order:
479
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
480
+ - the first dimension of `diag`, if `diag` is an array
481
+ - the current dimensions of `A` otherwise
482
+ """
483
+
484
+ if rows_of_blocks is None and cols_of_blocks is not None:
485
+ rows_of_blocks = cols_of_blocks
486
+ if cols_of_blocks is None and rows_of_blocks is not None:
487
+ cols_of_blocks = rows_of_blocks
488
+
489
+ if warp.types.is_array(diag):
490
+ if rows_of_blocks is None:
491
+ rows_of_blocks = diag.shape[0]
492
+ cols_of_blocks = diag.shape[0]
493
+
494
+ if rows_of_blocks is not None:
495
+ A.nrow = rows_of_blocks
496
+ A.ncol = cols_of_blocks
497
+
498
+ A.nnz = min(A.nrow, A.ncol)
499
+ _bsr_ensure_fits(A)
500
+
501
+ if warp.types.is_array(diag):
502
+ wp.launch(
503
+ kernel=_bsr_set_diag_kernel,
504
+ dim=A.nnz,
505
+ device=A.values.device,
506
+ inputs=[diag, A.offsets, A.columns, A.values],
507
+ )
508
+ else:
509
+ if not warp.types.type_is_value(type(diag)):
510
+ # Cast to launchable type
511
+ diag = A.values.dtype(diag)
512
+ wp.launch(
513
+ kernel=_bsr_set_diag_constant_kernel,
514
+ dim=A.nnz,
515
+ device=A.values.device,
516
+ inputs=[diag, A.offsets, A.columns, A.values],
517
+ )
518
+
519
+
520
+ def bsr_diag(
521
+ diag: "Union[BlockType, Array[BlockType]]",
522
+ rows_of_blocks: Optional[int] = None,
523
+ cols_of_blocks: Optional[int] = None,
524
+ ) -> BsrMatrix["BlockType"]:
525
+ """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
526
+
527
+ Args:
528
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
529
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
530
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
531
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
532
+
533
+ The shape of the matrix will be defined one of the following, in that order:
534
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
535
+ - the first dimension of `diag`, if `diag` is an array
536
+ """
537
+
538
+ if rows_of_blocks is None and cols_of_blocks is not None:
539
+ rows_of_blocks = cols_of_blocks
540
+ if cols_of_blocks is None and rows_of_blocks is not None:
541
+ cols_of_blocks = rows_of_blocks
542
+
543
+ if warp.types.is_array(diag):
544
+ if rows_of_blocks is None:
545
+ rows_of_blocks = diag.shape[0]
546
+ cols_of_blocks = diag.shape[0]
547
+
548
+ A = bsr_zeros(
549
+ rows_of_blocks,
550
+ cols_of_blocks,
551
+ block_type=diag.dtype,
552
+ device=diag.device,
553
+ )
554
+ else:
555
+ if rows_of_blocks is None:
556
+ raise ValueError(
557
+ "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
558
+ )
559
+
560
+ block_type = type(diag)
561
+ if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
562
+ block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
563
+
564
+ A = bsr_zeros(
565
+ rows_of_blocks,
566
+ cols_of_blocks,
567
+ block_type=block_type,
568
+ )
569
+
570
+ bsr_set_diag(A, diag)
571
+ return A
572
+
573
+
574
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
575
+ """Sets `A` as the identity matrix
576
+
577
+ Args:
578
+ A: the sparse matrix to modify
579
+ rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
580
+ """
581
+
582
+ if A.block_shape == (1, 1):
583
+ identity = A.scalar_type(1.0)
584
+ else:
585
+ from numpy import eye
586
+
587
+ identity = eye(A.block_shape[0])
588
+
589
+ bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
590
+
591
+
592
+ def bsr_identity(
593
+ rows_of_blocks: int,
594
+ block_type: BlockType[Rows, Rows, Scalar],
595
+ device: wp.context.Devicelike = None,
596
+ ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
597
+ """Creates and returns a square identity matrix.
598
+
599
+ Args:
600
+ rows_of_blocks: Number of rows and columns of blocks in the created matrix.
601
+ block_type: Block type for the newly created matrix -- must be square
602
+ device: Device onto which to allocate the data arrays
603
+ """
604
+ A = bsr_zeros(
605
+ rows_of_blocks=rows_of_blocks,
606
+ cols_of_blocks=rows_of_blocks,
607
+ block_type=block_type,
608
+ device=device,
609
+ )
610
+ bsr_set_identity(A)
611
+ return A
612
+
613
+
614
+ @wp.kernel
615
+ def _bsr_scale_kernel(
616
+ alpha: Any,
617
+ values: wp.array(dtype=Any),
618
+ ):
619
+ values[wp.tid()] = alpha * values[wp.tid()]
620
+
621
+
622
+ def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
623
+ """
624
+ Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
625
+ """
626
+
627
+ if alpha != 1.0 and x.nnz > 0:
628
+ if alpha == 0.0:
629
+ bsr_set_zero(x)
630
+ else:
631
+ if not isinstance(alpha, x.scalar_type):
632
+ alpha = x.scalar_type(alpha)
633
+
634
+ wp.launch(
635
+ kernel=_bsr_scale_kernel,
636
+ dim=x.nnz,
637
+ device=x.values.device,
638
+ inputs=[alpha, x.values],
639
+ )
640
+
641
+ return x
642
+
643
+
644
+ @wp.kernel
645
+ def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
646
+ i = wp.tid()
647
+
648
+ row = wp.lower_bound(bsr_offsets, i + 1) - 1
649
+ rows[dest_offset + i] = row
650
+
651
+
652
+ @wp.kernel
653
+ def _bsr_axpy_add_block(
654
+ src_offset: int,
655
+ scale: Any,
656
+ rows: wp.array(dtype=int),
657
+ cols: wp.array(dtype=int),
658
+ dst_offsets: wp.array(dtype=int),
659
+ dst_columns: wp.array(dtype=int),
660
+ src_values: wp.array(dtype=Any),
661
+ dst_values: wp.array(dtype=Any),
662
+ ):
663
+ i = wp.tid()
664
+ row = rows[i + src_offset]
665
+ col = cols[i + src_offset]
666
+ beg = dst_offsets[row]
667
+ end = dst_offsets[row + 1]
668
+
669
+ block = wp.lower_bound(dst_columns, beg, end, col)
670
+
671
+ dst_values[block] = dst_values[block] + scale * src_values[i]
672
+
673
+
674
+ class bsr_axpy_work_arrays:
675
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
676
+
677
+ def __init__(self):
678
+ self._reset(None)
679
+
680
+ def _reset(self, device):
681
+ self.device = device
682
+ self._sum_rows = None
683
+ self._sum_cols = None
684
+ self._old_y_values = None
685
+ self._old_x_values = None
686
+
687
+ def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
688
+ if self.device != device:
689
+ self._reset(device)
690
+
691
+ if self._sum_rows is None or self._sum_rows.size < sum_nnz:
692
+ self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
693
+ if self._sum_cols is None or self._sum_cols.size < sum_nnz:
694
+ self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
695
+
696
+ if self._old_y_values is None or self._old_y_values.size < y.nnz:
697
+ self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
698
+
699
+
700
+ def bsr_axpy(
701
+ x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
702
+ y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
703
+ alpha: Scalar = 1.0,
704
+ beta: Scalar = 1.0,
705
+ work_arrays: Optional[bsr_axpy_work_arrays] = None,
706
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
707
+ """
708
+ Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
709
+
710
+ The `x` and `y` matrices are allowed to alias.
711
+
712
+ Args:
713
+ x: Read-only right-hand-side.
714
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
715
+ alpha: Uniform scaling factor for `x`
716
+ beta: Uniform scaling factor for `y`
717
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
718
+ """
719
+
720
+ if y is None:
721
+ # If not output matrix is provided, allocate it for convenience
722
+ y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
723
+ beta = 0.0
724
+
725
+ # Handle easy cases first
726
+ if beta == 0.0 or y.nnz == 0:
727
+ bsr_assign(src=x, dest=y)
728
+ return bsr_scale(y, alpha=alpha)
729
+
730
+ if alpha == 0.0 or x.nnz == 0:
731
+ return bsr_scale(y, alpha=beta)
732
+
733
+ if not isinstance(alpha, y.scalar_type):
734
+ alpha = y.scalar_type(alpha)
735
+ if not isinstance(beta, y.scalar_type):
736
+ beta = y.scalar_type(beta)
737
+
738
+ if x == y:
739
+ # Aliasing case
740
+ return bsr_scale(y, alpha=alpha.value + beta.value)
741
+
742
+ # General case
743
+
744
+ if x.values.device != y.values.device:
745
+ raise ValueError("All arguments must reside on the same device")
746
+
747
+ if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
748
+ raise ValueError("Matrices must have the same block type")
749
+
750
+ if x.nrow != y.nrow or x.ncol != y.ncol:
751
+ raise ValueError("Matrices must have the same number of rows and columns")
752
+
753
+ if work_arrays is None:
754
+ work_arrays = bsr_axpy_work_arrays()
755
+
756
+ sum_nnz = x.nnz + y.nnz
757
+ device = y.values.device
758
+ work_arrays._allocate(device, y, sum_nnz)
759
+
760
+ wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
761
+ wp.launch(
762
+ kernel=_bsr_get_block_row,
763
+ device=device,
764
+ dim=y.nnz,
765
+ inputs=[0, y.offsets, work_arrays._sum_rows],
766
+ )
767
+
768
+ wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
769
+ wp.launch(
770
+ kernel=_bsr_get_block_row,
771
+ device=device,
772
+ dim=x.nnz,
773
+ inputs=[y.nnz, x.offsets, work_arrays._sum_rows],
774
+ )
775
+
776
+ # Save old y values before overwriting matrix
777
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
778
+
779
+ # Increase dest array sizes if needed
780
+ if y.columns.shape[0] < sum_nnz:
781
+ y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device)
782
+
783
+ from warp.context import runtime
784
+
785
+ if device.is_cpu:
786
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
787
+ else:
788
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
789
+
790
+ old_y_nnz = y.nnz
791
+ y.nnz = native_func(
792
+ y.block_shape[0],
793
+ y.block_shape[1],
794
+ y.nrow,
795
+ sum_nnz,
796
+ work_arrays._sum_rows.ptr,
797
+ work_arrays._sum_cols.ptr,
798
+ 0,
799
+ y.offsets.ptr,
800
+ y.columns.ptr,
801
+ 0,
802
+ )
803
+
804
+ _bsr_ensure_fits(y)
805
+ y.values.zero_()
806
+
807
+ wp.launch(
808
+ kernel=_bsr_axpy_add_block,
809
+ device=device,
810
+ dim=old_y_nnz,
811
+ inputs=[
812
+ 0,
813
+ beta,
814
+ work_arrays._sum_rows,
815
+ work_arrays._sum_cols,
816
+ y.offsets,
817
+ y.columns,
818
+ work_arrays._old_y_values,
819
+ y.values,
820
+ ],
821
+ )
822
+
823
+ wp.launch(
824
+ kernel=_bsr_axpy_add_block,
825
+ device=device,
826
+ dim=x.nnz,
827
+ inputs=[
828
+ old_y_nnz,
829
+ alpha,
830
+ work_arrays._sum_rows,
831
+ work_arrays._sum_cols,
832
+ y.offsets,
833
+ y.columns,
834
+ x.values,
835
+ y.values,
836
+ ],
837
+ )
838
+
839
+ return y
840
+
841
+
842
+ @wp.kernel
843
+ def _bsr_mm_count_coeffs(
844
+ z_nnz: int,
845
+ x_offsets: wp.array(dtype=int),
846
+ x_columns: wp.array(dtype=int),
847
+ y_offsets: wp.array(dtype=int),
848
+ counts: wp.array(dtype=int),
849
+ ):
850
+ row = wp.tid()
851
+ count = int(0)
852
+
853
+ x_beg = x_offsets[row]
854
+ x_end = x_offsets[row + 1]
855
+
856
+ for x_block in range(x_beg, x_end):
857
+ x_col = x_columns[x_block]
858
+ count += y_offsets[x_col + 1] - y_offsets[x_col]
859
+
860
+ counts[row + 1] = count
861
+
862
+ if row == 0:
863
+ counts[0] = z_nnz
864
+
865
+
866
+ @wp.kernel
867
+ def _bsr_mm_list_coeffs(
868
+ x_offsets: wp.array(dtype=int),
869
+ x_columns: wp.array(dtype=int),
870
+ y_offsets: wp.array(dtype=int),
871
+ y_columns: wp.array(dtype=int),
872
+ mm_offsets: wp.array(dtype=int),
873
+ mm_rows: wp.array(dtype=int),
874
+ mm_cols: wp.array(dtype=int),
875
+ ):
876
+ row = wp.tid()
877
+ mm_block = mm_offsets[row]
878
+
879
+ x_beg = x_offsets[row]
880
+ x_end = x_offsets[row + 1]
881
+
882
+ for x_block in range(x_beg, x_end):
883
+ x_col = x_columns[x_block]
884
+
885
+ y_beg = y_offsets[x_col]
886
+ y_end = y_offsets[x_col + 1]
887
+ for y_block in range(y_beg, y_end):
888
+ mm_cols[mm_block] = y_columns[y_block]
889
+ mm_rows[mm_block] = row
890
+ mm_block += 1
891
+
892
+
893
+ @wp.kernel
894
+ def _bsr_mm_compute_values(
895
+ alpha: Any,
896
+ x_offsets: wp.array(dtype=int),
897
+ x_columns: wp.array(dtype=int),
898
+ x_values: wp.array(dtype=Any),
899
+ y_offsets: wp.array(dtype=int),
900
+ y_columns: wp.array(dtype=int),
901
+ y_values: wp.array(dtype=Any),
902
+ mm_offsets: wp.array(dtype=int),
903
+ mm_cols: wp.array(dtype=int),
904
+ mm_values: wp.array(dtype=Any),
905
+ ):
906
+ mm_block = wp.tid()
907
+
908
+ row = wp.lower_bound(mm_offsets, mm_block + 1) - 1
909
+ col = mm_cols[mm_block]
910
+
911
+ mm_val = mm_values.dtype(type(alpha)(0.0))
912
+
913
+ x_beg = x_offsets[row]
914
+ x_end = x_offsets[row + 1]
915
+ for x_block in range(x_beg, x_end):
916
+ x_col = x_columns[x_block]
917
+ y_beg = y_offsets[x_col]
918
+ y_end = y_offsets[x_col + 1]
919
+
920
+ y_block = wp.lower_bound(y_columns, y_beg, y_end, col)
921
+ if y_block < y_end and y_columns[y_block] == col:
922
+ mm_val += x_values[x_block] * y_values[y_block]
923
+
924
+ mm_values[mm_block] += alpha * mm_val
925
+
926
+
927
+ class bsr_mm_work_arrays:
928
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
929
+
930
+ def __init__(self):
931
+ self._reset(None)
932
+
933
+ def _reset(self, device):
934
+ self.device = device
935
+ self._pinned_count_buffer = None
936
+ self._mm_row_counts = None
937
+ self._mm_rows = None
938
+ self._mm_cols = None
939
+ self._old_z_values = None
940
+ self._old_z_offsets = None
941
+ self._old_z_columns = None
942
+
943
+ def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
944
+ if self.device != device:
945
+ self._reset(device)
946
+
947
+ # Allocations that do not depend on any computation
948
+ if self.device.is_cuda:
949
+ if self._pinned_count_buffer is None:
950
+ self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
951
+
952
+ if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
953
+ self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
954
+
955
+ if copied_z_nnz > 0:
956
+ if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
957
+ self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
958
+
959
+ if z_aliasing:
960
+ if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
961
+ self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
962
+ if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
963
+ self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
964
+
965
+ def _allocate_stage_2(self, mm_nnz: int):
966
+ # Allocations that depend on unmerged nnz estimate
967
+ if self._mm_rows is None or self._mm_rows.size < mm_nnz:
968
+ self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
969
+ if self._mm_cols is None or self._mm_cols.size < mm_nnz:
970
+ self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
971
+
972
+
973
+ def bsr_mm(
974
+ x: BsrMatrix[BlockType[Rows, Any, Scalar]],
975
+ y: BsrMatrix[BlockType[Any, Cols, Scalar]],
976
+ z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
977
+ alpha: Scalar = 1.0,
978
+ beta: Scalar = 0.0,
979
+ work_arrays: Optional[bsr_mm_work_arrays] = None,
980
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
981
+ """
982
+ Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
983
+
984
+ The `x`, `y` and `z` matrices are allowed to alias.
985
+ If the matrix `z` is not provided as input, it will be allocated and treated as zero.
986
+
987
+ Args:
988
+ x: Read-only left factor of the matrix-matrix product.
989
+ y: Read-only right factor of the matrix-matrix product.
990
+ z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
991
+ alpha: Uniform scaling factor for the ``x * y`` product
992
+ beta: Uniform scaling factor for `z`
993
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
994
+ """
995
+
996
+ if z is None:
997
+ # If not output matrix is provided, allocate it for convenience
998
+ z_block_shape = (x.block_shape[0], y.block_shape[1])
999
+ if z_block_shape == (1, 1):
1000
+ z_block_type = x.scalar_type
1001
+ else:
1002
+ z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type)
1003
+ z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
1004
+ beta = 0.0
1005
+
1006
+ if x.values.device != y.values.device or x.values.device != z.values.device:
1007
+ raise ValueError("All arguments must reside on the same device")
1008
+
1009
+ if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
1010
+ raise ValueError("Matrices must have the same scalar type")
1011
+
1012
+ if (
1013
+ x.block_shape[0] != z.block_shape[0]
1014
+ or y.block_shape[1] != z.block_shape[1]
1015
+ or x.block_shape[1] != y.block_shape[0]
1016
+ ):
1017
+ raise ValueError("Incompatible block sizes for matrix multiplication")
1018
+
1019
+ if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
1020
+ raise ValueError("Incompatible number of rows/columns for matrix multiplication")
1021
+
1022
+ device = z.values.device
1023
+
1024
+ if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
1025
+ # Easy case
1026
+ return bsr_scale(z, beta)
1027
+
1028
+ if not isinstance(alpha, z.scalar_type):
1029
+ alpha = z.scalar_type(alpha)
1030
+ if not isinstance(beta, z.scalar_type):
1031
+ beta = z.scalar_type(beta)
1032
+
1033
+ if work_arrays is None:
1034
+ work_arrays = bsr_mm_work_arrays()
1035
+
1036
+ z_aliasing = z == x or z == y
1037
+ copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
1038
+
1039
+ work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
1040
+
1041
+ # Prefix sum of number of (unmerged) mm blocks per row
1042
+ wp.launch(
1043
+ kernel=_bsr_mm_count_coeffs,
1044
+ device=device,
1045
+ dim=z.nrow,
1046
+ inputs=[
1047
+ copied_z_nnz,
1048
+ x.offsets,
1049
+ x.columns,
1050
+ y.offsets,
1051
+ work_arrays._mm_row_counts,
1052
+ ],
1053
+ )
1054
+ warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
1055
+
1056
+ # Get back total counts on host
1057
+ if device.is_cuda:
1058
+ wp.copy(
1059
+ dest=work_arrays._pinned_count_buffer,
1060
+ src=work_arrays._mm_row_counts,
1061
+ src_offset=z.nrow,
1062
+ count=1,
1063
+ )
1064
+ wp.synchronize_stream(wp.get_stream(device))
1065
+ mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
1066
+ else:
1067
+ mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
1068
+
1069
+ work_arrays._allocate_stage_2(mm_nnz)
1070
+
1071
+ # If z has a non-zero scale, save current data before overwriting it
1072
+ if copied_z_nnz > 0:
1073
+ # Copy z row and column indices
1074
+ wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1075
+ wp.launch(
1076
+ kernel=_bsr_get_block_row,
1077
+ device=device,
1078
+ dim=copied_z_nnz,
1079
+ inputs=[0, z.offsets, work_arrays._mm_rows],
1080
+ )
1081
+ # Save current z values in temporary buffer
1082
+ wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1083
+ if z_aliasing:
1084
+ # If z is aliasing with x or y, need to save topology as well
1085
+ wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1086
+ wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1087
+
1088
+ # Fill unmerged mm blocks rows and columns
1089
+ wp.launch(
1090
+ kernel=_bsr_mm_list_coeffs,
1091
+ device=device,
1092
+ dim=z.nrow,
1093
+ inputs=[
1094
+ x.offsets,
1095
+ x.columns,
1096
+ y.offsets,
1097
+ y.columns,
1098
+ work_arrays._mm_row_counts,
1099
+ work_arrays._mm_rows,
1100
+ work_arrays._mm_cols,
1101
+ ],
1102
+ )
1103
+
1104
+ # Increase dest array size if needed
1105
+ if z.columns.shape[0] < mm_nnz:
1106
+ z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1107
+
1108
+ from warp.context import runtime
1109
+
1110
+ if device.is_cpu:
1111
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
1112
+ else:
1113
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
1114
+
1115
+ z.nnz = native_func(
1116
+ z.block_shape[0],
1117
+ z.block_shape[1],
1118
+ z.nrow,
1119
+ mm_nnz,
1120
+ work_arrays._mm_rows.ptr,
1121
+ work_arrays._mm_cols.ptr,
1122
+ 0,
1123
+ z.offsets.ptr,
1124
+ z.columns.ptr,
1125
+ 0,
1126
+ )
1127
+
1128
+ _bsr_ensure_fits(z)
1129
+ z.values.zero_()
1130
+
1131
+ if copied_z_nnz > 0:
1132
+ # Add back original z values
1133
+ wp.launch(
1134
+ kernel=_bsr_axpy_add_block,
1135
+ device=device,
1136
+ dim=copied_z_nnz,
1137
+ inputs=[
1138
+ 0,
1139
+ beta,
1140
+ work_arrays._mm_rows,
1141
+ work_arrays._mm_cols,
1142
+ z.offsets,
1143
+ z.columns,
1144
+ work_arrays._old_z_values,
1145
+ z.values,
1146
+ ],
1147
+ )
1148
+
1149
+ # Add mm blocks to z values
1150
+ if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
1151
+ warp.types.type_is_matrix(z.values.dtype)
1152
+ ):
1153
+ # Result block type is scalar, but operands are matrices
1154
+ # Cast result to (1x1) matrix to perform multiplication
1155
+ mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1156
+ else:
1157
+ mm_values = z.values
1158
+
1159
+ wp.launch(
1160
+ kernel=_bsr_mm_compute_values,
1161
+ device=device,
1162
+ dim=z.nnz,
1163
+ inputs=[
1164
+ alpha,
1165
+ work_arrays._old_z_offsets if x == z else x.offsets,
1166
+ work_arrays._old_z_columns if x == z else x.columns,
1167
+ work_arrays._old_z_values if x == z else x.values,
1168
+ work_arrays._old_z_offsets if y == z else y.offsets,
1169
+ work_arrays._old_z_columns if y == z else y.columns,
1170
+ work_arrays._old_z_values if y == z else y.values,
1171
+ z.offsets,
1172
+ z.columns,
1173
+ mm_values,
1174
+ ],
1175
+ )
1176
+
1177
+ return z
1178
+
1179
+
1180
+ @wp.kernel
1181
+ def _bsr_mv_kernel(
1182
+ alpha: Any,
1183
+ A_offsets: wp.array(dtype=int),
1184
+ A_columns: wp.array(dtype=int),
1185
+ A_values: wp.array(dtype=Any),
1186
+ x: wp.array(dtype=Any),
1187
+ beta: Any,
1188
+ y: wp.array(dtype=Any),
1189
+ ):
1190
+ row = wp.tid()
1191
+
1192
+ # zero-initialize with type of y elements
1193
+ scalar_zero = type(alpha)(0)
1194
+ v = y.dtype(scalar_zero)
1195
+
1196
+ if alpha != scalar_zero:
1197
+ beg = A_offsets[row]
1198
+ end = A_offsets[row + 1]
1199
+ for block in range(beg, end):
1200
+ v += A_values[block] * x[A_columns[block]]
1201
+ v *= alpha
1202
+
1203
+ if beta != scalar_zero:
1204
+ v += beta * y[row]
1205
+
1206
+ y[row] = v
1207
+
1208
+
1209
+ def bsr_mv(
1210
+ A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
1211
+ x: "Array[Vector[Cols, Scalar] | Scalar]",
1212
+ y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1213
+ alpha: Scalar = 1.0,
1214
+ beta: Scalar = 0.0,
1215
+ work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1216
+ ) -> "Array[Vector[Rows, Scalar] | Scalar]":
1217
+ """
1218
+ Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
1219
+
1220
+ The `x` and `y` vectors are allowed to alias.
1221
+
1222
+ Args:
1223
+ A: Read-only, left matrix factor of the matrix-vector product.
1224
+ x: Read-only, right vector factor of the matrix-vector product.
1225
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1226
+ alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1227
+ beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1228
+ work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1229
+ will be used for this purpose, otherwise a temporary allocation will be performed.
1230
+ """
1231
+
1232
+ if y is None:
1233
+ # If no output array is provided, allocate one for convenience
1234
+ y_vec_len = A.block_shape[0]
1235
+ y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1236
+ y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
1237
+ y.zero_()
1238
+ beta = 0.0
1239
+
1240
+ if not isinstance(alpha, A.scalar_type):
1241
+ alpha = A.scalar_type(alpha)
1242
+ if not isinstance(beta, A.scalar_type):
1243
+ beta = A.scalar_type(beta)
1244
+
1245
+ if A.values.device != x.device or A.values.device != y.device:
1246
+ raise ValueError("A, x and y must reside on the same device")
1247
+
1248
+ if x.shape[0] != A.ncol:
1249
+ raise ValueError("Number of columns of A must match number of rows of x")
1250
+ if y.shape[0] != A.nrow:
1251
+ raise ValueError("Number of rows of A must match number of rows of y")
1252
+
1253
+ if x == y:
1254
+ # Aliasing case, need temporary storage
1255
+ if work_buffer is None:
1256
+ work_buffer = wp.empty_like(y)
1257
+ elif work_buffer.size < y.size:
1258
+ raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
1259
+ elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
1260
+ raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
1261
+
1262
+ # Save old y values before overwriting vector
1263
+ wp.copy(dest=work_buffer, src=y, count=y.size)
1264
+ x = work_buffer
1265
+
1266
+ # Promote scalar vectors to length-1 vecs and conversely
1267
+ if warp.types.type_is_matrix(A.values.dtype):
1268
+ if A.block_shape[0] == 1:
1269
+ if y.dtype == A.scalar_type:
1270
+ y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1271
+ if A.block_shape[1] == 1:
1272
+ if x.dtype == A.scalar_type:
1273
+ x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1274
+ else:
1275
+ if A.block_shape[0] == 1:
1276
+ if y.dtype != A.scalar_type:
1277
+ y = y.view(dtype=A.scalar_type)
1278
+ if A.block_shape[1] == 1:
1279
+ if x.dtype != A.scalar_type:
1280
+ x = x.view(dtype=A.scalar_type)
1281
+
1282
+ wp.launch(
1283
+ kernel=_bsr_mv_kernel,
1284
+ device=A.values.device,
1285
+ dim=A.nrow,
1286
+ inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
1287
+ )
1288
+
1289
+ return y