warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +882 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1435 -379
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +3 -3
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +521 -250
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +18 -17
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +578 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
warp/_src/sparse.py ADDED
@@ -0,0 +1,2718 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import ctypes
19
+ import weakref
20
+ from typing import Any, Generic, TypeVar, Union
21
+
22
+ import warp as wp
23
+ import warp._src.utils
24
+ from warp._src.types import (
25
+ Array,
26
+ Cols,
27
+ Rows,
28
+ Scalar,
29
+ Vector,
30
+ is_array,
31
+ scalar_types,
32
+ type_is_matrix,
33
+ type_repr,
34
+ type_scalar_type,
35
+ type_size,
36
+ type_size_in_bytes,
37
+ type_to_warp,
38
+ types_equal,
39
+ )
40
+
41
+ _wp_module_name_ = "warp.sparse"
42
+
43
+ __all__ = [
44
+ "BsrMatrix",
45
+ "bsr_assign",
46
+ "bsr_axpy",
47
+ "bsr_block_index",
48
+ "bsr_copy",
49
+ "bsr_diag",
50
+ "bsr_from_triplets",
51
+ "bsr_get_diag",
52
+ "bsr_identity",
53
+ "bsr_matrix_t",
54
+ "bsr_mm",
55
+ "bsr_mm_work_arrays",
56
+ "bsr_mv",
57
+ "bsr_row_index",
58
+ "bsr_scale",
59
+ "bsr_set_diag",
60
+ "bsr_set_from_triplets",
61
+ "bsr_set_identity",
62
+ "bsr_set_transpose",
63
+ "bsr_set_zero",
64
+ "bsr_transposed",
65
+ "bsr_zeros",
66
+ ]
67
+
68
+
69
+ # typing hints
70
+
71
+ _BlockType = TypeVar("BlockType") # noqa: PLC0132
72
+
73
+
74
+ class _MatrixBlockType(Generic[Rows, Cols, Scalar]):
75
+ pass
76
+
77
+
78
+ class _ScalarBlockType(Generic[Scalar]):
79
+ pass
80
+
81
+
82
+ BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
83
+
84
+ _struct_cache = {}
85
+ _transfer_buffer_cache = {}
86
+
87
+
88
+ class BsrMatrix(Generic[_BlockType]):
89
+ """Untyped base class for BSR and CSR matrices.
90
+
91
+ Should not be constructed directly but through functions such as :func:`bsr_zeros`.
92
+
93
+ Attributes:
94
+ nrow (int): Number of rows of blocks.
95
+ ncol (int): Number of columns of blocks.
96
+ nnz (int): Upper bound for the number of non-zero blocks, used for
97
+ dimensioning launches. The exact number is at ``offsets[nrow-1]``.
98
+ See also :meth:`nnz_sync`.
99
+ offsets (Array[int]): Array of size at least ``1 + nrow`` such that the
100
+ start and end indices of the blocks of row ``r`` are ``offsets[r]``
101
+ and ``offsets[r+1]``, respectively.
102
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing
103
+ block column indices.
104
+ values (Array[BlockType]): Array of size at least equal to ``nnz``
105
+ containing block values.
106
+ """
107
+
108
+ @property
109
+ def scalar_type(self) -> Scalar:
110
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type."""
111
+ return type_scalar_type(self.values.dtype)
112
+
113
+ @property
114
+ def block_shape(self) -> tuple[int, int]:
115
+ """Shape of the individual blocks."""
116
+ return getattr(self.values.dtype, "_shape_", (1, 1))
117
+
118
+ @property
119
+ def block_size(self) -> int:
120
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
121
+ return type_size(self.values.dtype)
122
+
123
+ @property
124
+ def shape(self) -> tuple[int, int]:
125
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block."""
126
+ block_shape = self.block_shape
127
+ return (self.nrow * block_shape[0], self.ncol * block_shape[1])
128
+
129
+ @property
130
+ def dtype(self) -> type:
131
+ """Data type for individual block values."""
132
+ return self.values.dtype
133
+
134
+ @property
135
+ def device(self) -> wp._src.context.Device:
136
+ """Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
137
+ return self.values.device
138
+
139
+ @property
140
+ def requires_grad(self) -> bool:
141
+ """Read-only property indicating whether the matrix participates in adjoint computations."""
142
+ return self.values.requires_grad
143
+
144
+ @property
145
+ def scalar_values(self) -> wp.array:
146
+ """Accesses the ``values`` array as a 3d scalar array."""
147
+ values_view = _as_3d_array(self.values, self.block_shape)
148
+ values_view._ref = self.values # keep ref in case we're garbage collected
149
+ return values_view
150
+
151
+ def uncompress_rows(self, out: wp.array = None) -> wp.array:
152
+ """Compute the row index for each non-zero block from the compressed row offsets."""
153
+ if out is None:
154
+ out = wp.empty(self.nnz, dtype=int, device=self.device)
155
+
156
+ wp.launch(
157
+ kernel=_bsr_get_block_row,
158
+ device=self.device,
159
+ dim=self.nnz,
160
+ inputs=[self.nrow, self.offsets, out],
161
+ )
162
+ return out
163
+
164
+ def nnz_sync(self) -> int:
165
+ """
166
+ Synchronize the number of non-zeros from the device offsets array to the host.
167
+
168
+ Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
169
+ or, if none has been scheduled yet, starts a new transfer and waits for it to complete.
170
+
171
+ Then updates the host-side nnz upper bound to match the exact one, and returns it.
172
+
173
+ See also :meth:`notify_nnz_async`.
174
+ """
175
+
176
+ buf, event = self._nnz_transfer_if_any()
177
+ if buf is None:
178
+ buf, event = self._copy_nnz_async()
179
+
180
+ if event is not None:
181
+ wp.synchronize_event(event)
182
+ self.nnz = int(buf.numpy()[0])
183
+ return self.nnz
184
+
185
+ def notify_nnz_changed(self, nnz: int | None = None) -> None:
186
+ """Notify the matrix that the number of non-zeros has been changed from outside of the ``warp.sparse`` builtin functions.
187
+
188
+ Should be called in particular when the offsets array has been modified, or when the nnz upper bound has changed.
189
+ Makes sure that the matrix is properly resized and starts the asynchronous transfer of the actual non-zero count.
190
+
191
+ Args:
192
+ nnz: The new upper-bound for the number of non-zeros. If not provided, it will be read from the device offsets array (requires a synchronization).
193
+ """
194
+
195
+ if nnz is None:
196
+ self.nnz_sync()
197
+ else:
198
+ self._copy_nnz_async()
199
+
200
+ _bsr_ensure_fits(self, nnz=nnz)
201
+
202
+ def copy_nnz_async(self) -> None:
203
+ """
204
+ Starts the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
205
+
206
+ Deprecated; prefer :meth:`notify_nnz_changed` instead, which will make sure to resize arrays if necessary.
207
+ """
208
+ wp._src.utils.warn(
209
+ "The `copy_nnz_async` method is deprecated and will be removed in a future version. Prefer `notify_nnz_changed` instead.",
210
+ DeprecationWarning,
211
+ )
212
+ self._copy_nnz_async()
213
+
214
+ def _copy_nnz_async(self) -> tuple[wp.array, wp.Event]:
215
+ buf, event = self._setup_nnz_transfer()
216
+ if buf is not None:
217
+ stream = wp.get_stream(self.device) if self.device.is_cuda else None
218
+ wp.copy(src=self.offsets, dest=buf, src_offset=self.nrow, count=1, stream=stream)
219
+ if event is not None:
220
+ stream.record_event(event, external=True)
221
+ return buf, event
222
+
223
+ def _setup_nnz_transfer(self) -> tuple[wp.array, wp.Event]:
224
+ buf, event = self._nnz_transfer_if_any()
225
+ if buf is not None:
226
+ return buf, event
227
+
228
+ buf, event = _allocate_transfer_buf(self.device)
229
+ if buf is not None:
230
+ # buf may still be None if device is currently capturing
231
+ BsrMatrix.__setattr__(self, "_nnz_transfer", (buf, event))
232
+ weakref.finalize(self, _redeem_transfer_buf, self.device, buf, event)
233
+
234
+ return buf, event
235
+
236
+ def _nnz_transfer_if_any(self) -> tuple[wp.array, wp.Event]:
237
+ return getattr(self, "_nnz_transfer", (None, None))
238
+
239
+ # Overloaded math operators
240
+ def __add__(self, y):
241
+ return bsr_axpy(y, bsr_copy(self))
242
+
243
+ def __iadd__(self, y):
244
+ return bsr_axpy(y, self)
245
+
246
+ def __radd__(self, x):
247
+ return bsr_axpy(x, bsr_copy(self))
248
+
249
+ def __sub__(self, y):
250
+ return bsr_axpy(y, bsr_copy(self), alpha=-1.0)
251
+
252
+ def __rsub__(self, x):
253
+ return bsr_axpy(x, bsr_copy(self), beta=-1.0)
254
+
255
+ def __isub__(self, y):
256
+ return bsr_axpy(y, self, alpha=-1.0)
257
+
258
+ def __mul__(self, y):
259
+ return _BsrScalingExpression(self, y)
260
+
261
+ def __rmul__(self, x):
262
+ return _BsrScalingExpression(self, x)
263
+
264
+ def __imul__(self, y):
265
+ return bsr_scale(self, y)
266
+
267
+ def __matmul__(self, y):
268
+ if isinstance(y, wp.array):
269
+ return bsr_mv(self, y)
270
+
271
+ return bsr_mm(self, y)
272
+
273
+ def __rmatmul__(self, x):
274
+ if isinstance(x, wp.array):
275
+ return bsr_mv(self, x, transpose=True)
276
+
277
+ return bsr_mm(x, self)
278
+
279
+ def __imatmul__(self, y):
280
+ return bsr_mm(self, y, self)
281
+
282
+ def __truediv__(self, y):
283
+ return _BsrScalingExpression(self, 1.0 / y)
284
+
285
+ def __neg__(self):
286
+ return _BsrScalingExpression(self, -1.0)
287
+
288
+ def transpose(self):
289
+ """Return a transposed copy of this matrix."""
290
+ return bsr_transposed(self)
291
+
292
+
293
+ def _allocate_transfer_buf(device):
294
+ if device.ordinal in _transfer_buffer_cache:
295
+ all_, pool = _transfer_buffer_cache[device.ordinal]
296
+ else:
297
+ all_ = []
298
+ pool = []
299
+ _transfer_buffer_cache[device.ordinal] = (all_, pool)
300
+
301
+ if pool:
302
+ return pool.pop()
303
+
304
+ if device.is_capturing:
305
+ return None, None
306
+
307
+ buf = wp.empty(dtype=int, shape=(1,), device="cpu", pinned=device.is_cuda)
308
+ event = wp.Event(device) if device.is_cuda else None
309
+ all_.append((buf, event)) # keep a reference to the buffer and event, prevent garbage collection before redeem
310
+ return buf, event
311
+
312
+
313
+ def _redeem_transfer_buf(device, buf, event):
314
+ all_, pool = _transfer_buffer_cache[device.ordinal]
315
+ pool.append((buf, event))
316
+
317
+
318
+ def bsr_matrix_t(dtype: BlockType):
319
+ dtype = type_to_warp(dtype)
320
+
321
+ if not type_is_matrix(dtype) and dtype not in scalar_types:
322
+ raise ValueError(f"BsrMatrix block type must be either warp matrix or scalar; got {type_repr(dtype)}")
323
+
324
+ class BsrMatrixTyped(BsrMatrix):
325
+ nrow: int
326
+ """Number of rows of blocks."""
327
+ ncol: int
328
+ """Number of columns of blocks."""
329
+ nnz: int
330
+ """Upper bound for the number of non-zeros."""
331
+ offsets: wp.array(dtype=int)
332
+ """Array of size at least ``1 + nrow``."""
333
+ columns: wp.array(dtype=int)
334
+ """Array of size at least equal to ``nnz``."""
335
+ values: wp.array(dtype=dtype)
336
+
337
+ module = wp.get_module(BsrMatrix.__module__)
338
+
339
+ if hasattr(dtype, "_shape_"):
340
+ type_str = f"{type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
341
+ else:
342
+ type_str = dtype.__name__
343
+ key = f"{BsrMatrix.__qualname__}_{type_str}"
344
+
345
+ if key not in _struct_cache:
346
+ BsrMatrixTyped.dtype = dtype # necessary for eval_annotations
347
+ _struct_cache[key] = wp._src.codegen.Struct(
348
+ key=key,
349
+ cls=BsrMatrixTyped,
350
+ module=module,
351
+ )
352
+
353
+ return _struct_cache[key]
354
+
355
+
356
+ def bsr_zeros(
357
+ rows_of_blocks: int,
358
+ cols_of_blocks: int,
359
+ block_type: BlockType,
360
+ device: wp._src.context.Devicelike = None,
361
+ ) -> BsrMatrix:
362
+ """Construct and return an empty BSR or CSR matrix with the given shape.
363
+
364
+ Args:
365
+ bsr: The BSR or CSR matrix to set to zero.
366
+ rows_of_blocks: Number of rows of blocks.
367
+ cols_of_blocks: Number of columns of blocks.
368
+ block_type: Type of individual blocks.
369
+ For CSR matrices, this should be a scalar type.
370
+ For BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`).
371
+ device: Device on which to allocate the matrix arrays.
372
+ """
373
+
374
+ bsr = bsr_matrix_t(block_type)()
375
+
376
+ bsr.nrow = int(rows_of_blocks)
377
+ bsr.ncol = int(cols_of_blocks)
378
+ bsr.nnz = 0
379
+ bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
380
+ bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
381
+ bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
382
+
383
+ return bsr
384
+
385
+
386
+ def _bsr_resize(bsr: BsrMatrix, rows_of_blocks: int | None = None, cols_of_blocks: int | None = None) -> None:
387
+ if rows_of_blocks is not None:
388
+ bsr.nrow = int(rows_of_blocks)
389
+ if cols_of_blocks is not None:
390
+ bsr.ncol = int(cols_of_blocks)
391
+
392
+ if bsr.offsets.size < bsr.nrow + 1:
393
+ bsr.offsets = wp.empty(shape=(bsr.nrow + 1,), dtype=int, device=bsr.offsets.device)
394
+
395
+
396
+ def _bsr_ensure_fits(bsr: BsrMatrix, nnz: int | None = None) -> None:
397
+ if nnz is None:
398
+ nnz = bsr.nnz
399
+ else:
400
+ # update nnz upper bound
401
+ bsr.nnz = int(nnz)
402
+
403
+ if bsr.columns.size < nnz:
404
+ bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
405
+ if bsr.values.size < nnz:
406
+ bsr.values = wp.empty(
407
+ shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device, requires_grad=bsr.values.requires_grad
408
+ )
409
+
410
+
411
+ def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: int | None = None, cols_of_blocks: int | None = None):
412
+ """Set a BSR matrix to zero, possibly changing its size.
413
+
414
+ Args:
415
+ bsr: The BSR or CSR matrix to set to zero.
416
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
417
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
418
+ """
419
+ _bsr_resize(bsr, rows_of_blocks, cols_of_blocks)
420
+
421
+ bsr.offsets.zero_()
422
+ bsr.notify_nnz_changed(nnz=0)
423
+
424
+
425
+ def _as_3d_array(arr, block_shape):
426
+ return wp.array(
427
+ ptr=arr.ptr,
428
+ capacity=arr.capacity,
429
+ device=arr.device,
430
+ dtype=type_scalar_type(arr.dtype),
431
+ shape=(arr.shape[0], *block_shape),
432
+ grad=None if arr.grad is None else _as_3d_array(arr.grad, block_shape),
433
+ )
434
+
435
+
436
+ def _optional_ctypes_pointer(array: wp.array | None, ctype):
437
+ return None if array is None else ctypes.cast(array.ptr, ctypes.POINTER(ctype))
438
+
439
+
440
+ def _optional_ctypes_event(event: wp.Event | None):
441
+ return None if event is None else event.cuda_event
442
+
443
+
444
+ _zero_value_masks = {
445
+ wp.float16: 0x7FFF,
446
+ wp.float32: 0x7FFFFFFF,
447
+ wp.float64: 0x7FFFFFFFFFFFFFFF,
448
+ wp.int8: 0xFF,
449
+ wp.int16: 0xFFFF,
450
+ wp.int32: 0xFFFFFFFF,
451
+ wp.int64: 0xFFFFFFFFFFFFFFFF,
452
+ }
453
+
454
+
455
+ @wp.kernel
456
+ def _bsr_accumulate_triplet_values(
457
+ row_count: int,
458
+ tpl_summed_offsets: wp.array(dtype=int),
459
+ tpl_summed_indices: wp.array(dtype=int),
460
+ tpl_values: wp.array3d(dtype=Any),
461
+ bsr_offsets: wp.array(dtype=int),
462
+ bsr_values: wp.array3d(dtype=Any),
463
+ ):
464
+ block, i, j = wp.tid()
465
+
466
+ if block >= bsr_offsets[row_count]:
467
+ return
468
+
469
+ if block == 0:
470
+ beg = 0
471
+ else:
472
+ beg = tpl_summed_offsets[block - 1]
473
+ end = tpl_summed_offsets[block]
474
+
475
+ val = tpl_values[tpl_summed_indices[beg], i, j]
476
+ for k in range(beg + 1, end):
477
+ val += tpl_values[tpl_summed_indices[k], i, j]
478
+
479
+ bsr_values[block, i, j] = val
480
+
481
+
482
+ def bsr_set_from_triplets(
483
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
484
+ rows: Array[int],
485
+ columns: Array[int],
486
+ values: Array[Scalar | BlockType[Rows, Cols, Scalar]] | None = None,
487
+ count: Array[int] | None = None,
488
+ prune_numerical_zeros: bool = True,
489
+ masked: bool = False,
490
+ ):
491
+ """Fill a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
492
+
493
+ The first dimension of the three input arrays must match and indicates the number of COO triplets.
494
+
495
+ Args:
496
+ dest: Sparse matrix to populate.
497
+ rows: Row index for each non-zero.
498
+ columns: Columns index for each non-zero.
499
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
500
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
501
+ If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
502
+ count: Single-element array indicating the number of triplets. If ``None``, the number of triplets is determined from the shape of
503
+ ``rows`` and ``columns`` arrays.
504
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
505
+ masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
506
+ """
507
+
508
+ if rows.device != columns.device or rows.device != dest.device:
509
+ raise ValueError(
510
+ f"Rows and columns must reside on the destination matrix device, got {rows.device}, {columns.device} and {dest.device}"
511
+ )
512
+
513
+ if rows.shape[0] != columns.shape[0]:
514
+ raise ValueError(
515
+ f"Rows and columns arrays must have the same length, got {rows.shape[0]} and {columns.shape[0]}"
516
+ )
517
+
518
+ if rows.dtype != wp.int32 or columns.dtype != wp.int32:
519
+ raise TypeError("Rows and columns arrays must be of type int32")
520
+
521
+ if count is not None:
522
+ if count.device != rows.device:
523
+ raise ValueError(f"Count and rows must reside on the same device, got {count.device} and {rows.device}")
524
+
525
+ if count.shape != (1,):
526
+ raise ValueError(f"Count array must be a single-element array, got {count.shape}")
527
+
528
+ if count.dtype != wp.int32:
529
+ raise TypeError("Count array must be of type int32")
530
+
531
+ # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
532
+ if values is not None:
533
+ if values.device != rows.device:
534
+ raise ValueError(f"Values and rows must reside on the same device, got {values.device} and {rows.device}")
535
+
536
+ if values.shape[0] != rows.shape[0]:
537
+ raise ValueError(
538
+ f"Values and rows arrays must have the same length, got {values.shape[0]} and {rows.shape[0]}"
539
+ )
540
+
541
+ if values.ndim == 1:
542
+ if not types_equal(values.dtype, dest.values.dtype):
543
+ raise ValueError(
544
+ f"Values array type must correspond to that of the dest matrix, got {type_repr(values.dtype)} and {type_repr(dest.values.dtype)}"
545
+ )
546
+ elif values.ndim == 3:
547
+ if values.shape[1:] != dest.block_shape:
548
+ raise ValueError(
549
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
550
+ )
551
+
552
+ if type_scalar_type(values.dtype) != dest.scalar_type:
553
+ raise ValueError(
554
+ f"Scalar type of values array ({type_repr(values.dtype)}) should correspond to that of matrix ({type_repr(dest.scalar_type)})"
555
+ )
556
+ else:
557
+ raise ValueError(f"Number of dimension for values array should be 1 or 3, got {values.ndim}")
558
+
559
+ if prune_numerical_zeros and not values.is_contiguous:
560
+ raise ValueError("Values array should be contiguous for numerical zero pruning")
561
+
562
+ nnz = rows.shape[0]
563
+ if nnz == 0:
564
+ bsr_set_zero(dest)
565
+ return
566
+
567
+ # Increase dest array sizes if needed
568
+ if not masked:
569
+ _bsr_ensure_fits(dest, nnz=nnz)
570
+
571
+ device = dest.values.device
572
+ scalar_type = dest.scalar_type
573
+ zero_value_mask = _zero_value_masks.get(scalar_type, 0) if prune_numerical_zeros else 0
574
+
575
+ # compute the BSR topology
576
+
577
+ from warp._src.context import runtime
578
+
579
+ if device.is_cpu:
580
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
581
+ else:
582
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
583
+
584
+ nnz_buf, nnz_event = dest._setup_nnz_transfer()
585
+ summed_triplet_offsets = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
586
+ summed_triplet_indices = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
587
+
588
+ with wp.ScopedDevice(device):
589
+ native_func(
590
+ dest.block_size,
591
+ type_size_in_bytes(scalar_type),
592
+ dest.nrow,
593
+ dest.ncol,
594
+ nnz,
595
+ _optional_ctypes_pointer(count, ctype=ctypes.c_int32),
596
+ ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
597
+ ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
598
+ _optional_ctypes_pointer(values, ctype=ctypes.c_int32),
599
+ zero_value_mask,
600
+ masked,
601
+ ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
602
+ ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
603
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
604
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
605
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
606
+ _optional_ctypes_event(nnz_event),
607
+ )
608
+
609
+ # now accumulate repeated blocks
610
+ wp.launch(
611
+ _bsr_accumulate_triplet_values,
612
+ dim=(nnz, *dest.block_shape),
613
+ inputs=[
614
+ dest.nrow,
615
+ summed_triplet_offsets,
616
+ summed_triplet_indices,
617
+ _as_3d_array(values, dest.block_shape),
618
+ dest.offsets,
619
+ ],
620
+ outputs=[dest.scalar_values],
621
+ )
622
+
623
+
624
+ def bsr_from_triplets(
625
+ rows_of_blocks: int,
626
+ cols_of_blocks: int,
627
+ rows: Array[int],
628
+ columns: Array[int],
629
+ values: Array[Scalar | BlockType[Rows, Cols, Scalar]],
630
+ prune_numerical_zeros: bool = True,
631
+ ):
632
+ """Constructs a BSR matrix with values defined by coordinate-oriented (COO) triplets.
633
+
634
+ The first dimension of the three input arrays must match and indicates the number of COO triplets.
635
+
636
+ Args:
637
+ rows_of_blocks: Number of rows of blocks.
638
+ cols_of_blocks: Number of columns of blocks.
639
+ rows: Row index for each non-zero.
640
+ columns: Columns index for each non-zero.
641
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
642
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
643
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
644
+ """
645
+
646
+ if values.ndim == 3:
647
+ block_type = wp.mat(shape=values.shape[1:], dtype=values.dtype)
648
+ else:
649
+ block_type = values.dtype
650
+
651
+ A = bsr_zeros(
652
+ rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
653
+ )
654
+ A.values.requires_grad = values.requires_grad
655
+ bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
656
+ return A
657
+
658
+
659
+ class _BsrExpression(Generic[_BlockType]):
660
+ pass
661
+
662
+
663
+ class _BsrScalingExpression(_BsrExpression):
664
+ def __init__(self, mat, scale):
665
+ self.mat = mat
666
+ self.scale = scale
667
+
668
+ def eval(self):
669
+ return bsr_copy(self)
670
+
671
+ @property
672
+ def nrow(self) -> int:
673
+ return self.mat.nrow
674
+
675
+ @property
676
+ def ncol(self) -> int:
677
+ return self.mat.ncol
678
+
679
+ @property
680
+ def nnz(self) -> int:
681
+ return self.mat.nnz
682
+
683
+ @property
684
+ def offsets(self) -> wp.array:
685
+ return self.mat.offsets
686
+
687
+ @property
688
+ def columns(self) -> wp.array:
689
+ return self.mat.columns
690
+
691
+ @property
692
+ def scalar_type(self) -> Scalar:
693
+ return self.mat.scalar_type
694
+
695
+ @property
696
+ def block_shape(self) -> tuple[int, int]:
697
+ return self.mat.block_shape
698
+
699
+ @property
700
+ def block_size(self) -> int:
701
+ return self.mat.block_size
702
+
703
+ @property
704
+ def shape(self) -> tuple[int, int]:
705
+ return self.mat.shape
706
+
707
+ @property
708
+ def dtype(self) -> type:
709
+ return self.mat.dtype
710
+
711
+ @property
712
+ def requires_grad(self) -> bool:
713
+ return self.mat.requires_grad
714
+
715
+ @property
716
+ def device(self) -> wp._src.context.Device:
717
+ return self.mat.device
718
+
719
+ # Overloaded math operators
720
+ def __add__(self, y):
721
+ return bsr_axpy(y, bsr_copy(self.mat), alpha=self.scale)
722
+
723
+ def __radd__(self, x):
724
+ return bsr_axpy(x, bsr_copy(self.mat), beta=self.scale)
725
+
726
+ def __sub__(self, y):
727
+ return bsr_axpy(y, bsr_copy(self.mat), alpha=-self.scale)
728
+
729
+ def __rsub__(self, x):
730
+ return bsr_axpy(x, bsr_copy(self.mat), beta=-self.scale)
731
+
732
+ def __mul__(self, y):
733
+ return _BsrScalingExpression(self.mat, y * self.scale)
734
+
735
+ def __rmul__(self, x):
736
+ return _BsrScalingExpression(self.mat, x * self.scale)
737
+
738
+ def __matmul__(self, y):
739
+ if isinstance(y, wp.array):
740
+ return bsr_mv(self.mat, y, alpha=self.scale)
741
+
742
+ return bsr_mm(self.mat, y, alpha=self.scale)
743
+
744
+ def __rmatmul__(self, x):
745
+ if isinstance(x, wp.array):
746
+ return bsr_mv(self.mat, x, alpha=self.scale, transpose=True)
747
+
748
+ return bsr_mm(x, self.mat, alpha=self.scale)
749
+
750
+ def __truediv__(self, y):
751
+ return _BsrScalingExpression(self.mat, self.scale / y)
752
+
753
+ def __neg__(self):
754
+ return _BsrScalingExpression(self.mat, -self.scale)
755
+
756
+ def transpose(self):
757
+ """Returns a transposed copy of this matrix"""
758
+ return _BsrScalingExpression(self.mat.transpose(), self.scale)
759
+
760
+
761
+ BsrMatrixOrExpression = Union[BsrMatrix[_BlockType], _BsrExpression[_BlockType]]
762
+
763
+
764
+ def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
765
+ if isinstance(bsr, BsrMatrix):
766
+ return bsr, 1.0
767
+ if isinstance(bsr, _BsrScalingExpression):
768
+ return bsr.mat, bsr.scale
769
+
770
+ raise ValueError("Argument cannot be interpreted as a BsrMatrix")
771
+
772
+
773
+ @wp.func
774
+ def bsr_row_index(
775
+ offsets: wp.array(dtype=int),
776
+ row_count: int,
777
+ block_index: int,
778
+ ) -> int:
779
+ """Returns the index of the row containing a given block, or -1 if no such row exists.
780
+
781
+ Args:
782
+ offsets: Array of size at least ``1 + row_count`` containing the offsets of the blocks in each row.
783
+ row_count: Number of rows of blocks.
784
+ block_index: Index of the block.
785
+ """
786
+ return wp.where(block_index < offsets[row_count], wp.lower_bound(offsets, 0, row_count + 1, block_index + 1), 0) - 1
787
+
788
+
789
+ @wp.func
790
+ def bsr_block_index(
791
+ row: int,
792
+ col: int,
793
+ bsr_offsets: wp.array(dtype=int),
794
+ bsr_columns: wp.array(dtype=int),
795
+ ) -> int:
796
+ """
797
+ Returns the index of the block at block-coordinates (row, col), or -1 if no such block exists.
798
+ Assumes that the segments of ``bsr_columns`` corresponding to each row are sorted.
799
+
800
+ Args:
801
+ row: Row of the block.
802
+ col: Column of the block.
803
+ bsr_offsets: Array of size at least ``1 + row`` containing the offsets of the blocks in each row.
804
+ bsr_columns: Array of size at least equal to ``bsr_offsets[row + 1]`` containing the column indices of the blocks.
805
+ """
806
+
807
+ if row < 0:
808
+ return -1
809
+
810
+ row_beg = bsr_offsets[row]
811
+ row_end = bsr_offsets[row + 1]
812
+
813
+ if row_beg == row_end:
814
+ return -1
815
+
816
+ block_index = wp.lower_bound(bsr_columns, row_beg, row_end, col)
817
+ return wp.where(bsr_columns[block_index] == col, block_index, -1)
818
+
819
+
820
+ @wp.kernel(enable_backward=False)
821
+ def _bsr_assign_list_blocks(
822
+ src_subrows: int,
823
+ src_subcols: int,
824
+ dest_subrows: int,
825
+ dest_subcols: int,
826
+ src_row_count: int,
827
+ src_offsets: wp.array(dtype=int),
828
+ src_columns: wp.array(dtype=int),
829
+ dest_rows: wp.array(dtype=int),
830
+ dest_cols: wp.array(dtype=int),
831
+ ):
832
+ block, subrow, subcol = wp.tid()
833
+ dest_block = (block * src_subcols + subcol) * src_subrows + subrow
834
+
835
+ row = bsr_row_index(src_offsets, src_row_count, block)
836
+ if row == -1:
837
+ dest_rows[dest_block] = row # invalid
838
+ dest_cols[dest_block] = row
839
+ else:
840
+ dest_subrow = row * src_subrows + subrow
841
+ dest_subcol = src_columns[block] * src_subcols + subcol
842
+ dest_rows[dest_block] = dest_subrow // dest_subrows
843
+ dest_cols[dest_block] = dest_subcol // dest_subcols
844
+
845
+
846
+ @wp.kernel
847
+ def _bsr_assign_copy_blocks(
848
+ scale: Any,
849
+ src_subrows: int,
850
+ src_subcols: int,
851
+ dest_subrows: int,
852
+ dest_subcols: int,
853
+ src_row_count: int,
854
+ src_offsets: wp.array(dtype=int),
855
+ src_columns: wp.array(dtype=int),
856
+ src_values: wp.array3d(dtype=Any),
857
+ dest_offsets: wp.array(dtype=int),
858
+ dest_columns: wp.array(dtype=int),
859
+ dest_values: wp.array3d(dtype=Any),
860
+ ):
861
+ src_block = wp.tid()
862
+ src_block, subrow, subcol = wp.tid()
863
+
864
+ src_row = bsr_row_index(src_offsets, src_row_count, src_block)
865
+ if src_row == -1:
866
+ return
867
+
868
+ src_col = src_columns[src_block]
869
+
870
+ dest_subrow = src_row * src_subrows + subrow
871
+ dest_subcol = src_col * src_subcols + subcol
872
+ dest_row = dest_subrow // dest_subrows
873
+ dest_col = dest_subcol // dest_subcols
874
+
875
+ dest_block = bsr_block_index(dest_row, dest_col, dest_offsets, dest_columns)
876
+ if dest_block == -1:
877
+ return
878
+
879
+ split_row = dest_subrow - dest_subrows * dest_row
880
+ split_col = dest_subcol - dest_subcols * dest_col
881
+
882
+ rows_per_subblock = src_values.shape[1] // src_subrows
883
+ cols_per_subblock = src_values.shape[2] // src_subcols
884
+
885
+ dest_base_i = split_row * rows_per_subblock
886
+ dest_base_j = split_col * cols_per_subblock
887
+
888
+ src_base_i = subrow * rows_per_subblock
889
+ src_base_j = subcol * cols_per_subblock
890
+
891
+ for i in range(rows_per_subblock):
892
+ for j in range(cols_per_subblock):
893
+ dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
894
+ scale * src_values[src_block, i + src_base_i, j + src_base_j]
895
+ )
896
+
897
+
898
+ def bsr_assign(
899
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
900
+ src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
901
+ structure_only: bool = False,
902
+ masked: bool = False,
903
+ ):
904
+ """Copy the content of the ``src`` BSR matrix to ``dest``.
905
+
906
+ Args:
907
+ src: Matrix to be copied.
908
+ dest: Destination matrix. May have a different block shape or scalar type
909
+ than ``src``, in which case the required casting will be performed.
910
+ structure_only: If ``True``, only the non-zero indices are copied, and uninitialized value storage is allocated
911
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
912
+ casting if the two matrices use distinct scalar types.
913
+ masked: If ``True``, keep the non-zero topology of ``dest`` unchanged.
914
+ """
915
+
916
+ src, src_scale = _extract_matrix_and_scale(src)
917
+
918
+ if dest.values.device != src.values.device:
919
+ raise ValueError("Source and destination matrices must reside on the same device")
920
+
921
+ if src.block_shape[0] >= dest.block_shape[0]:
922
+ src_subrows = src.block_shape[0] // dest.block_shape[0]
923
+ dest_subrows = 1
924
+ else:
925
+ dest_subrows = dest.block_shape[0] // src.block_shape[0]
926
+ src_subrows = 1
927
+
928
+ if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
929
+ raise ValueError(
930
+ f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {dest.block_shape[0]}, {src.block_shape[0]})"
931
+ )
932
+
933
+ if src.block_shape[1] >= dest.block_shape[1]:
934
+ src_subcols = src.block_shape[1] // dest.block_shape[1]
935
+ dest_subcols = 1
936
+ else:
937
+ dest_subcols = dest.block_shape[1] // src.block_shape[1]
938
+ src_subcols = 1
939
+
940
+ if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
941
+ raise ValueError(
942
+ f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {dest.block_shape[1]}, {src.block_shape[1]})"
943
+ )
944
+
945
+ dest_nrow = (src.nrow * src_subrows) // dest_subrows
946
+ dest_ncol = (src.ncol * src_subcols) // dest_subcols
947
+
948
+ if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
949
+ raise ValueError(
950
+ f"The requested block shape {dest.block_shape} does not evenly divide the source matrix of total size {src.shape}"
951
+ )
952
+
953
+ nnz_alloc = src.nnz * src_subrows * src_subcols
954
+ if masked:
955
+ if dest_nrow != dest.nrow or dest_ncol != dest.ncol:
956
+ raise ValueError(
957
+ f"Incompatible destination matrix size, expected ({dest_nrow}, {dest_ncol}), got ({dest.nrow}, {dest.ncol})"
958
+ )
959
+ else:
960
+ _bsr_resize(dest, rows_of_blocks=dest_nrow, cols_of_blocks=dest_ncol)
961
+
962
+ if dest.block_shape == src.block_shape and not masked:
963
+ # Direct copy
964
+
965
+ wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
966
+ dest.notify_nnz_changed(nnz=nnz_alloc)
967
+
968
+ if nnz_alloc > 0:
969
+ wp.copy(dest=dest.columns, src=src.columns, count=nnz_alloc)
970
+
971
+ if not structure_only:
972
+ warp._src.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
973
+ bsr_scale(dest, src_scale)
974
+
975
+ else:
976
+ if not masked:
977
+ # Compute destination rows and columns
978
+ dest_rows = wp.empty(nnz_alloc, dtype=int, device=dest.device)
979
+ dest_cols = wp.empty(nnz_alloc, dtype=int, device=dest.device)
980
+ wp.launch(
981
+ _bsr_assign_list_blocks,
982
+ dim=(src.nnz, src_subrows, src_subcols),
983
+ device=dest.device,
984
+ inputs=[
985
+ src_subrows,
986
+ src_subcols,
987
+ dest_subrows,
988
+ dest_subcols,
989
+ src.nrow,
990
+ src.offsets,
991
+ src.columns,
992
+ dest_rows,
993
+ dest_cols,
994
+ ],
995
+ )
996
+
997
+ _bsr_ensure_fits(dest, nnz=nnz_alloc)
998
+
999
+ # Compute destination offsets from triplets
1000
+ from warp._src.context import runtime
1001
+
1002
+ if dest.device.is_cpu:
1003
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
1004
+ else:
1005
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
1006
+
1007
+ nnz_buf, nnz_event = dest._setup_nnz_transfer()
1008
+ with wp.ScopedDevice(dest.device):
1009
+ native_func(
1010
+ dest.block_size,
1011
+ 0, # scalar_size_in_bytes
1012
+ dest.nrow,
1013
+ dest.ncol,
1014
+ nnz_alloc,
1015
+ None, # device nnz
1016
+ ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1017
+ ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1018
+ None, # triplet values
1019
+ 0, # zero_value_mask
1020
+ masked,
1021
+ None, # summed block offsets
1022
+ None, # summed block indices
1023
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1024
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1025
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
1026
+ _optional_ctypes_event(nnz_event),
1027
+ )
1028
+
1029
+ # copy block values
1030
+ if not structure_only:
1031
+ dest.values.zero_()
1032
+ wp.launch(
1033
+ _bsr_assign_copy_blocks,
1034
+ dim=(src.nnz, src_subrows, src_subcols),
1035
+ device=dest.device,
1036
+ inputs=[
1037
+ src.scalar_type(src_scale),
1038
+ src_subrows,
1039
+ src_subcols,
1040
+ dest_subrows,
1041
+ dest_subcols,
1042
+ src.nrow,
1043
+ src.offsets,
1044
+ src.columns,
1045
+ src.scalar_values,
1046
+ dest.offsets,
1047
+ dest.columns,
1048
+ dest.scalar_values,
1049
+ ],
1050
+ )
1051
+
1052
+
1053
+ def bsr_copy(
1054
+ A: BsrMatrixOrExpression,
1055
+ scalar_type: Scalar | None = None,
1056
+ block_shape: tuple[int, int] | None = None,
1057
+ structure_only: bool = False,
1058
+ ):
1059
+ """Return a copy of matrix ``A``, possibly changing its scalar type.
1060
+
1061
+ Args:
1062
+ A: Matrix to be copied.
1063
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from ``A``.
1064
+ block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from ``A``.
1065
+ Both dimensions of ``block_shape`` must be either a multiple or an exact divider of the ones from ``A``.
1066
+ structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
1067
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
1068
+ casting if the two matrices use distinct scalar types.
1069
+ """
1070
+ if scalar_type is None:
1071
+ scalar_type = A.scalar_type
1072
+ if block_shape is None:
1073
+ block_shape = A.block_shape
1074
+
1075
+ if block_shape == (1, 1):
1076
+ block_type = scalar_type
1077
+ else:
1078
+ block_type = wp.mat(shape=block_shape, dtype=scalar_type)
1079
+
1080
+ copy = bsr_zeros(
1081
+ rows_of_blocks=A.nrow,
1082
+ cols_of_blocks=A.ncol,
1083
+ block_type=block_type,
1084
+ device=A.device,
1085
+ )
1086
+ copy.values.requires_grad = A.requires_grad
1087
+ bsr_assign(dest=copy, src=A, structure_only=structure_only)
1088
+ return copy
1089
+
1090
+
1091
+ @wp.kernel
1092
+ def _bsr_transpose_values(
1093
+ col_count: int,
1094
+ scale: Any,
1095
+ bsr_offsets: wp.array(dtype=int),
1096
+ bsr_columns: wp.array(dtype=int),
1097
+ bsr_values: wp.array3d(dtype=Any),
1098
+ block_index_map: wp.array(dtype=int),
1099
+ transposed_bsr_offsets: wp.array(dtype=int),
1100
+ transposed_bsr_columns: wp.array(dtype=int),
1101
+ transposed_bsr_values: wp.array3d(dtype=Any),
1102
+ ):
1103
+ block, i, j = wp.tid()
1104
+
1105
+ if block >= transposed_bsr_offsets[col_count]:
1106
+ return
1107
+
1108
+ if block_index_map:
1109
+ src_block = block_index_map[block]
1110
+ else:
1111
+ row = bsr_row_index(transposed_bsr_offsets, col_count, block)
1112
+ col = transposed_bsr_columns[block]
1113
+ src_block = bsr_block_index(col, row, bsr_offsets, bsr_columns)
1114
+ if src_block == -1:
1115
+ return
1116
+
1117
+ transposed_bsr_values[block, i, j] = bsr_values[src_block, j, i] * scale
1118
+
1119
+
1120
+ def bsr_set_transpose(
1121
+ dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
1122
+ src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
1123
+ masked: bool = False,
1124
+ ):
1125
+ """
1126
+ Assign the transposed matrix ``src`` to matrix ``dest``.
1127
+
1128
+ Args:
1129
+ dest: Sparse matrix to populate.
1130
+ src: Sparse matrix to transpose.
1131
+ masked: If ``True``, keep the non-zero topology of ``dest`` unchanged.
1132
+ """
1133
+
1134
+ src, src_scale = _extract_matrix_and_scale(src)
1135
+
1136
+ if dest.values.device != src.values.device:
1137
+ raise ValueError(
1138
+ f"All arguments must reside on the same device, got {dest.values.device} and {src.values.device}"
1139
+ )
1140
+
1141
+ if dest.scalar_type != src.scalar_type:
1142
+ raise ValueError(f"All arguments must have the same scalar type, got {dest.scalar_type} and {src.scalar_type}")
1143
+
1144
+ transpose_block_shape = src.block_shape[::-1]
1145
+
1146
+ if dest.block_shape != transpose_block_shape:
1147
+ raise ValueError(f"Destination block shape must be {transpose_block_shape}, got {dest.block_shape}")
1148
+
1149
+ if masked:
1150
+ if dest.nrow != src.ncol or dest.ncol != src.nrow:
1151
+ raise ValueError(
1152
+ f"Destination matrix must have {src.ncol} rows and {src.nrow} columns, got {dest.nrow} and {dest.ncol}"
1153
+ )
1154
+ block_index_map = None
1155
+ dest.values.zero_()
1156
+ else:
1157
+ _bsr_resize(dest, rows_of_blocks=src.ncol, cols_of_blocks=src.nrow)
1158
+
1159
+ nnz = src.nnz
1160
+ if nnz == 0:
1161
+ bsr_set_zero(dest)
1162
+ return
1163
+
1164
+ # Increase dest array sizes if needed
1165
+ _bsr_ensure_fits(dest, nnz=nnz)
1166
+
1167
+ from warp._src.context import runtime
1168
+
1169
+ if dest.values.device.is_cpu:
1170
+ native_func = runtime.core.wp_bsr_transpose_host
1171
+ else:
1172
+ native_func = runtime.core.wp_bsr_transpose_device
1173
+
1174
+ block_index_map = wp.empty(shape=2 * nnz, dtype=int, device=src.device)
1175
+
1176
+ with wp.ScopedDevice(dest.device):
1177
+ native_func(
1178
+ src.nrow,
1179
+ src.ncol,
1180
+ nnz,
1181
+ ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1182
+ ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1183
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1184
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1185
+ ctypes.cast(block_index_map.ptr, ctypes.POINTER(ctypes.c_int32)),
1186
+ )
1187
+
1188
+ dest._copy_nnz_async()
1189
+
1190
+ wp.launch(
1191
+ _bsr_transpose_values,
1192
+ dim=(dest.nnz, *dest.block_shape),
1193
+ device=dest.device,
1194
+ inputs=[
1195
+ src.ncol,
1196
+ dest.scalar_type(src_scale),
1197
+ src.offsets,
1198
+ src.columns,
1199
+ src.scalar_values,
1200
+ block_index_map,
1201
+ dest.offsets,
1202
+ dest.columns,
1203
+ ],
1204
+ outputs=[dest.scalar_values],
1205
+ )
1206
+
1207
+
1208
+ def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
1209
+ """Return a copy of the transposed matrix ``A``."""
1210
+
1211
+ if A.block_shape == (1, 1):
1212
+ block_type = A.values.dtype
1213
+ else:
1214
+ block_type = wp.mat(shape=A.block_shape[::-1], dtype=A.scalar_type)
1215
+
1216
+ transposed = bsr_zeros(
1217
+ rows_of_blocks=A.ncol,
1218
+ cols_of_blocks=A.nrow,
1219
+ block_type=block_type,
1220
+ device=A.device,
1221
+ )
1222
+ transposed.values.requires_grad = A.requires_grad
1223
+ bsr_set_transpose(dest=transposed, src=A)
1224
+ return transposed
1225
+
1226
+
1227
+ @wp.kernel
1228
+ def _bsr_get_diag_kernel(
1229
+ scale: Any,
1230
+ A_offsets: wp.array(dtype=int),
1231
+ A_columns: wp.array(dtype=int),
1232
+ A_values: wp.array3d(dtype=Any),
1233
+ out: wp.array3d(dtype=Any),
1234
+ ):
1235
+ row, br, bc = wp.tid()
1236
+
1237
+ diag = bsr_block_index(row, row, A_offsets, A_columns)
1238
+ if diag != -1:
1239
+ out[row, br, bc] = scale * A_values[diag, br, bc]
1240
+
1241
+
1242
+ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: Array[BlockType] | None = None) -> Array[BlockType]:
1243
+ """Return the array of blocks that constitute the diagonal of a sparse matrix.
1244
+
1245
+ Args:
1246
+ A: The sparse matrix from which to extract the diagonal.
1247
+ out: If provided, the array into which to store the diagonal blocks.
1248
+ """
1249
+
1250
+ A, scale = _extract_matrix_and_scale(A)
1251
+
1252
+ dim = min(A.nrow, A.ncol)
1253
+
1254
+ if out is None:
1255
+ out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
1256
+ else:
1257
+ if not types_equal(out.dtype, A.values.dtype):
1258
+ raise ValueError(f"Output array must have type {A.values.dtype}, got {out.dtype}")
1259
+ if out.device != A.values.device:
1260
+ raise ValueError(f"Output array must reside on device {A.values.device}, got {out.device}")
1261
+ if out.shape[0] < dim:
1262
+ raise ValueError(f"Output array must be of length at least {dim}, got {out.shape[0]}")
1263
+
1264
+ wp.launch(
1265
+ kernel=_bsr_get_diag_kernel,
1266
+ dim=(dim, *A.block_shape),
1267
+ device=A.values.device,
1268
+ inputs=[A.scalar_type(scale), A.offsets, A.columns, A.scalar_values, _as_3d_array(out, A.block_shape)],
1269
+ )
1270
+
1271
+ return out
1272
+
1273
+
1274
+ @wp.kernel(enable_backward=False)
1275
+ def _bsr_set_diag_kernel(
1276
+ nnz: int,
1277
+ A_offsets: wp.array(dtype=int),
1278
+ A_columns: wp.array(dtype=int),
1279
+ ):
1280
+ row = wp.tid()
1281
+ A_offsets[row] = wp.min(row, nnz)
1282
+ if row < nnz:
1283
+ A_columns[row] = row
1284
+
1285
+
1286
+ def bsr_set_diag(
1287
+ A: BsrMatrix[BlockType],
1288
+ diag: BlockType | Array[BlockType],
1289
+ rows_of_blocks: int | None = None,
1290
+ cols_of_blocks: int | None = None,
1291
+ ) -> None:
1292
+ """Set ``A`` as a block-diagonal matrix.
1293
+
1294
+ Args:
1295
+ A: The sparse matrix to modify.
1296
+ diag: Specifies the values for diagonal blocks. Can be one of:
1297
+
1298
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1299
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1300
+ - ``None``: Diagonal block values are left uninitialized
1301
+
1302
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
1303
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
1304
+
1305
+ The shape of the matrix will be defined one of the following, in this order:
1306
+
1307
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1308
+ If only one is given, the second is assumed equal.
1309
+ - The first dimension of ``diag``, if ``diag`` is an array
1310
+ - The current dimensions of ``A`` otherwise
1311
+ """
1312
+
1313
+ if rows_of_blocks is None and cols_of_blocks is not None:
1314
+ rows_of_blocks = cols_of_blocks
1315
+ if cols_of_blocks is None and rows_of_blocks is not None:
1316
+ cols_of_blocks = rows_of_blocks
1317
+
1318
+ if is_array(diag):
1319
+ if rows_of_blocks is None:
1320
+ rows_of_blocks = diag.shape[0]
1321
+ cols_of_blocks = diag.shape[0]
1322
+
1323
+ if rows_of_blocks is not None:
1324
+ _bsr_resize(A, rows_of_blocks, cols_of_blocks)
1325
+
1326
+ nnz = min(A.nrow, A.ncol)
1327
+ A.notify_nnz_changed(nnz=nnz) # notify change of nnz upper bound
1328
+
1329
+ wp.launch(
1330
+ kernel=_bsr_set_diag_kernel,
1331
+ dim=nnz + 1,
1332
+ device=A.offsets.device,
1333
+ inputs=[nnz, A.offsets, A.columns],
1334
+ )
1335
+
1336
+ A.notify_nnz_changed(nnz=nnz) # notify change of offsets
1337
+
1338
+ if is_array(diag):
1339
+ wp.copy(src=diag, dest=A.values, count=nnz)
1340
+ elif diag is not None:
1341
+ A.values.fill_(diag)
1342
+
1343
+
1344
+ def bsr_diag(
1345
+ diag: BlockType | Array[BlockType] | None = None,
1346
+ rows_of_blocks: int | None = None,
1347
+ cols_of_blocks: int | None = None,
1348
+ block_type: BlockType | None = None,
1349
+ device=None,
1350
+ ) -> BsrMatrix[BlockType]:
1351
+ """Create and return a block-diagonal BSR matrix from an given block value or array of block values.
1352
+
1353
+ Args:
1354
+ diag: Specifies the values for diagonal blocks. Can be one of:
1355
+
1356
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1357
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1358
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
1359
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
1360
+ block_type: If ``diag`` is ``None``, block type of the matrix. Otherwise deduced from ``diag``
1361
+ device: If ``diag`` is not a Warp array, device on which to allocate the matrix. Otherwise deduced from ``diag``
1362
+
1363
+ The shape of the matrix will be defined one of the following, in this order:
1364
+
1365
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1366
+ If only one is given, the second is assumed equal.
1367
+ - The first dimension of ``diag`` if ``diag`` is an array.
1368
+ """
1369
+
1370
+ if rows_of_blocks is None and cols_of_blocks is not None:
1371
+ rows_of_blocks = cols_of_blocks
1372
+ if cols_of_blocks is None and rows_of_blocks is not None:
1373
+ cols_of_blocks = rows_of_blocks
1374
+
1375
+ if is_array(diag):
1376
+ if rows_of_blocks is None:
1377
+ rows_of_blocks = diag.shape[0]
1378
+ cols_of_blocks = diag.shape[0]
1379
+
1380
+ block_type = diag.dtype
1381
+ device = diag.device
1382
+ else:
1383
+ if rows_of_blocks is None:
1384
+ raise ValueError(
1385
+ "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
1386
+ )
1387
+
1388
+ if block_type is None:
1389
+ if diag is None:
1390
+ raise ValueError("Either `diag` or `block_type` needs to be provided")
1391
+
1392
+ block_type = type(diag)
1393
+ if not type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
1394
+ block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
1395
+
1396
+ A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
1397
+ if is_array(diag):
1398
+ A.values.requires_grad = diag.requires_grad
1399
+ bsr_set_diag(A, diag)
1400
+ return A
1401
+
1402
+
1403
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: int | None = None) -> None:
1404
+ """Set ``A`` as the identity matrix.
1405
+
1406
+ Args:
1407
+ A: The sparse matrix to modify.
1408
+ rows_of_blocks: If provided, the matrix will be resized as a square
1409
+ matrix with ``rows_of_blocks`` rows and columns.
1410
+ """
1411
+
1412
+ if A.block_shape == (1, 1):
1413
+ identity = A.scalar_type(1.0)
1414
+ else:
1415
+ from numpy import eye
1416
+
1417
+ identity = eye(A.block_shape[0])
1418
+
1419
+ bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
1420
+
1421
+
1422
+ def bsr_identity(
1423
+ rows_of_blocks: int,
1424
+ block_type: BlockType[Rows, Rows, Scalar],
1425
+ device: wp._src.context.Devicelike = None,
1426
+ ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
1427
+ """Create and return a square identity matrix.
1428
+
1429
+ Args:
1430
+ rows_of_blocks: Number of rows and columns of blocks in the created matrix.
1431
+ block_type: Block type for the newly created matrix. Must be square
1432
+ device: Device onto which to allocate the data arrays
1433
+ """
1434
+ A = bsr_zeros(
1435
+ rows_of_blocks=rows_of_blocks,
1436
+ cols_of_blocks=rows_of_blocks,
1437
+ block_type=block_type,
1438
+ device=device,
1439
+ )
1440
+ bsr_set_identity(A)
1441
+ return A
1442
+
1443
+
1444
+ @wp.kernel
1445
+ def _bsr_scale_kernel(
1446
+ alpha: Any,
1447
+ values: wp.array(dtype=Any),
1448
+ ):
1449
+ row = wp.tid()
1450
+ values[row] = alpha * values[row]
1451
+
1452
+
1453
+ @wp.kernel
1454
+ def _bsr_scale_kernel(
1455
+ alpha: Any,
1456
+ values: wp.array3d(dtype=Any),
1457
+ ):
1458
+ row, br, bc = wp.tid()
1459
+ values[row, br, bc] = alpha * values[row, br, bc]
1460
+
1461
+
1462
+ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1463
+ """Perform the operation ``x := alpha * x`` on BSR matrix ``x`` and return ``x``."""
1464
+
1465
+ x, scale = _extract_matrix_and_scale(x)
1466
+ alpha *= scale
1467
+
1468
+ if alpha != 1.0 and x.nnz > 0:
1469
+ if alpha == 0.0:
1470
+ x.values.zero_()
1471
+ else:
1472
+ alpha = x.scalar_type(alpha)
1473
+
1474
+ wp.launch(
1475
+ kernel=_bsr_scale_kernel,
1476
+ dim=(x.nnz, *x.block_shape),
1477
+ device=x.values.device,
1478
+ inputs=[alpha, x.scalar_values],
1479
+ )
1480
+
1481
+ return x
1482
+
1483
+
1484
+ @wp.kernel(enable_backward=False)
1485
+ def _bsr_get_block_row(row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
1486
+ block = wp.tid()
1487
+ rows[block] = bsr_row_index(bsr_offsets, row_count, block)
1488
+
1489
+
1490
+ @wp.kernel
1491
+ def _bsr_axpy_add_block(
1492
+ src_offset: int,
1493
+ scale: Any,
1494
+ rows: wp.array(dtype=int),
1495
+ cols: wp.array(dtype=int),
1496
+ dst_offsets: wp.array(dtype=int),
1497
+ dst_columns: wp.array(dtype=int),
1498
+ src_values: wp.array3d(dtype=Any),
1499
+ dst_values: wp.array3d(dtype=Any),
1500
+ ):
1501
+ i, br, bc = wp.tid()
1502
+ row = rows[i + src_offset]
1503
+ col = cols[i + src_offset]
1504
+
1505
+ block = bsr_block_index(row, col, dst_offsets, dst_columns)
1506
+ if block != -1:
1507
+ dst_values[block, br, bc] += scale * src_values[i, br, bc]
1508
+
1509
+
1510
+ @wp.kernel
1511
+ def _bsr_axpy_masked(
1512
+ alpha: Any,
1513
+ row_count: int,
1514
+ src_offsets: wp.array(dtype=int),
1515
+ src_columns: wp.array(dtype=int),
1516
+ src_values: wp.array3d(dtype=Any),
1517
+ dst_offsets: wp.array(dtype=int),
1518
+ dst_columns: wp.array(dtype=int),
1519
+ dst_values: wp.array3d(dtype=Any),
1520
+ ):
1521
+ block, br, bc = wp.tid()
1522
+
1523
+ row = bsr_row_index(dst_offsets, row_count, block)
1524
+ if row == -1:
1525
+ return
1526
+
1527
+ col = dst_columns[block]
1528
+ src_block = bsr_block_index(row, col, src_offsets, src_columns)
1529
+ if src_block != -1:
1530
+ dst_values[block, br, bc] += alpha * src_values[src_block, br, bc]
1531
+
1532
+
1533
+ class bsr_axpy_work_arrays:
1534
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls."""
1535
+
1536
+ def __init__(self):
1537
+ self._reset(None)
1538
+
1539
+ def _reset(self, device):
1540
+ self.device = device
1541
+ self._sum_rows = None
1542
+ self._sum_cols = None
1543
+ self._old_y_values = None
1544
+ self._old_x_values = None
1545
+
1546
+ def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
1547
+ if self.device != device:
1548
+ self._reset(device)
1549
+
1550
+ if self._sum_rows is None or self._sum_rows.size < sum_nnz:
1551
+ self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
1552
+ if self._sum_cols is None or self._sum_cols.size < sum_nnz:
1553
+ self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
1554
+
1555
+ if self._old_y_values is None or self._old_y_values.size < y.nnz:
1556
+ self._old_y_values = wp.empty_like(y.values[: y.nnz])
1557
+
1558
+
1559
+ def bsr_axpy(
1560
+ x: BsrMatrixOrExpression,
1561
+ y: BsrMatrix[BlockType[Rows, Cols, Scalar]] | None = None,
1562
+ alpha: Scalar = 1.0,
1563
+ beta: Scalar = 1.0,
1564
+ masked: bool = False,
1565
+ work_arrays: bsr_axpy_work_arrays | None = None,
1566
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1567
+ """
1568
+ Perform the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices ``x`` and ``y`` and return ``y``.
1569
+
1570
+ The ``x`` and ``y`` matrices are allowed to alias.
1571
+
1572
+ Args:
1573
+ x: Read-only first operand.
1574
+ y: Mutable second operand and output matrix. If ``y`` is not provided, it will be allocated and treated as zero.
1575
+ alpha: Uniform scaling factor for ``x``.
1576
+ beta: Uniform scaling factor for ``y``.
1577
+ masked: If ``True``, keep the non-zero topology of ``y`` unchanged.
1578
+ work_arrays: In most cases, this function will require the use of temporary storage.
1579
+ This storage can be reused across calls by passing an instance of
1580
+ :class:`bsr_axpy_work_arrays` in ``work_arrays``.
1581
+ """
1582
+
1583
+ x, x_scale = _extract_matrix_and_scale(x)
1584
+ alpha *= x_scale
1585
+
1586
+ if y is None:
1587
+ if masked:
1588
+ raise ValueError("Left-hand-side 'y' matrix must be provided for masked addition")
1589
+
1590
+ # If not output matrix is provided, allocate it for convenience
1591
+ y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
1592
+ y.values.requires_grad = x.requires_grad
1593
+ beta = 0.0
1594
+
1595
+ x_nnz = x.nnz
1596
+ y_nnz = y.nnz
1597
+
1598
+ # Handle easy cases first
1599
+ if beta == 0.0 or y_nnz == 0:
1600
+ bsr_assign(src=x, dest=y, masked=masked)
1601
+ return bsr_scale(y, alpha=alpha)
1602
+
1603
+ if alpha == 0.0 or x_nnz == 0:
1604
+ return bsr_scale(y, alpha=beta)
1605
+
1606
+ if x == y:
1607
+ # Aliasing case
1608
+ return bsr_scale(y, alpha=alpha + beta)
1609
+
1610
+ # General case
1611
+
1612
+ if not isinstance(alpha, y.scalar_type):
1613
+ alpha = y.scalar_type(alpha)
1614
+ if not isinstance(beta, y.scalar_type):
1615
+ beta = y.scalar_type(beta)
1616
+
1617
+ if x.values.device != y.values.device:
1618
+ raise ValueError(f"All arguments must reside on the same device, got {x.values.device} and {y.values.device}")
1619
+
1620
+ if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
1621
+ raise ValueError(
1622
+ f"Matrices must have the same block type, got ({x.block_shape}, {x.scalar_type}) and ({y.block_shape}, {y.scalar_type})"
1623
+ )
1624
+
1625
+ if x.nrow != y.nrow or x.ncol != y.ncol:
1626
+ raise ValueError(
1627
+ f"Matrices must have the same number of rows and columns, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
1628
+ )
1629
+
1630
+ device = y.values.device
1631
+ if masked:
1632
+ bsr_scale(y, alpha=beta.value)
1633
+ wp.launch(
1634
+ kernel=_bsr_axpy_masked,
1635
+ device=device,
1636
+ dim=(y_nnz, y.block_shape[0], y.block_shape[1]),
1637
+ inputs=[
1638
+ alpha,
1639
+ x.nrow,
1640
+ x.offsets,
1641
+ x.columns,
1642
+ x.scalar_values,
1643
+ y.offsets,
1644
+ y.columns,
1645
+ y.scalar_values,
1646
+ ],
1647
+ )
1648
+
1649
+ else:
1650
+ if work_arrays is None:
1651
+ work_arrays = bsr_axpy_work_arrays()
1652
+
1653
+ sum_nnz = x_nnz + y_nnz
1654
+ work_arrays._allocate(device, y, sum_nnz)
1655
+
1656
+ wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
1657
+ y.uncompress_rows(out=work_arrays._sum_rows)
1658
+
1659
+ wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
1660
+ x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
1661
+
1662
+ # Save old y values before overwriting matrix
1663
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
1664
+
1665
+ # Increase dest array sizes if needed
1666
+ _bsr_ensure_fits(y, nnz=sum_nnz)
1667
+
1668
+ from warp._src.context import runtime
1669
+
1670
+ if device.is_cpu:
1671
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
1672
+ else:
1673
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
1674
+
1675
+ old_y_nnz = y_nnz
1676
+ nnz_buf, nnz_event = y._setup_nnz_transfer()
1677
+
1678
+ with wp.ScopedDevice(y.device):
1679
+ native_func(
1680
+ y.block_size,
1681
+ 0, # scalar_size_in_bytes
1682
+ y.nrow,
1683
+ y.ncol,
1684
+ sum_nnz,
1685
+ None, # device nnz
1686
+ ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1687
+ ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1688
+ None, # triplet values
1689
+ 0, # zero_value_mask
1690
+ masked,
1691
+ None, # summed block offsets
1692
+ None, # summed block indices
1693
+ ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1694
+ ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1695
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
1696
+ _optional_ctypes_event(nnz_event),
1697
+ )
1698
+
1699
+ y.values.zero_()
1700
+
1701
+ wp.launch(
1702
+ kernel=_bsr_axpy_add_block,
1703
+ device=device,
1704
+ dim=(old_y_nnz, y.block_shape[0], y.block_shape[1]),
1705
+ inputs=[
1706
+ 0,
1707
+ beta,
1708
+ work_arrays._sum_rows,
1709
+ work_arrays._sum_cols,
1710
+ y.offsets,
1711
+ y.columns,
1712
+ _as_3d_array(work_arrays._old_y_values, y.block_shape),
1713
+ y.scalar_values,
1714
+ ],
1715
+ )
1716
+
1717
+ wp.launch(
1718
+ kernel=_bsr_axpy_add_block,
1719
+ device=device,
1720
+ dim=(x_nnz, y.block_shape[0], y.block_shape[1]),
1721
+ inputs=[
1722
+ old_y_nnz,
1723
+ alpha,
1724
+ work_arrays._sum_rows,
1725
+ work_arrays._sum_cols,
1726
+ y.offsets,
1727
+ y.columns,
1728
+ x.scalar_values,
1729
+ y.scalar_values,
1730
+ ],
1731
+ )
1732
+
1733
+ return y
1734
+
1735
+
1736
+ def make_bsr_mm_count_coeffs(tile_size):
1737
+ from warp._src.fem.cache import dynamic_kernel
1738
+
1739
+ @dynamic_kernel(suffix=tile_size)
1740
+ def bsr_mm_count_coeffs(
1741
+ y_ncol: int,
1742
+ z_nnz: int,
1743
+ x_offsets: wp.array(dtype=int),
1744
+ x_columns: wp.array(dtype=int),
1745
+ y_offsets: wp.array(dtype=int),
1746
+ y_columns: wp.array(dtype=int),
1747
+ row_min: wp.array(dtype=int),
1748
+ block_counts: wp.array(dtype=int),
1749
+ ):
1750
+ row, lane = wp.tid()
1751
+ row_count = int(0)
1752
+
1753
+ x_beg = x_offsets[row]
1754
+ x_end = x_offsets[row + 1]
1755
+
1756
+ min_col = y_ncol
1757
+ max_col = int(0)
1758
+
1759
+ for x_block in range(x_beg + lane, x_end, tile_size):
1760
+ x_col = x_columns[x_block]
1761
+ y_row_end = y_offsets[x_col + 1]
1762
+ y_row_beg = y_offsets[x_col]
1763
+ block_count = y_row_end - y_row_beg
1764
+ if block_count != 0:
1765
+ min_col = wp.min(y_columns[y_row_beg], min_col)
1766
+ max_col = wp.max(y_columns[y_row_end - 1], max_col)
1767
+
1768
+ block_counts[x_block + 1] = block_count
1769
+ row_count += block_count
1770
+
1771
+ if wp.static(tile_size) > 1:
1772
+ row_count = wp.tile_sum(wp.tile(row_count))[0]
1773
+ min_col = wp.tile_min(wp.tile(min_col))[0]
1774
+ max_col = wp.tile_max(wp.tile(max_col))[0]
1775
+ col_range_size = wp.max(0, max_col - min_col + 1)
1776
+
1777
+ if row_count > col_range_size:
1778
+ # Optimization for deep products.
1779
+ # Do not store the whole whole list of src product terms, they would be highly redundant
1780
+ # Instead just mark a range in the output matrix
1781
+
1782
+ if lane == 0:
1783
+ row_min[row] = min_col
1784
+ block_counts[x_end] = col_range_size
1785
+
1786
+ for x_block in range(x_beg + lane, x_end - 1, tile_size):
1787
+ block_counts[x_block + 1] = 0
1788
+ elif lane == 0:
1789
+ row_min[row] = -1
1790
+
1791
+ if lane == 0 and row == 0:
1792
+ block_counts[0] = z_nnz
1793
+
1794
+ return bsr_mm_count_coeffs
1795
+
1796
+
1797
+ @wp.kernel(enable_backward=False)
1798
+ def _bsr_mm_list_coeffs(
1799
+ copied_z_nnz: int,
1800
+ mm_nnz: int,
1801
+ x_nrow: int,
1802
+ x_offsets: wp.array(dtype=int),
1803
+ x_columns: wp.array(dtype=int),
1804
+ y_offsets: wp.array(dtype=int),
1805
+ y_columns: wp.array(dtype=int),
1806
+ mm_row_min: wp.array(dtype=int),
1807
+ mm_offsets: wp.array(dtype=int),
1808
+ mm_rows: wp.array(dtype=int),
1809
+ mm_cols: wp.array(dtype=int),
1810
+ mm_src_blocks: wp.array(dtype=int),
1811
+ ):
1812
+ mm_block = wp.tid() + copied_z_nnz
1813
+
1814
+ x_nnz = x_offsets[x_nrow]
1815
+
1816
+ x_block = bsr_row_index(mm_offsets, x_nnz, mm_block)
1817
+
1818
+ if x_block == -1:
1819
+ mm_cols[mm_block] = -1
1820
+ mm_rows[mm_block] = -1
1821
+ return
1822
+
1823
+ if mm_block + 1 == mm_nnz and mm_nnz < mm_offsets[x_nnz]:
1824
+ wp.printf(
1825
+ "Number of potential `bsr_mm` blocks (%d) exceeded `max_nnz` (%d)\n",
1826
+ mm_offsets[x_nnz] - copied_z_nnz,
1827
+ mm_nnz - copied_z_nnz,
1828
+ )
1829
+
1830
+ pos = mm_block - mm_offsets[x_block]
1831
+
1832
+ row = bsr_row_index(x_offsets, x_nrow, x_block)
1833
+
1834
+ row_min_col = mm_row_min[row]
1835
+ if row_min_col == -1:
1836
+ x_col = x_columns[x_block]
1837
+ y_beg = y_offsets[x_col]
1838
+ y_block = y_beg + pos
1839
+ col = y_columns[y_block]
1840
+ src_block = x_block
1841
+ else:
1842
+ col = row_min_col + pos
1843
+ src_block = -1
1844
+
1845
+ mm_cols[mm_block] = col
1846
+ mm_rows[mm_block] = row
1847
+ mm_src_blocks[mm_block] = src_block
1848
+
1849
+
1850
+ @wp.func
1851
+ def _bsr_mm_use_triplets(
1852
+ row: int,
1853
+ mm_block: int,
1854
+ mm_row_min: wp.array(dtype=int),
1855
+ row_offsets: wp.array(dtype=int),
1856
+ summed_triplet_offsets: wp.array(dtype=int),
1857
+ ):
1858
+ x_beg = row_offsets[row]
1859
+ x_end = row_offsets[row + 1]
1860
+
1861
+ if mm_row_min:
1862
+ if mm_row_min[row] == -1:
1863
+ if mm_block == 0:
1864
+ block_beg = 0
1865
+ else:
1866
+ block_beg = summed_triplet_offsets[mm_block - 1]
1867
+ block_end = summed_triplet_offsets[mm_block]
1868
+
1869
+ if x_end - x_beg > 3 * (block_end - block_beg):
1870
+ return True, block_beg, block_end
1871
+
1872
+ return False, x_beg, x_end
1873
+
1874
+
1875
+ @wp.kernel(enable_backward=False)
1876
+ def _bsr_mm_compute_values(
1877
+ alpha: Any,
1878
+ x_offsets: wp.array(dtype=int),
1879
+ x_columns: wp.array(dtype=int),
1880
+ x_values: wp.array(dtype=Any),
1881
+ y_offsets: wp.array(dtype=int),
1882
+ y_columns: wp.array(dtype=int),
1883
+ y_values: wp.array(dtype=Any),
1884
+ mm_row_min: wp.array(dtype=int),
1885
+ summed_triplet_offsets: wp.array(dtype=int),
1886
+ summed_triplet_src_blocks: wp.indexedarray(dtype=int),
1887
+ mm_row_count: int,
1888
+ mm_offsets: wp.array(dtype=int),
1889
+ mm_cols: wp.array(dtype=int),
1890
+ mm_values: wp.array(dtype=Any),
1891
+ ):
1892
+ mm_block = wp.tid()
1893
+
1894
+ row = bsr_row_index(mm_offsets, mm_row_count, mm_block)
1895
+ if row == -1:
1896
+ return
1897
+
1898
+ use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
1899
+ row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
1900
+ )
1901
+
1902
+ mm_val = mm_values.dtype(type(alpha)(0.0))
1903
+ col = mm_cols[mm_block]
1904
+ if use_triplets:
1905
+ for tpl_idx in range(block_beg, block_end):
1906
+ x_block = summed_triplet_src_blocks[tpl_idx]
1907
+ x_col = x_columns[x_block]
1908
+ if x_block != -1:
1909
+ y_block = bsr_block_index(x_col, col, y_offsets, y_columns)
1910
+ mm_val += x_values[x_block] * y_values[y_block]
1911
+ else:
1912
+ for x_block in range(block_beg, block_end):
1913
+ x_col = x_columns[x_block]
1914
+ y_block = bsr_block_index(x_col, col, y_offsets, y_columns)
1915
+ if y_block != -1:
1916
+ mm_val += x_values[x_block] * y_values[y_block]
1917
+
1918
+ mm_values[mm_block] += alpha * mm_val
1919
+
1920
+
1921
+ def make_bsr_mm_compute_values_tiled_outer(subblock_rows, subblock_cols, block_depth, scalar_type, tile_size):
1922
+ from warp._src.fem.cache import dynamic_func, dynamic_kernel
1923
+
1924
+ mm_type = wp.mat(dtype=scalar_type, shape=(subblock_rows, subblock_cols))
1925
+
1926
+ x_col_vec_t = wp.vec(dtype=scalar_type, length=subblock_rows)
1927
+ y_row_vec_t = wp.vec(dtype=scalar_type, length=subblock_cols)
1928
+
1929
+ suffix = (subblock_rows, subblock_cols, block_depth, tile_size, scalar_type.__name__)
1930
+
1931
+ @dynamic_func(suffix=suffix)
1932
+ def _outer_product(
1933
+ x_values: wp.array2d(dtype=Any),
1934
+ y_values: wp.array2d(dtype=Any),
1935
+ brow_off: int,
1936
+ bcol_off: int,
1937
+ block_col: int,
1938
+ brow_count: int,
1939
+ bcol_count: int,
1940
+ ):
1941
+ x_col_vec = x_col_vec_t()
1942
+ y_row_vec = y_row_vec_t()
1943
+
1944
+ for k in range(brow_count):
1945
+ x_col_vec[k] = x_values[brow_off + k, block_col]
1946
+ for k in range(bcol_count):
1947
+ y_row_vec[k] = y_values[block_col, bcol_off + k]
1948
+
1949
+ return wp.outer(x_col_vec, y_row_vec)
1950
+
1951
+ @dynamic_kernel(suffix=suffix, kernel_options={"enable_backward": False})
1952
+ def bsr_mm_compute_values(
1953
+ alpha: Any,
1954
+ x_offsets: wp.array(dtype=int),
1955
+ x_columns: wp.array(dtype=int),
1956
+ x_values: wp.array3d(dtype=Any),
1957
+ y_offsets: wp.array(dtype=int),
1958
+ y_columns: wp.array(dtype=int),
1959
+ y_values: wp.array3d(dtype=Any),
1960
+ mm_row_min: wp.array(dtype=int),
1961
+ summed_triplet_offsets: wp.array(dtype=int),
1962
+ summed_triplet_src_blocks: wp.indexedarray(dtype=int),
1963
+ mm_row_count: int,
1964
+ mm_offsets: wp.array(dtype=int),
1965
+ mm_cols: wp.array(dtype=int),
1966
+ mm_values: wp.array3d(dtype=Any),
1967
+ ):
1968
+ mm_block, subrow, subcol, lane = wp.tid()
1969
+
1970
+ brow_off = subrow * wp.static(subblock_rows)
1971
+ bcol_off = subcol * wp.static(subblock_cols)
1972
+
1973
+ brow_count = wp.min(mm_values.shape[1] - brow_off, subblock_rows)
1974
+ bcol_count = wp.min(mm_values.shape[2] - bcol_off, subblock_cols)
1975
+
1976
+ mm_row = bsr_row_index(mm_offsets, mm_row_count, mm_block)
1977
+ if mm_row == -1:
1978
+ return
1979
+
1980
+ lane_val = mm_type()
1981
+
1982
+ use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
1983
+ mm_row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
1984
+ )
1985
+
1986
+ col_count = (block_end - block_beg) * block_depth
1987
+
1988
+ mm_col = mm_cols[mm_block]
1989
+ if use_triplets:
1990
+ for col in range(lane, col_count, tile_size):
1991
+ tpl_block = col // wp.static(block_depth)
1992
+ block_col = col - tpl_block * wp.static(block_depth)
1993
+ tpl_block += block_beg
1994
+
1995
+ x_block = summed_triplet_src_blocks[tpl_block]
1996
+ if x_block != -1:
1997
+ x_col = x_columns[x_block]
1998
+ y_block = bsr_block_index(x_col, mm_col, y_offsets, y_columns)
1999
+ lane_val += _outer_product(
2000
+ x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
2001
+ )
2002
+ else:
2003
+ for col in range(lane, col_count, tile_size):
2004
+ x_block = col // wp.static(block_depth)
2005
+ block_col = col - x_block * wp.static(block_depth)
2006
+ x_block += block_beg
2007
+
2008
+ x_col = x_columns[x_block]
2009
+ y_block = bsr_block_index(x_col, mm_col, y_offsets, y_columns)
2010
+
2011
+ if y_block != -1:
2012
+ lane_val += _outer_product(
2013
+ x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
2014
+ )
2015
+
2016
+ mm_val = wp.tile_sum(wp.tile(lane_val, preserve_type=True))[0]
2017
+
2018
+ for coef in range(lane, wp.static(subblock_cols * subblock_rows), tile_size):
2019
+ br = coef // subblock_cols
2020
+ bc = coef - br * subblock_cols
2021
+ if br < brow_count and bc < bcol_count:
2022
+ mm_values[mm_block, br + brow_off, bc + bcol_off] += mm_val[br, bc] * alpha
2023
+
2024
+ return bsr_mm_compute_values
2025
+
2026
+
2027
+ class bsr_mm_work_arrays:
2028
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
2029
+
2030
+ def __init__(self):
2031
+ self._reset(None)
2032
+
2033
+ def _reset(self, device):
2034
+ self.device = device
2035
+ self._mm_row_min = None
2036
+ self._mm_block_counts = None
2037
+ self._mm_rows = None
2038
+ self._mm_cols = None
2039
+ self._mm_src_blocks = None
2040
+ self._old_z_values = None
2041
+ self._old_z_offsets = None
2042
+ self._old_z_columns = None
2043
+ self._mm_nnz = 0
2044
+
2045
+ def _allocate_stage_1(self, device, x_nnz: int, z: BsrMatrix, beta: float, z_aliasing: bool):
2046
+ if self.device != device:
2047
+ self._reset(device)
2048
+
2049
+ # Allocations that do not depend on any computation
2050
+ self._copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
2051
+
2052
+ if self._mm_row_min is None or self._mm_block_counts.size < z.nrow + 1:
2053
+ self._mm_row_min = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
2054
+ if self._mm_block_counts is None or self._mm_block_counts.size < x_nnz + 1:
2055
+ self._mm_block_counts = wp.empty(shape=(x_nnz + 1,), dtype=int, device=self.device)
2056
+
2057
+ if self._copied_z_nnz > 0:
2058
+ if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
2059
+ self._old_z_values = wp.empty(shape=(self._copied_z_nnz,), dtype=z.values.dtype, device=self.device)
2060
+
2061
+ if z_aliasing:
2062
+ if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
2063
+ self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
2064
+ if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
2065
+ self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
2066
+
2067
+ def _allocate_stage_2(self, mm_nnz: int):
2068
+ # Allocations that depend on unmerged nnz estimate
2069
+ self._mm_nnz = mm_nnz
2070
+ if self._mm_rows is None or self._mm_rows.size < mm_nnz:
2071
+ self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
2072
+ if self._mm_cols is None or self._mm_cols.size < mm_nnz:
2073
+ self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
2074
+ if self._mm_src_blocks is None or self._mm_src_blocks.size < mm_nnz:
2075
+ self._mm_src_blocks = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
2076
+
2077
+
2078
+ def bsr_mm(
2079
+ x: BsrMatrixOrExpression[BlockType[Rows, Any, Scalar]],
2080
+ y: BsrMatrixOrExpression[BlockType[Any, Cols, Scalar]],
2081
+ z: BsrMatrix[BlockType[Rows, Cols, Scalar]] | None = None,
2082
+ alpha: Scalar = 1.0,
2083
+ beta: Scalar = 0.0,
2084
+ masked: bool = False,
2085
+ work_arrays: bsr_mm_work_arrays | None = None,
2086
+ reuse_topology: bool = False,
2087
+ tile_size: int = 0,
2088
+ max_new_nnz: int | None = None,
2089
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
2090
+ """
2091
+ Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
2092
+
2093
+ The ``x``, ``y`` and ``z`` matrices are allowed to alias.
2094
+ If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
2095
+
2096
+ This method can be graph-captured if either:
2097
+ - `masked=True`
2098
+ - `reuse_topology=True`
2099
+ - `max_new_nnz` is provided
2100
+
2101
+ Args:
2102
+ x: Read-only left operand of the matrix-matrix product.
2103
+ y: Read-only right operand of the matrix-matrix product.
2104
+ z: Mutable affine operand and result matrix. If ``z`` is not provided, it will be allocated and treated as zero.
2105
+ alpha: Uniform scaling factor for the ``x @ y`` product
2106
+ beta: Uniform scaling factor for ``z``
2107
+ masked: If ``True``, keep the non-zero topology of ``z`` unchanged.
2108
+ work_arrays: In most cases, this function will require the use of temporary storage.
2109
+ This storage can be reused across calls by passing an instance of
2110
+ :class:`bsr_mm_work_arrays` in ``work_arrays``.
2111
+ reuse_topology: If ``True``, reuse the product topology information
2112
+ stored in ``work_arrays`` rather than recompute it from scratch.
2113
+ The matrices ``x``, ``y`` and ``z`` must be structurally similar to
2114
+ the previous call in which ``work_arrays`` were populated.
2115
+ max_new_nnz: If provided, the maximum number of non-zeros for the matrix-matrix product result
2116
+ (not counting the existing non-zeros in ``z``).
2117
+ tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
2118
+ If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
2119
+ use tiles using using an heuristic based on the matrix shape and number of non-zeros..
2120
+ """
2121
+
2122
+ x, x_scale = _extract_matrix_and_scale(x)
2123
+ alpha *= x_scale
2124
+ y, y_scale = _extract_matrix_and_scale(y)
2125
+ alpha *= y_scale
2126
+
2127
+ if z is None:
2128
+ if masked:
2129
+ raise ValueError("Left-hand-side 'z' matrix must be provided for masked multiplication")
2130
+
2131
+ # If not output matrix is provided, allocate it for convenience
2132
+ z_block_shape = (x.block_shape[0], y.block_shape[1])
2133
+ if z_block_shape == (1, 1):
2134
+ z_block_type = x.scalar_type
2135
+ else:
2136
+ z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
2137
+ z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
2138
+ z.values.requires_grad = x.requires_grad or y.requires_grad
2139
+ beta = 0.0
2140
+
2141
+ if x.values.device != y.values.device or x.values.device != z.values.device:
2142
+ raise ValueError(
2143
+ f"All arguments must reside on the same device, got {x.values.device}, {y.values.device} and {z.values.device}"
2144
+ )
2145
+
2146
+ if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
2147
+ raise ValueError(
2148
+ f"Matrices must have the same scalar type, got {x.scalar_type}, {y.scalar_type} and {z.scalar_type}"
2149
+ )
2150
+
2151
+ if (
2152
+ x.block_shape[0] != z.block_shape[0]
2153
+ or y.block_shape[1] != z.block_shape[1]
2154
+ or x.block_shape[1] != y.block_shape[0]
2155
+ ):
2156
+ raise ValueError(
2157
+ f"Incompatible block sizes for matrix multiplication, got ({x.block_shape}, {y.block_shape}) and ({z.block_shape})"
2158
+ )
2159
+
2160
+ if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
2161
+ raise ValueError(
2162
+ f"Incompatible number of rows/columns for matrix multiplication, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
2163
+ )
2164
+
2165
+ device = z.values.device
2166
+
2167
+ if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
2168
+ # Easy case
2169
+ return bsr_scale(z, beta)
2170
+
2171
+ z_aliasing = z == x or z == y
2172
+
2173
+ if masked:
2174
+ # no need to copy z, scale in-place
2175
+ copied_z_nnz = 0
2176
+ mm_nnz = z.nnz
2177
+
2178
+ if z_aliasing:
2179
+ raise ValueError("`masked=True` is not supported for aliased inputs")
2180
+
2181
+ if beta == 0.0:
2182
+ # do not bsr_scale(0), this would not preserve topology
2183
+ z.values.zero_()
2184
+ else:
2185
+ bsr_scale(z, beta)
2186
+ elif reuse_topology:
2187
+ if work_arrays is None:
2188
+ raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
2189
+
2190
+ copied_z_nnz = work_arrays._copied_z_nnz
2191
+ mm_nnz = work_arrays._mm_nnz
2192
+ else:
2193
+ if work_arrays is None:
2194
+ work_arrays = bsr_mm_work_arrays()
2195
+
2196
+ if max_new_nnz is None:
2197
+ if device.is_capturing:
2198
+ raise RuntimeError(
2199
+ "`bsr_mm` requires either `reuse_topology=True`, `masked=True` or `max_new_nnz` to be set for use in graph capture"
2200
+ )
2201
+ z.nnz_sync()
2202
+
2203
+ work_arrays._allocate_stage_1(device, x.nnz, z, beta, z_aliasing)
2204
+ copied_z_nnz = work_arrays._copied_z_nnz
2205
+
2206
+ # Prefix sum of number of (unmerged) mm blocks per row
2207
+ # Use either a thread or a block per row depending on avg nnz/row
2208
+ work_arrays._mm_block_counts.zero_()
2209
+ count_tile_size = 32
2210
+ if not device.is_cuda or x.nnz < 3 * count_tile_size * x.nrow:
2211
+ count_tile_size = 1
2212
+
2213
+ wp.launch(
2214
+ kernel=make_bsr_mm_count_coeffs(count_tile_size),
2215
+ device=device,
2216
+ dim=(z.nrow, count_tile_size),
2217
+ block_dim=count_tile_size if count_tile_size > 1 else 256,
2218
+ inputs=[
2219
+ y.ncol,
2220
+ copied_z_nnz,
2221
+ x.offsets,
2222
+ x.columns,
2223
+ y.offsets,
2224
+ y.columns,
2225
+ work_arrays._mm_row_min,
2226
+ work_arrays._mm_block_counts,
2227
+ ],
2228
+ )
2229
+ warp._src.utils.array_scan(work_arrays._mm_block_counts[: x.nnz + 1], work_arrays._mm_block_counts[: x.nnz + 1])
2230
+
2231
+ if max_new_nnz is not None:
2232
+ mm_nnz = max_new_nnz + copied_z_nnz
2233
+ else:
2234
+ # Get back total counts on host -- we need a synchronization here
2235
+ # Use pinned buffer from z, we are going to need it later anyway
2236
+ nnz_buf, _ = z._setup_nnz_transfer()
2237
+ stream = wp.get_stream(device) if device.is_cuda else None
2238
+ wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
2239
+ if device.is_cuda:
2240
+ wp.synchronize_stream(stream)
2241
+ mm_nnz = int(nnz_buf.numpy()[0])
2242
+
2243
+ if mm_nnz == copied_z_nnz:
2244
+ # x@y = 0
2245
+ return bsr_scale(z, beta)
2246
+
2247
+ work_arrays._allocate_stage_2(mm_nnz)
2248
+
2249
+ # If z has a non-zero scale, save current data before overwriting it
2250
+ if copied_z_nnz > 0:
2251
+ # Copy z row and column indices
2252
+ wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
2253
+ z.uncompress_rows(out=work_arrays._mm_rows)
2254
+ work_arrays._mm_src_blocks[:copied_z_nnz].fill_(-1)
2255
+ if z_aliasing:
2256
+ # If z is aliasing with x or y, need to save topology as well
2257
+ wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
2258
+ wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
2259
+
2260
+ # Fill unmerged mm blocks rows and columns
2261
+ wp.launch(
2262
+ kernel=_bsr_mm_list_coeffs,
2263
+ device=device,
2264
+ dim=mm_nnz - copied_z_nnz,
2265
+ inputs=[
2266
+ copied_z_nnz,
2267
+ mm_nnz,
2268
+ x.nrow,
2269
+ x.offsets,
2270
+ x.columns,
2271
+ y.offsets,
2272
+ y.columns,
2273
+ work_arrays._mm_row_min,
2274
+ work_arrays._mm_block_counts,
2275
+ work_arrays._mm_rows,
2276
+ work_arrays._mm_cols,
2277
+ work_arrays._mm_src_blocks,
2278
+ ],
2279
+ )
2280
+
2281
+ alpha = z.scalar_type(alpha)
2282
+ beta = z.scalar_type(beta)
2283
+
2284
+ if copied_z_nnz > 0:
2285
+ # Save current z values in temporary buffer
2286
+ wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
2287
+
2288
+ if not masked:
2289
+ # Increase dest array size if needed
2290
+ if z.columns.shape[0] < mm_nnz:
2291
+ z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
2292
+
2293
+ from warp._src.context import runtime
2294
+
2295
+ if device.is_cpu:
2296
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
2297
+ else:
2298
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
2299
+
2300
+ nnz_buf, nnz_event = z._setup_nnz_transfer()
2301
+ summed_triplet_offsets = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
2302
+ summed_triplet_indices = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
2303
+
2304
+ with wp.ScopedDevice(z.device):
2305
+ native_func(
2306
+ z.block_size,
2307
+ 0, # scalar_size_in_bytes
2308
+ z.nrow,
2309
+ z.ncol,
2310
+ mm_nnz,
2311
+ None, # device nnz
2312
+ ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
2313
+ ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
2314
+ None, # triplet values
2315
+ 0, # zero_value_mask
2316
+ False, # masked_topology
2317
+ ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
2318
+ ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
2319
+ ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
2320
+ ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
2321
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
2322
+ _optional_ctypes_event(nnz_event),
2323
+ )
2324
+
2325
+ # Resize z to fit mm result if necessary
2326
+ # If we are not reusing the product topology, this needs another synchronization
2327
+ if not reuse_topology:
2328
+ work_arrays.result_nnz = z.nnz_sync() if max_new_nnz is None else mm_nnz
2329
+
2330
+ _bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
2331
+ z.values.zero_()
2332
+
2333
+ if copied_z_nnz > 0:
2334
+ # Add back original z values
2335
+ wp.launch(
2336
+ kernel=_bsr_axpy_add_block,
2337
+ device=device,
2338
+ dim=(copied_z_nnz, z.block_shape[0], z.block_shape[1]),
2339
+ inputs=[
2340
+ 0,
2341
+ beta,
2342
+ work_arrays._mm_rows,
2343
+ work_arrays._mm_cols,
2344
+ z.offsets,
2345
+ z.columns,
2346
+ _as_3d_array(work_arrays._old_z_values, z.block_shape),
2347
+ z.scalar_values,
2348
+ ],
2349
+ )
2350
+
2351
+ max_subblock_dim = 12
2352
+ if tile_size > 0:
2353
+ use_tiles = True
2354
+ elif tile_size < 0:
2355
+ use_tiles = False
2356
+ else:
2357
+ # Heuristic for using tiled variant: few or very large blocks
2358
+ tile_size = 64
2359
+ max_tiles_per_sm = 2048 // tile_size # assume 64 resident warps per SM
2360
+ use_tiles = device.is_cuda and (
2361
+ max(x.block_size, y.block_size, z.block_size) > max_subblock_dim**2
2362
+ or z.nnz < max_tiles_per_sm * device.sm_count
2363
+ )
2364
+
2365
+ if use_tiles:
2366
+ subblock_rows = min(max_subblock_dim, z.block_shape[0])
2367
+ subblock_cols = min(max_subblock_dim, z.block_shape[1])
2368
+
2369
+ wp.launch(
2370
+ kernel=make_bsr_mm_compute_values_tiled_outer(
2371
+ subblock_rows, subblock_cols, x.block_shape[1], z.scalar_type, tile_size
2372
+ ),
2373
+ device=device,
2374
+ dim=(
2375
+ z.nnz,
2376
+ (z.block_shape[0] + subblock_rows - 1) // subblock_rows,
2377
+ (z.block_shape[1] + subblock_cols - 1) // subblock_cols,
2378
+ tile_size,
2379
+ ),
2380
+ block_dim=tile_size,
2381
+ inputs=[
2382
+ alpha,
2383
+ work_arrays._old_z_offsets if x == z else x.offsets,
2384
+ work_arrays._old_z_columns if x == z else x.columns,
2385
+ _as_3d_array(work_arrays._old_z_values, z.block_shape) if x == z else x.scalar_values,
2386
+ work_arrays._old_z_offsets if y == z else y.offsets,
2387
+ work_arrays._old_z_columns if y == z else y.columns,
2388
+ _as_3d_array(work_arrays._old_z_values, z.block_shape) if y == z else y.scalar_values,
2389
+ None if masked else work_arrays._mm_row_min,
2390
+ None if masked else summed_triplet_offsets,
2391
+ None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
2392
+ z.nrow,
2393
+ z.offsets,
2394
+ z.columns,
2395
+ z.scalar_values,
2396
+ ],
2397
+ )
2398
+
2399
+ return z
2400
+
2401
+ # Add mm blocks to z values
2402
+ if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
2403
+ # Result block type is scalar, but operands are matrices
2404
+ # Cast result to (1x1) matrix to perform multiplication
2405
+ mm_values = z.values.view(wp.mat(shape=(1, 1), dtype=z.scalar_type))
2406
+ else:
2407
+ mm_values = z.values
2408
+
2409
+ wp.launch(
2410
+ kernel=_bsr_mm_compute_values,
2411
+ device=device,
2412
+ dim=z.nnz,
2413
+ inputs=[
2414
+ alpha,
2415
+ work_arrays._old_z_offsets if x == z else x.offsets,
2416
+ work_arrays._old_z_columns if x == z else x.columns,
2417
+ work_arrays._old_z_values if x == z else x.values,
2418
+ work_arrays._old_z_offsets if y == z else y.offsets,
2419
+ work_arrays._old_z_columns if y == z else y.columns,
2420
+ work_arrays._old_z_values if y == z else y.values,
2421
+ None if masked else work_arrays._mm_row_min,
2422
+ None if masked else summed_triplet_offsets,
2423
+ None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
2424
+ z.nrow,
2425
+ z.offsets,
2426
+ z.columns,
2427
+ mm_values,
2428
+ ],
2429
+ )
2430
+
2431
+ return z
2432
+
2433
+
2434
+ def make_bsr_mv_kernel(block_cols: int):
2435
+ from warp._src.fem.cache import dynamic_kernel
2436
+
2437
+ @dynamic_kernel(suffix=block_cols, kernel_options={"enable_backward": False})
2438
+ def bsr_mv_kernel(
2439
+ alpha: Any,
2440
+ A_offsets: wp.array(dtype=int),
2441
+ A_columns: wp.array(dtype=int),
2442
+ A_values: wp.array3d(dtype=Any),
2443
+ x: wp.array(dtype=Any),
2444
+ beta: Any,
2445
+ y: wp.array(dtype=Any),
2446
+ ):
2447
+ row, subrow = wp.tid()
2448
+
2449
+ block_rows = A_values.shape[1]
2450
+
2451
+ yi = row * block_rows + subrow
2452
+
2453
+ # zero-initialize with type of y elements
2454
+ scalar_zero = type(alpha)(0)
2455
+ v = scalar_zero
2456
+
2457
+ if alpha != scalar_zero:
2458
+ beg = A_offsets[row]
2459
+ end = A_offsets[row + 1]
2460
+ for block in range(beg, end):
2461
+ xs = A_columns[block] * block_cols
2462
+ for col in range(wp.static(block_cols)):
2463
+ v += A_values[block, subrow, col] * x[xs + col]
2464
+ v *= alpha
2465
+
2466
+ if beta != scalar_zero:
2467
+ v += beta * y[yi]
2468
+
2469
+ y[yi] = v
2470
+
2471
+ return bsr_mv_kernel
2472
+
2473
+
2474
+ def make_bsr_mv_tiled_kernel(tile_size: int):
2475
+ from warp._src.fem.cache import dynamic_kernel
2476
+
2477
+ @dynamic_kernel(suffix=tile_size, kernel_options={"enable_backward": False})
2478
+ def bsr_mv_tiled_kernel(
2479
+ alpha: Any,
2480
+ A_offsets: wp.array(dtype=int),
2481
+ A_columns: wp.array(dtype=int),
2482
+ A_values: wp.array3d(dtype=Any),
2483
+ x: wp.array(dtype=Any),
2484
+ beta: Any,
2485
+ y: wp.array(dtype=Any),
2486
+ ):
2487
+ row, subrow, lane = wp.tid()
2488
+
2489
+ scalar_zero = type(alpha)(0)
2490
+ block_rows = A_values.shape[1]
2491
+ block_cols = A_values.shape[2]
2492
+
2493
+ yi = row * block_rows + subrow
2494
+
2495
+ if beta == scalar_zero:
2496
+ subrow_sum = wp.tile_zeros(shape=(1,), dtype=y.dtype)
2497
+ else:
2498
+ subrow_sum = beta * wp.tile_load(y, 1, yi)
2499
+
2500
+ if alpha != scalar_zero:
2501
+ block_beg = A_offsets[row]
2502
+ col_count = (A_offsets[row + 1] - block_beg) * block_cols
2503
+
2504
+ col = lane
2505
+ lane_sum = y.dtype(0)
2506
+
2507
+ for col in range(lane, col_count, tile_size):
2508
+ block = col // block_cols
2509
+ block_col = col - block * block_cols
2510
+ block += block_beg
2511
+
2512
+ xi = x[A_columns[block] * block_cols + block_col]
2513
+ lane_sum += A_values[block, subrow, block_col] * xi
2514
+
2515
+ lane_sum *= alpha
2516
+ subrow_sum += wp.tile_sum(wp.tile(lane_sum))
2517
+
2518
+ wp.tile_store(y, subrow_sum, yi)
2519
+
2520
+ return bsr_mv_tiled_kernel
2521
+
2522
+
2523
+ def make_bsr_mv_transpose_kernel(block_rows: int):
2524
+ from warp._src.fem.cache import dynamic_kernel
2525
+
2526
+ @dynamic_kernel(suffix=block_rows, kernel_options={"enable_backward": False})
2527
+ def bsr_mv_transpose_kernel(
2528
+ alpha: Any,
2529
+ A_row_count: int,
2530
+ A_offsets: wp.array(dtype=int),
2531
+ A_columns: wp.array(dtype=int),
2532
+ A_values: wp.array3d(dtype=Any),
2533
+ x: wp.array(dtype=Any),
2534
+ y: wp.array(dtype=Any),
2535
+ ):
2536
+ block, subcol = wp.tid()
2537
+
2538
+ row = bsr_row_index(A_offsets, A_row_count, block)
2539
+ if row == -1:
2540
+ return
2541
+
2542
+ block_cols = A_values.shape[2]
2543
+
2544
+ A_block = A_values[block]
2545
+
2546
+ col_sum = type(alpha)(0)
2547
+ for subrow in range(wp.static(block_rows)):
2548
+ col_sum += A_block[subrow, subcol] * x[row * block_rows + subrow]
2549
+
2550
+ wp.atomic_add(y, A_columns[block] * block_cols + subcol, alpha * col_sum)
2551
+
2552
+ return bsr_mv_transpose_kernel
2553
+
2554
+
2555
+ def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
2556
+ # cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
2557
+
2558
+ scalar_count = array.size * type_size(array.dtype)
2559
+ if scalar_count != expected_scalar_count:
2560
+ raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
2561
+
2562
+ if array.ndim == 1 and types_equal(array.dtype, dtype):
2563
+ return array
2564
+
2565
+ if type_scalar_type(array.dtype) != type_scalar_type(dtype):
2566
+ raise ValueError(f"Incompatible scalar types, expected {type_repr(array.dtype)}, got {type_repr(dtype)}")
2567
+
2568
+ if array.ndim > 2:
2569
+ raise ValueError(f"Incompatible array number of dimensions, expected 1 or 2, got {array.ndim}")
2570
+
2571
+ if not array.is_contiguous:
2572
+ raise ValueError("Array must be contiguous")
2573
+
2574
+ vec_length = type_size(dtype)
2575
+ vec_count = scalar_count // vec_length
2576
+ if vec_count * vec_length != scalar_count:
2577
+ raise ValueError(
2578
+ f"Array of shape {array.shape} and type {type_repr(array.dtype)} cannot be reshaped to an array of type {type_repr(dtype)}"
2579
+ )
2580
+
2581
+ def vec_view(array):
2582
+ return wp.array(
2583
+ data=None,
2584
+ ptr=array.ptr,
2585
+ capacity=array.capacity,
2586
+ device=array.device,
2587
+ dtype=dtype,
2588
+ shape=vec_count,
2589
+ grad=None if array.grad is None else vec_view(array.grad),
2590
+ )
2591
+
2592
+ view = vec_view(array)
2593
+ view._ref = array
2594
+ return view
2595
+
2596
+
2597
+ def bsr_mv(
2598
+ A: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
2599
+ x: Array[Vector[Cols, Scalar] | Scalar],
2600
+ y: Array[Vector[Rows, Scalar] | Scalar] | None = None,
2601
+ alpha: Scalar = 1.0,
2602
+ beta: Scalar = 0.0,
2603
+ transpose: bool = False,
2604
+ work_buffer: Array[Vector[Rows, Scalar] | Scalar] | None = None,
2605
+ tile_size: int = 0,
2606
+ ) -> Array[Vector[Rows, Scalar] | Scalar]:
2607
+ """Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
2608
+
2609
+ The ``x`` and ``y`` vectors are allowed to alias.
2610
+
2611
+ Args:
2612
+ A: Read-only, left matrix operand of the matrix-vector product.
2613
+ x: Read-only, right vector operand of the matrix-vector product.
2614
+ y: Mutable affine operand and result vector. If ``y`` is not provided, it will be allocated and treated as zero.
2615
+ alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
2616
+ beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
2617
+ transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
2618
+ work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
2619
+ If provided, the ``work_buffer`` array will be used for this purpose,
2620
+ otherwise a temporary allocation will be performed.
2621
+ tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
2622
+ If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
2623
+ use tiles using using an heuristic based on the matrix shape and number of non-zeros..
2624
+ """
2625
+
2626
+ A, A_scale = _extract_matrix_and_scale(A)
2627
+ alpha *= A_scale
2628
+
2629
+ if transpose:
2630
+ block_shape = A.block_shape[1], A.block_shape[0]
2631
+ nrow, ncol = A.ncol, A.nrow
2632
+ else:
2633
+ block_shape = A.block_shape
2634
+ nrow, ncol = A.nrow, A.ncol
2635
+
2636
+ if y is None:
2637
+ # If no output array is provided, allocate one for convenience
2638
+ y_vec_len = block_shape[0]
2639
+ y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
2640
+ y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype, requires_grad=x.requires_grad)
2641
+ beta = 0.0
2642
+
2643
+ alpha = A.scalar_type(alpha)
2644
+ beta = A.scalar_type(beta)
2645
+
2646
+ device = A.values.device
2647
+ if A.values.device != x.device or A.values.device != y.device:
2648
+ raise ValueError(
2649
+ f"A, x, and y must reside on the same device, got {A.values.device}, {x.device} and {y.device}"
2650
+ )
2651
+
2652
+ if x.ptr == y.ptr:
2653
+ # Aliasing case, need temporary storage
2654
+ if work_buffer is None:
2655
+ work_buffer = wp.empty_like(y)
2656
+ elif work_buffer.size < y.size:
2657
+ raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}, got {work_buffer.size}")
2658
+ elif not types_equal(work_buffer.dtype, y.dtype):
2659
+ raise ValueError(
2660
+ f"Work buffer must have same data type as y, {type_repr(y.dtype)} vs {type_repr(work_buffer.dtype)}"
2661
+ )
2662
+
2663
+ # Save old y values before overwriting vector
2664
+ wp.copy(dest=work_buffer, src=y, count=y.size)
2665
+ x = work_buffer
2666
+
2667
+ try:
2668
+ x_view = _vec_array_view(x, A.scalar_type, expected_scalar_count=ncol * block_shape[1])
2669
+ except ValueError as err:
2670
+ raise ValueError("Incompatible 'x' vector for bsr_mv") from err
2671
+ try:
2672
+ y_view = _vec_array_view(y, A.scalar_type, expected_scalar_count=nrow * block_shape[0])
2673
+ except ValueError as err:
2674
+ raise ValueError("Incompatible 'y' vector for bsr_mv") from err
2675
+
2676
+ # heuristic to use tiled version for long rows
2677
+ if tile_size > 0:
2678
+ use_tiles = True
2679
+ elif tile_size < 0:
2680
+ use_tiles = False
2681
+ else:
2682
+ tile_size = 64
2683
+ use_tiles = device.is_cuda and A.nnz * A.block_size > 2 * tile_size * A.shape[0]
2684
+
2685
+ if transpose:
2686
+ if beta.value == 0.0:
2687
+ y.zero_()
2688
+ elif beta.value != 1.0:
2689
+ wp.launch(
2690
+ kernel=_bsr_scale_kernel,
2691
+ device=y.device,
2692
+ dim=y_view.shape[0],
2693
+ inputs=[beta, y_view],
2694
+ )
2695
+ if alpha.value != 0.0:
2696
+ wp.launch(
2697
+ kernel=make_bsr_mv_transpose_kernel(block_rows=block_shape[1]),
2698
+ device=A.values.device,
2699
+ dim=(A.nnz, block_shape[0]),
2700
+ inputs=[alpha, A.nrow, A.offsets, A.columns, A.scalar_values, x_view, y_view],
2701
+ )
2702
+ elif use_tiles:
2703
+ wp.launch(
2704
+ kernel=make_bsr_mv_tiled_kernel(tile_size),
2705
+ device=A.values.device,
2706
+ dim=(nrow, block_shape[0], tile_size),
2707
+ block_dim=tile_size,
2708
+ inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
2709
+ )
2710
+ else:
2711
+ wp.launch(
2712
+ kernel=make_bsr_mv_kernel(block_cols=block_shape[1]),
2713
+ device=A.values.device,
2714
+ dim=(nrow, block_shape[0]),
2715
+ inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
2716
+ )
2717
+
2718
+ return y