warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/_src/sparse.py ADDED
@@ -0,0 +1,2716 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ import ctypes
19
+ import weakref
20
+ from typing import Any, Generic, TypeVar, Union
21
+
22
+ import warp as wp
23
+ import warp._src.utils
24
+ from warp._src.types import (
25
+ Array,
26
+ Cols,
27
+ Rows,
28
+ Scalar,
29
+ Vector,
30
+ is_array,
31
+ scalar_types,
32
+ type_is_matrix,
33
+ type_repr,
34
+ type_scalar_type,
35
+ type_size,
36
+ type_size_in_bytes,
37
+ type_to_warp,
38
+ types_equal,
39
+ )
40
+
41
+ __all__ = [
42
+ "BsrMatrix",
43
+ "bsr_assign",
44
+ "bsr_axpy",
45
+ "bsr_block_index",
46
+ "bsr_copy",
47
+ "bsr_diag",
48
+ "bsr_from_triplets",
49
+ "bsr_get_diag",
50
+ "bsr_identity",
51
+ "bsr_matrix_t",
52
+ "bsr_mm",
53
+ "bsr_mm_work_arrays",
54
+ "bsr_mv",
55
+ "bsr_row_index",
56
+ "bsr_scale",
57
+ "bsr_set_diag",
58
+ "bsr_set_from_triplets",
59
+ "bsr_set_identity",
60
+ "bsr_set_transpose",
61
+ "bsr_set_zero",
62
+ "bsr_transposed",
63
+ "bsr_zeros",
64
+ ]
65
+
66
+
67
+ # typing hints
68
+
69
+ _BlockType = TypeVar("BlockType") # noqa: PLC0132
70
+
71
+
72
+ class _MatrixBlockType(Generic[Rows, Cols, Scalar]):
73
+ pass
74
+
75
+
76
+ class _ScalarBlockType(Generic[Scalar]):
77
+ pass
78
+
79
+
80
+ BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
81
+
82
+ _struct_cache = {}
83
+ _transfer_buffer_cache = {}
84
+
85
+
86
+ class BsrMatrix(Generic[_BlockType]):
87
+ """Untyped base class for BSR and CSR matrices.
88
+
89
+ Should not be constructed directly but through functions such as :func:`bsr_zeros`.
90
+
91
+ Attributes:
92
+ nrow (int): Number of rows of blocks.
93
+ ncol (int): Number of columns of blocks.
94
+ nnz (int): Upper bound for the number of non-zero blocks, used for
95
+ dimensioning launches. The exact number is at ``offsets[nrow-1]``.
96
+ See also :meth:`nnz_sync`.
97
+ offsets (Array[int]): Array of size at least ``1 + nrow`` such that the
98
+ start and end indices of the blocks of row ``r`` are ``offsets[r]``
99
+ and ``offsets[r+1]``, respectively.
100
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing
101
+ block column indices.
102
+ values (Array[BlockType]): Array of size at least equal to ``nnz``
103
+ containing block values.
104
+ """
105
+
106
+ @property
107
+ def scalar_type(self) -> Scalar:
108
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type."""
109
+ return type_scalar_type(self.values.dtype)
110
+
111
+ @property
112
+ def block_shape(self) -> tuple[int, int]:
113
+ """Shape of the individual blocks."""
114
+ return getattr(self.values.dtype, "_shape_", (1, 1))
115
+
116
+ @property
117
+ def block_size(self) -> int:
118
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
119
+ return type_size(self.values.dtype)
120
+
121
+ @property
122
+ def shape(self) -> tuple[int, int]:
123
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block."""
124
+ block_shape = self.block_shape
125
+ return (self.nrow * block_shape[0], self.ncol * block_shape[1])
126
+
127
+ @property
128
+ def dtype(self) -> type:
129
+ """Data type for individual block values."""
130
+ return self.values.dtype
131
+
132
+ @property
133
+ def device(self) -> wp._src.context.Device:
134
+ """Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
135
+ return self.values.device
136
+
137
+ @property
138
+ def requires_grad(self) -> bool:
139
+ """Read-only property indicating whether the matrix participates in adjoint computations."""
140
+ return self.values.requires_grad
141
+
142
+ @property
143
+ def scalar_values(self) -> wp.array:
144
+ """Accesses the ``values`` array as a 3d scalar array."""
145
+ values_view = _as_3d_array(self.values, self.block_shape)
146
+ values_view._ref = self.values # keep ref in case we're garbage collected
147
+ return values_view
148
+
149
+ def uncompress_rows(self, out: wp.array = None) -> wp.array:
150
+ """Compute the row index for each non-zero block from the compressed row offsets."""
151
+ if out is None:
152
+ out = wp.empty(self.nnz, dtype=int, device=self.device)
153
+
154
+ wp.launch(
155
+ kernel=_bsr_get_block_row,
156
+ device=self.device,
157
+ dim=self.nnz,
158
+ inputs=[self.nrow, self.offsets, out],
159
+ )
160
+ return out
161
+
162
+ def nnz_sync(self) -> int:
163
+ """
164
+ Synchronize the number of non-zeros from the device offsets array to the host.
165
+
166
+ Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
167
+ or, if none has been scheduled yet, starts a new transfer and waits for it to complete.
168
+
169
+ Then updates the host-side nnz upper bound to match the exact one, and returns it.
170
+
171
+ See also :meth:`notify_nnz_async`.
172
+ """
173
+
174
+ buf, event = self._nnz_transfer_if_any()
175
+ if buf is None:
176
+ buf, event = self._copy_nnz_async()
177
+
178
+ if event is not None:
179
+ wp.synchronize_event(event)
180
+ self.nnz = int(buf.numpy()[0])
181
+ return self.nnz
182
+
183
+ def notify_nnz_changed(self, nnz: int | None = None) -> None:
184
+ """Notify the matrix that the number of non-zeros has been changed from outside of the ``warp.sparse`` builtin functions.
185
+
186
+ Should be called in particular when the offsets array has been modified, or when the nnz upper bound has changed.
187
+ Makes sure that the matrix is properly resized and starts the asynchronous transfer of the actual non-zero count.
188
+
189
+ Args:
190
+ nnz: The new upper-bound for the number of non-zeros. If not provided, it will be read from the device offsets array (requires a synchronization).
191
+ """
192
+
193
+ if nnz is None:
194
+ self.nnz_sync()
195
+ else:
196
+ self._copy_nnz_async()
197
+
198
+ _bsr_ensure_fits(self, nnz=nnz)
199
+
200
+ def copy_nnz_async(self) -> None:
201
+ """
202
+ Starts the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
203
+
204
+ Deprecated; prefer :meth:`notify_nnz_changed` instead, which will make sure to resize arrays if necessary.
205
+ """
206
+ wp._src.utils.warn(
207
+ "The `copy_nnz_async` method is deprecated and will be removed in a future version. Prefer `notify_nnz_changed` instead.",
208
+ DeprecationWarning,
209
+ )
210
+ self._copy_nnz_async()
211
+
212
+ def _copy_nnz_async(self) -> tuple[wp.array, wp.Event]:
213
+ buf, event = self._setup_nnz_transfer()
214
+ if buf is not None:
215
+ stream = wp.get_stream(self.device) if self.device.is_cuda else None
216
+ wp.copy(src=self.offsets, dest=buf, src_offset=self.nrow, count=1, stream=stream)
217
+ if event is not None:
218
+ stream.record_event(event, external=True)
219
+ return buf, event
220
+
221
+ def _setup_nnz_transfer(self) -> tuple[wp.array, wp.Event]:
222
+ buf, event = self._nnz_transfer_if_any()
223
+ if buf is not None:
224
+ return buf, event
225
+
226
+ buf, event = _allocate_transfer_buf(self.device)
227
+ if buf is not None:
228
+ # buf may still be None if device is currently capturing
229
+ BsrMatrix.__setattr__(self, "_nnz_transfer", (buf, event))
230
+ weakref.finalize(self, _redeem_transfer_buf, self.device, buf, event)
231
+
232
+ return buf, event
233
+
234
+ def _nnz_transfer_if_any(self) -> tuple[wp.array, wp.Event]:
235
+ return getattr(self, "_nnz_transfer", (None, None))
236
+
237
+ # Overloaded math operators
238
+ def __add__(self, y):
239
+ return bsr_axpy(y, bsr_copy(self))
240
+
241
+ def __iadd__(self, y):
242
+ return bsr_axpy(y, self)
243
+
244
+ def __radd__(self, x):
245
+ return bsr_axpy(x, bsr_copy(self))
246
+
247
+ def __sub__(self, y):
248
+ return bsr_axpy(y, bsr_copy(self), alpha=-1.0)
249
+
250
+ def __rsub__(self, x):
251
+ return bsr_axpy(x, bsr_copy(self), beta=-1.0)
252
+
253
+ def __isub__(self, y):
254
+ return bsr_axpy(y, self, alpha=-1.0)
255
+
256
+ def __mul__(self, y):
257
+ return _BsrScalingExpression(self, y)
258
+
259
+ def __rmul__(self, x):
260
+ return _BsrScalingExpression(self, x)
261
+
262
+ def __imul__(self, y):
263
+ return bsr_scale(self, y)
264
+
265
+ def __matmul__(self, y):
266
+ if isinstance(y, wp.array):
267
+ return bsr_mv(self, y)
268
+
269
+ return bsr_mm(self, y)
270
+
271
+ def __rmatmul__(self, x):
272
+ if isinstance(x, wp.array):
273
+ return bsr_mv(self, x, transpose=True)
274
+
275
+ return bsr_mm(x, self)
276
+
277
+ def __imatmul__(self, y):
278
+ return bsr_mm(self, y, self)
279
+
280
+ def __truediv__(self, y):
281
+ return _BsrScalingExpression(self, 1.0 / y)
282
+
283
+ def __neg__(self):
284
+ return _BsrScalingExpression(self, -1.0)
285
+
286
+ def transpose(self):
287
+ """Return a transposed copy of this matrix."""
288
+ return bsr_transposed(self)
289
+
290
+
291
+ def _allocate_transfer_buf(device):
292
+ if device.ordinal in _transfer_buffer_cache:
293
+ all_, pool = _transfer_buffer_cache[device.ordinal]
294
+ else:
295
+ all_ = []
296
+ pool = []
297
+ _transfer_buffer_cache[device.ordinal] = (all_, pool)
298
+
299
+ if pool:
300
+ return pool.pop()
301
+
302
+ if device.is_capturing:
303
+ return None, None
304
+
305
+ buf = wp.empty(dtype=int, shape=(1,), device="cpu", pinned=device.is_cuda)
306
+ event = wp.Event(device) if device.is_cuda else None
307
+ all_.append((buf, event)) # keep a reference to the buffer and event, prevent garbage collection before redeem
308
+ return buf, event
309
+
310
+
311
+ def _redeem_transfer_buf(device, buf, event):
312
+ all_, pool = _transfer_buffer_cache[device.ordinal]
313
+ pool.append((buf, event))
314
+
315
+
316
+ def bsr_matrix_t(dtype: BlockType):
317
+ dtype = type_to_warp(dtype)
318
+
319
+ if not type_is_matrix(dtype) and dtype not in scalar_types:
320
+ raise ValueError(f"BsrMatrix block type must be either warp matrix or scalar; got {type_repr(dtype)}")
321
+
322
+ class BsrMatrixTyped(BsrMatrix):
323
+ nrow: int
324
+ """Number of rows of blocks."""
325
+ ncol: int
326
+ """Number of columns of blocks."""
327
+ nnz: int
328
+ """Upper bound for the number of non-zeros."""
329
+ offsets: wp.array(dtype=int)
330
+ """Array of size at least ``1 + nrow``."""
331
+ columns: wp.array(dtype=int)
332
+ """Array of size at least equal to ``nnz``."""
333
+ values: wp.array(dtype=dtype)
334
+
335
+ module = wp.get_module(BsrMatrix.__module__)
336
+
337
+ if hasattr(dtype, "_shape_"):
338
+ type_str = f"{type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
339
+ else:
340
+ type_str = dtype.__name__
341
+ key = f"{BsrMatrix.__qualname__}_{type_str}"
342
+
343
+ if key not in _struct_cache:
344
+ BsrMatrixTyped.dtype = dtype # necessary for eval_annotations
345
+ _struct_cache[key] = wp._src.codegen.Struct(
346
+ key=key,
347
+ cls=BsrMatrixTyped,
348
+ module=module,
349
+ )
350
+
351
+ return _struct_cache[key]
352
+
353
+
354
+ def bsr_zeros(
355
+ rows_of_blocks: int,
356
+ cols_of_blocks: int,
357
+ block_type: BlockType,
358
+ device: wp._src.context.Devicelike = None,
359
+ ) -> BsrMatrix:
360
+ """Construct and return an empty BSR or CSR matrix with the given shape.
361
+
362
+ Args:
363
+ bsr: The BSR or CSR matrix to set to zero.
364
+ rows_of_blocks: Number of rows of blocks.
365
+ cols_of_blocks: Number of columns of blocks.
366
+ block_type: Type of individual blocks.
367
+ For CSR matrices, this should be a scalar type.
368
+ For BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`).
369
+ device: Device on which to allocate the matrix arrays.
370
+ """
371
+
372
+ bsr = bsr_matrix_t(block_type)()
373
+
374
+ bsr.nrow = int(rows_of_blocks)
375
+ bsr.ncol = int(cols_of_blocks)
376
+ bsr.nnz = 0
377
+ bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
378
+ bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
379
+ bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
380
+
381
+ return bsr
382
+
383
+
384
+ def _bsr_resize(bsr: BsrMatrix, rows_of_blocks: int | None = None, cols_of_blocks: int | None = None) -> None:
385
+ if rows_of_blocks is not None:
386
+ bsr.nrow = int(rows_of_blocks)
387
+ if cols_of_blocks is not None:
388
+ bsr.ncol = int(cols_of_blocks)
389
+
390
+ if bsr.offsets.size < bsr.nrow + 1:
391
+ bsr.offsets = wp.empty(shape=(bsr.nrow + 1,), dtype=int, device=bsr.offsets.device)
392
+
393
+
394
+ def _bsr_ensure_fits(bsr: BsrMatrix, nnz: int | None = None) -> None:
395
+ if nnz is None:
396
+ nnz = bsr.nnz
397
+ else:
398
+ # update nnz upper bound
399
+ bsr.nnz = int(nnz)
400
+
401
+ if bsr.columns.size < nnz:
402
+ bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
403
+ if bsr.values.size < nnz:
404
+ bsr.values = wp.empty(
405
+ shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device, requires_grad=bsr.values.requires_grad
406
+ )
407
+
408
+
409
+ def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: int | None = None, cols_of_blocks: int | None = None):
410
+ """Set a BSR matrix to zero, possibly changing its size.
411
+
412
+ Args:
413
+ bsr: The BSR or CSR matrix to set to zero.
414
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
415
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
416
+ """
417
+ _bsr_resize(bsr, rows_of_blocks, cols_of_blocks)
418
+
419
+ bsr.offsets.zero_()
420
+ bsr.notify_nnz_changed(nnz=0)
421
+
422
+
423
+ def _as_3d_array(arr, block_shape):
424
+ return wp.array(
425
+ ptr=arr.ptr,
426
+ capacity=arr.capacity,
427
+ device=arr.device,
428
+ dtype=type_scalar_type(arr.dtype),
429
+ shape=(arr.shape[0], *block_shape),
430
+ grad=None if arr.grad is None else _as_3d_array(arr.grad, block_shape),
431
+ )
432
+
433
+
434
+ def _optional_ctypes_pointer(array: wp.array | None, ctype):
435
+ return None if array is None else ctypes.cast(array.ptr, ctypes.POINTER(ctype))
436
+
437
+
438
+ def _optional_ctypes_event(event: wp.Event | None):
439
+ return None if event is None else event.cuda_event
440
+
441
+
442
+ _zero_value_masks = {
443
+ wp.float16: 0x7FFF,
444
+ wp.float32: 0x7FFFFFFF,
445
+ wp.float64: 0x7FFFFFFFFFFFFFFF,
446
+ wp.int8: 0xFF,
447
+ wp.int16: 0xFFFF,
448
+ wp.int32: 0xFFFFFFFF,
449
+ wp.int64: 0xFFFFFFFFFFFFFFFF,
450
+ }
451
+
452
+
453
+ @wp.kernel
454
+ def _bsr_accumulate_triplet_values(
455
+ row_count: int,
456
+ tpl_summed_offsets: wp.array(dtype=int),
457
+ tpl_summed_indices: wp.array(dtype=int),
458
+ tpl_values: wp.array3d(dtype=Any),
459
+ bsr_offsets: wp.array(dtype=int),
460
+ bsr_values: wp.array3d(dtype=Any),
461
+ ):
462
+ block, i, j = wp.tid()
463
+
464
+ if block >= bsr_offsets[row_count]:
465
+ return
466
+
467
+ if block == 0:
468
+ beg = 0
469
+ else:
470
+ beg = tpl_summed_offsets[block - 1]
471
+ end = tpl_summed_offsets[block]
472
+
473
+ val = tpl_values[tpl_summed_indices[beg], i, j]
474
+ for k in range(beg + 1, end):
475
+ val += tpl_values[tpl_summed_indices[k], i, j]
476
+
477
+ bsr_values[block, i, j] = val
478
+
479
+
480
+ def bsr_set_from_triplets(
481
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
482
+ rows: Array[int],
483
+ columns: Array[int],
484
+ values: Array[Scalar | BlockType[Rows, Cols, Scalar]] | None = None,
485
+ count: Array[int] | None = None,
486
+ prune_numerical_zeros: bool = True,
487
+ masked: bool = False,
488
+ ):
489
+ """Fill a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
490
+
491
+ The first dimension of the three input arrays must match and indicates the number of COO triplets.
492
+
493
+ Args:
494
+ dest: Sparse matrix to populate.
495
+ rows: Row index for each non-zero.
496
+ columns: Columns index for each non-zero.
497
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
498
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
499
+ If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
500
+ count: Single-element array indicating the number of triplets. If ``None``, the number of triplets is determined from the shape of
501
+ ``rows`` and ``columns`` arrays.
502
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
503
+ masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
504
+ """
505
+
506
+ if rows.device != columns.device or rows.device != dest.device:
507
+ raise ValueError(
508
+ f"Rows and columns must reside on the destination matrix device, got {rows.device}, {columns.device} and {dest.device}"
509
+ )
510
+
511
+ if rows.shape[0] != columns.shape[0]:
512
+ raise ValueError(
513
+ f"Rows and columns arrays must have the same length, got {rows.shape[0]} and {columns.shape[0]}"
514
+ )
515
+
516
+ if rows.dtype != wp.int32 or columns.dtype != wp.int32:
517
+ raise TypeError("Rows and columns arrays must be of type int32")
518
+
519
+ if count is not None:
520
+ if count.device != rows.device:
521
+ raise ValueError(f"Count and rows must reside on the same device, got {count.device} and {rows.device}")
522
+
523
+ if count.shape != (1,):
524
+ raise ValueError(f"Count array must be a single-element array, got {count.shape}")
525
+
526
+ if count.dtype != wp.int32:
527
+ raise TypeError("Count array must be of type int32")
528
+
529
+ # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
530
+ if values is not None:
531
+ if values.device != rows.device:
532
+ raise ValueError(f"Values and rows must reside on the same device, got {values.device} and {rows.device}")
533
+
534
+ if values.shape[0] != rows.shape[0]:
535
+ raise ValueError(
536
+ f"Values and rows arrays must have the same length, got {values.shape[0]} and {rows.shape[0]}"
537
+ )
538
+
539
+ if values.ndim == 1:
540
+ if not types_equal(values.dtype, dest.values.dtype):
541
+ raise ValueError(
542
+ f"Values array type must correspond to that of the dest matrix, got {type_repr(values.dtype)} and {type_repr(dest.values.dtype)}"
543
+ )
544
+ elif values.ndim == 3:
545
+ if values.shape[1:] != dest.block_shape:
546
+ raise ValueError(
547
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
548
+ )
549
+
550
+ if type_scalar_type(values.dtype) != dest.scalar_type:
551
+ raise ValueError(
552
+ f"Scalar type of values array ({type_repr(values.dtype)}) should correspond to that of matrix ({type_repr(dest.scalar_type)})"
553
+ )
554
+ else:
555
+ raise ValueError(f"Number of dimension for values array should be 1 or 3, got {values.ndim}")
556
+
557
+ if prune_numerical_zeros and not values.is_contiguous:
558
+ raise ValueError("Values array should be contiguous for numerical zero pruning")
559
+
560
+ nnz = rows.shape[0]
561
+ if nnz == 0:
562
+ bsr_set_zero(dest)
563
+ return
564
+
565
+ # Increase dest array sizes if needed
566
+ if not masked:
567
+ _bsr_ensure_fits(dest, nnz=nnz)
568
+
569
+ device = dest.values.device
570
+ scalar_type = dest.scalar_type
571
+ zero_value_mask = _zero_value_masks.get(scalar_type, 0) if prune_numerical_zeros else 0
572
+
573
+ # compute the BSR topology
574
+
575
+ from warp._src.context import runtime
576
+
577
+ if device.is_cpu:
578
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
579
+ else:
580
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
581
+
582
+ nnz_buf, nnz_event = dest._setup_nnz_transfer()
583
+ summed_triplet_offsets = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
584
+ summed_triplet_indices = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
585
+
586
+ with wp.ScopedDevice(device):
587
+ native_func(
588
+ dest.block_size,
589
+ type_size_in_bytes(scalar_type),
590
+ dest.nrow,
591
+ dest.ncol,
592
+ nnz,
593
+ _optional_ctypes_pointer(count, ctype=ctypes.c_int32),
594
+ ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
595
+ ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
596
+ _optional_ctypes_pointer(values, ctype=ctypes.c_int32),
597
+ zero_value_mask,
598
+ masked,
599
+ ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
600
+ ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
601
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
602
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
603
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
604
+ _optional_ctypes_event(nnz_event),
605
+ )
606
+
607
+ # now accumulate repeated blocks
608
+ wp.launch(
609
+ _bsr_accumulate_triplet_values,
610
+ dim=(nnz, *dest.block_shape),
611
+ inputs=[
612
+ dest.nrow,
613
+ summed_triplet_offsets,
614
+ summed_triplet_indices,
615
+ _as_3d_array(values, dest.block_shape),
616
+ dest.offsets,
617
+ ],
618
+ outputs=[dest.scalar_values],
619
+ )
620
+
621
+
622
+ def bsr_from_triplets(
623
+ rows_of_blocks: int,
624
+ cols_of_blocks: int,
625
+ rows: Array[int],
626
+ columns: Array[int],
627
+ values: Array[Scalar | BlockType[Rows, Cols, Scalar]],
628
+ prune_numerical_zeros: bool = True,
629
+ ):
630
+ """Constructs a BSR matrix with values defined by coordinate-oriented (COO) triplets.
631
+
632
+ The first dimension of the three input arrays must match and indicates the number of COO triplets.
633
+
634
+ Args:
635
+ rows_of_blocks: Number of rows of blocks.
636
+ cols_of_blocks: Number of columns of blocks.
637
+ rows: Row index for each non-zero.
638
+ columns: Columns index for each non-zero.
639
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
640
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
641
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
642
+ """
643
+
644
+ if values.ndim == 3:
645
+ block_type = wp.mat(shape=values.shape[1:], dtype=values.dtype)
646
+ else:
647
+ block_type = values.dtype
648
+
649
+ A = bsr_zeros(
650
+ rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
651
+ )
652
+ A.values.requires_grad = values.requires_grad
653
+ bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
654
+ return A
655
+
656
+
657
+ class _BsrExpression(Generic[_BlockType]):
658
+ pass
659
+
660
+
661
+ class _BsrScalingExpression(_BsrExpression):
662
+ def __init__(self, mat, scale):
663
+ self.mat = mat
664
+ self.scale = scale
665
+
666
+ def eval(self):
667
+ return bsr_copy(self)
668
+
669
+ @property
670
+ def nrow(self) -> int:
671
+ return self.mat.nrow
672
+
673
+ @property
674
+ def ncol(self) -> int:
675
+ return self.mat.ncol
676
+
677
+ @property
678
+ def nnz(self) -> int:
679
+ return self.mat.nnz
680
+
681
+ @property
682
+ def offsets(self) -> wp.array:
683
+ return self.mat.offsets
684
+
685
+ @property
686
+ def columns(self) -> wp.array:
687
+ return self.mat.columns
688
+
689
+ @property
690
+ def scalar_type(self) -> Scalar:
691
+ return self.mat.scalar_type
692
+
693
+ @property
694
+ def block_shape(self) -> tuple[int, int]:
695
+ return self.mat.block_shape
696
+
697
+ @property
698
+ def block_size(self) -> int:
699
+ return self.mat.block_size
700
+
701
+ @property
702
+ def shape(self) -> tuple[int, int]:
703
+ return self.mat.shape
704
+
705
+ @property
706
+ def dtype(self) -> type:
707
+ return self.mat.dtype
708
+
709
+ @property
710
+ def requires_grad(self) -> bool:
711
+ return self.mat.requires_grad
712
+
713
+ @property
714
+ def device(self) -> wp._src.context.Device:
715
+ return self.mat.device
716
+
717
+ # Overloaded math operators
718
+ def __add__(self, y):
719
+ return bsr_axpy(y, bsr_copy(self.mat), alpha=self.scale)
720
+
721
+ def __radd__(self, x):
722
+ return bsr_axpy(x, bsr_copy(self.mat), beta=self.scale)
723
+
724
+ def __sub__(self, y):
725
+ return bsr_axpy(y, bsr_copy(self.mat), alpha=-self.scale)
726
+
727
+ def __rsub__(self, x):
728
+ return bsr_axpy(x, bsr_copy(self.mat), beta=-self.scale)
729
+
730
+ def __mul__(self, y):
731
+ return _BsrScalingExpression(self.mat, y * self.scale)
732
+
733
+ def __rmul__(self, x):
734
+ return _BsrScalingExpression(self.mat, x * self.scale)
735
+
736
+ def __matmul__(self, y):
737
+ if isinstance(y, wp.array):
738
+ return bsr_mv(self.mat, y, alpha=self.scale)
739
+
740
+ return bsr_mm(self.mat, y, alpha=self.scale)
741
+
742
+ def __rmatmul__(self, x):
743
+ if isinstance(x, wp.array):
744
+ return bsr_mv(self.mat, x, alpha=self.scale, transpose=True)
745
+
746
+ return bsr_mm(x, self.mat, alpha=self.scale)
747
+
748
+ def __truediv__(self, y):
749
+ return _BsrScalingExpression(self.mat, self.scale / y)
750
+
751
+ def __neg__(self):
752
+ return _BsrScalingExpression(self.mat, -self.scale)
753
+
754
+ def transpose(self):
755
+ """Returns a transposed copy of this matrix"""
756
+ return _BsrScalingExpression(self.mat.transpose(), self.scale)
757
+
758
+
759
+ BsrMatrixOrExpression = Union[BsrMatrix[_BlockType], _BsrExpression[_BlockType]]
760
+
761
+
762
+ def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
763
+ if isinstance(bsr, BsrMatrix):
764
+ return bsr, 1.0
765
+ if isinstance(bsr, _BsrScalingExpression):
766
+ return bsr.mat, bsr.scale
767
+
768
+ raise ValueError("Argument cannot be interpreted as a BsrMatrix")
769
+
770
+
771
+ @wp.func
772
+ def bsr_row_index(
773
+ offsets: wp.array(dtype=int),
774
+ row_count: int,
775
+ block_index: int,
776
+ ) -> int:
777
+ """Returns the index of the row containing a given block, or -1 if no such row exists.
778
+
779
+ Args:
780
+ offsets: Array of size at least ``1 + row_count`` containing the offsets of the blocks in each row.
781
+ row_count: Number of rows of blocks.
782
+ block_index: Index of the block.
783
+ """
784
+ return wp.where(block_index < offsets[row_count], wp.lower_bound(offsets, 0, row_count + 1, block_index + 1), 0) - 1
785
+
786
+
787
+ @wp.func
788
+ def bsr_block_index(
789
+ row: int,
790
+ col: int,
791
+ bsr_offsets: wp.array(dtype=int),
792
+ bsr_columns: wp.array(dtype=int),
793
+ ) -> int:
794
+ """
795
+ Returns the index of the block at block-coordinates (row, col), or -1 if no such block exists.
796
+ Assumes that the segments of ``bsr_columns`` corresponding to each row are sorted.
797
+
798
+ Args:
799
+ row: Row of the block.
800
+ col: Column of the block.
801
+ bsr_offsets: Array of size at least ``1 + row`` containing the offsets of the blocks in each row.
802
+ bsr_columns: Array of size at least equal to ``bsr_offsets[row + 1]`` containing the column indices of the blocks.
803
+ """
804
+
805
+ if row < 0:
806
+ return -1
807
+
808
+ row_beg = bsr_offsets[row]
809
+ row_end = bsr_offsets[row + 1]
810
+
811
+ if row_beg == row_end:
812
+ return -1
813
+
814
+ block_index = wp.lower_bound(bsr_columns, row_beg, row_end, col)
815
+ return wp.where(bsr_columns[block_index] == col, block_index, -1)
816
+
817
+
818
+ @wp.kernel(enable_backward=False)
819
+ def _bsr_assign_list_blocks(
820
+ src_subrows: int,
821
+ src_subcols: int,
822
+ dest_subrows: int,
823
+ dest_subcols: int,
824
+ src_row_count: int,
825
+ src_offsets: wp.array(dtype=int),
826
+ src_columns: wp.array(dtype=int),
827
+ dest_rows: wp.array(dtype=int),
828
+ dest_cols: wp.array(dtype=int),
829
+ ):
830
+ block, subrow, subcol = wp.tid()
831
+ dest_block = (block * src_subcols + subcol) * src_subrows + subrow
832
+
833
+ row = bsr_row_index(src_offsets, src_row_count, block)
834
+ if row == -1:
835
+ dest_rows[dest_block] = row # invalid
836
+ dest_cols[dest_block] = row
837
+ else:
838
+ dest_subrow = row * src_subrows + subrow
839
+ dest_subcol = src_columns[block] * src_subcols + subcol
840
+ dest_rows[dest_block] = dest_subrow // dest_subrows
841
+ dest_cols[dest_block] = dest_subcol // dest_subcols
842
+
843
+
844
+ @wp.kernel
845
+ def _bsr_assign_copy_blocks(
846
+ scale: Any,
847
+ src_subrows: int,
848
+ src_subcols: int,
849
+ dest_subrows: int,
850
+ dest_subcols: int,
851
+ src_row_count: int,
852
+ src_offsets: wp.array(dtype=int),
853
+ src_columns: wp.array(dtype=int),
854
+ src_values: wp.array3d(dtype=Any),
855
+ dest_offsets: wp.array(dtype=int),
856
+ dest_columns: wp.array(dtype=int),
857
+ dest_values: wp.array3d(dtype=Any),
858
+ ):
859
+ src_block = wp.tid()
860
+ src_block, subrow, subcol = wp.tid()
861
+
862
+ src_row = bsr_row_index(src_offsets, src_row_count, src_block)
863
+ if src_row == -1:
864
+ return
865
+
866
+ src_col = src_columns[src_block]
867
+
868
+ dest_subrow = src_row * src_subrows + subrow
869
+ dest_subcol = src_col * src_subcols + subcol
870
+ dest_row = dest_subrow // dest_subrows
871
+ dest_col = dest_subcol // dest_subcols
872
+
873
+ dest_block = bsr_block_index(dest_row, dest_col, dest_offsets, dest_columns)
874
+ if dest_block == -1:
875
+ return
876
+
877
+ split_row = dest_subrow - dest_subrows * dest_row
878
+ split_col = dest_subcol - dest_subcols * dest_col
879
+
880
+ rows_per_subblock = src_values.shape[1] // src_subrows
881
+ cols_per_subblock = src_values.shape[2] // src_subcols
882
+
883
+ dest_base_i = split_row * rows_per_subblock
884
+ dest_base_j = split_col * cols_per_subblock
885
+
886
+ src_base_i = subrow * rows_per_subblock
887
+ src_base_j = subcol * cols_per_subblock
888
+
889
+ for i in range(rows_per_subblock):
890
+ for j in range(cols_per_subblock):
891
+ dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
892
+ scale * src_values[src_block, i + src_base_i, j + src_base_j]
893
+ )
894
+
895
+
896
+ def bsr_assign(
897
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
898
+ src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
899
+ structure_only: bool = False,
900
+ masked: bool = False,
901
+ ):
902
+ """Copy the content of the ``src`` BSR matrix to ``dest``.
903
+
904
+ Args:
905
+ src: Matrix to be copied.
906
+ dest: Destination matrix. May have a different block shape or scalar type
907
+ than ``src``, in which case the required casting will be performed.
908
+ structure_only: If ``True``, only the non-zero indices are copied, and uninitialized value storage is allocated
909
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
910
+ casting if the two matrices use distinct scalar types.
911
+ masked: If ``True``, keep the non-zero topology of ``dest`` unchanged.
912
+ """
913
+
914
+ src, src_scale = _extract_matrix_and_scale(src)
915
+
916
+ if dest.values.device != src.values.device:
917
+ raise ValueError("Source and destination matrices must reside on the same device")
918
+
919
+ if src.block_shape[0] >= dest.block_shape[0]:
920
+ src_subrows = src.block_shape[0] // dest.block_shape[0]
921
+ dest_subrows = 1
922
+ else:
923
+ dest_subrows = dest.block_shape[0] // src.block_shape[0]
924
+ src_subrows = 1
925
+
926
+ if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
927
+ raise ValueError(
928
+ f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {dest.block_shape[0]}, {src.block_shape[0]})"
929
+ )
930
+
931
+ if src.block_shape[1] >= dest.block_shape[1]:
932
+ src_subcols = src.block_shape[1] // dest.block_shape[1]
933
+ dest_subcols = 1
934
+ else:
935
+ dest_subcols = dest.block_shape[1] // src.block_shape[1]
936
+ src_subcols = 1
937
+
938
+ if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
939
+ raise ValueError(
940
+ f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {dest.block_shape[1]}, {src.block_shape[1]})"
941
+ )
942
+
943
+ dest_nrow = (src.nrow * src_subrows) // dest_subrows
944
+ dest_ncol = (src.ncol * src_subcols) // dest_subcols
945
+
946
+ if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
947
+ raise ValueError(
948
+ f"The requested block shape {dest.block_shape} does not evenly divide the source matrix of total size {src.shape}"
949
+ )
950
+
951
+ nnz_alloc = src.nnz * src_subrows * src_subcols
952
+ if masked:
953
+ if dest_nrow != dest.nrow or dest_ncol != dest.ncol:
954
+ raise ValueError(
955
+ f"Incompatible destination matrix size, expected ({dest_nrow}, {dest_ncol}), got ({dest.nrow}, {dest.ncol})"
956
+ )
957
+ else:
958
+ _bsr_resize(dest, rows_of_blocks=dest_nrow, cols_of_blocks=dest_ncol)
959
+
960
+ if dest.block_shape == src.block_shape and not masked:
961
+ # Direct copy
962
+
963
+ wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
964
+ dest.notify_nnz_changed(nnz=nnz_alloc)
965
+
966
+ if nnz_alloc > 0:
967
+ wp.copy(dest=dest.columns, src=src.columns, count=nnz_alloc)
968
+
969
+ if not structure_only:
970
+ warp._src.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
971
+ bsr_scale(dest, src_scale)
972
+
973
+ else:
974
+ if not masked:
975
+ # Compute destination rows and columns
976
+ dest_rows = wp.empty(nnz_alloc, dtype=int, device=dest.device)
977
+ dest_cols = wp.empty(nnz_alloc, dtype=int, device=dest.device)
978
+ wp.launch(
979
+ _bsr_assign_list_blocks,
980
+ dim=(src.nnz, src_subrows, src_subcols),
981
+ device=dest.device,
982
+ inputs=[
983
+ src_subrows,
984
+ src_subcols,
985
+ dest_subrows,
986
+ dest_subcols,
987
+ src.nrow,
988
+ src.offsets,
989
+ src.columns,
990
+ dest_rows,
991
+ dest_cols,
992
+ ],
993
+ )
994
+
995
+ _bsr_ensure_fits(dest, nnz=nnz_alloc)
996
+
997
+ # Compute destination offsets from triplets
998
+ from warp._src.context import runtime
999
+
1000
+ if dest.device.is_cpu:
1001
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
1002
+ else:
1003
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
1004
+
1005
+ nnz_buf, nnz_event = dest._setup_nnz_transfer()
1006
+ with wp.ScopedDevice(dest.device):
1007
+ native_func(
1008
+ dest.block_size,
1009
+ 0, # scalar_size_in_bytes
1010
+ dest.nrow,
1011
+ dest.ncol,
1012
+ nnz_alloc,
1013
+ None, # device nnz
1014
+ ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1015
+ ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1016
+ None, # triplet values
1017
+ 0, # zero_value_mask
1018
+ masked,
1019
+ None, # summed block offsets
1020
+ None, # summed block indices
1021
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1022
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1023
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
1024
+ _optional_ctypes_event(nnz_event),
1025
+ )
1026
+
1027
+ # copy block values
1028
+ if not structure_only:
1029
+ dest.values.zero_()
1030
+ wp.launch(
1031
+ _bsr_assign_copy_blocks,
1032
+ dim=(src.nnz, src_subrows, src_subcols),
1033
+ device=dest.device,
1034
+ inputs=[
1035
+ src.scalar_type(src_scale),
1036
+ src_subrows,
1037
+ src_subcols,
1038
+ dest_subrows,
1039
+ dest_subcols,
1040
+ src.nrow,
1041
+ src.offsets,
1042
+ src.columns,
1043
+ src.scalar_values,
1044
+ dest.offsets,
1045
+ dest.columns,
1046
+ dest.scalar_values,
1047
+ ],
1048
+ )
1049
+
1050
+
1051
+ def bsr_copy(
1052
+ A: BsrMatrixOrExpression,
1053
+ scalar_type: Scalar | None = None,
1054
+ block_shape: tuple[int, int] | None = None,
1055
+ structure_only: bool = False,
1056
+ ):
1057
+ """Return a copy of matrix ``A``, possibly changing its scalar type.
1058
+
1059
+ Args:
1060
+ A: Matrix to be copied.
1061
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from ``A``.
1062
+ block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from ``A``.
1063
+ Both dimensions of ``block_shape`` must be either a multiple or an exact divider of the ones from ``A``.
1064
+ structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
1065
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
1066
+ casting if the two matrices use distinct scalar types.
1067
+ """
1068
+ if scalar_type is None:
1069
+ scalar_type = A.scalar_type
1070
+ if block_shape is None:
1071
+ block_shape = A.block_shape
1072
+
1073
+ if block_shape == (1, 1):
1074
+ block_type = scalar_type
1075
+ else:
1076
+ block_type = wp.mat(shape=block_shape, dtype=scalar_type)
1077
+
1078
+ copy = bsr_zeros(
1079
+ rows_of_blocks=A.nrow,
1080
+ cols_of_blocks=A.ncol,
1081
+ block_type=block_type,
1082
+ device=A.device,
1083
+ )
1084
+ copy.values.requires_grad = A.requires_grad
1085
+ bsr_assign(dest=copy, src=A, structure_only=structure_only)
1086
+ return copy
1087
+
1088
+
1089
+ @wp.kernel
1090
+ def _bsr_transpose_values(
1091
+ col_count: int,
1092
+ scale: Any,
1093
+ bsr_offsets: wp.array(dtype=int),
1094
+ bsr_columns: wp.array(dtype=int),
1095
+ bsr_values: wp.array3d(dtype=Any),
1096
+ block_index_map: wp.array(dtype=int),
1097
+ transposed_bsr_offsets: wp.array(dtype=int),
1098
+ transposed_bsr_columns: wp.array(dtype=int),
1099
+ transposed_bsr_values: wp.array3d(dtype=Any),
1100
+ ):
1101
+ block, i, j = wp.tid()
1102
+
1103
+ if block >= transposed_bsr_offsets[col_count]:
1104
+ return
1105
+
1106
+ if block_index_map:
1107
+ src_block = block_index_map[block]
1108
+ else:
1109
+ row = bsr_row_index(transposed_bsr_offsets, col_count, block)
1110
+ col = transposed_bsr_columns[block]
1111
+ src_block = bsr_block_index(col, row, bsr_offsets, bsr_columns)
1112
+ if src_block == -1:
1113
+ return
1114
+
1115
+ transposed_bsr_values[block, i, j] = bsr_values[src_block, j, i] * scale
1116
+
1117
+
1118
+ def bsr_set_transpose(
1119
+ dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
1120
+ src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
1121
+ masked: bool = False,
1122
+ ):
1123
+ """
1124
+ Assign the transposed matrix ``src`` to matrix ``dest``.
1125
+
1126
+ Args:
1127
+ dest: Sparse matrix to populate.
1128
+ src: Sparse matrix to transpose.
1129
+ masked: If ``True``, keep the non-zero topology of ``dest`` unchanged.
1130
+ """
1131
+
1132
+ src, src_scale = _extract_matrix_and_scale(src)
1133
+
1134
+ if dest.values.device != src.values.device:
1135
+ raise ValueError(
1136
+ f"All arguments must reside on the same device, got {dest.values.device} and {src.values.device}"
1137
+ )
1138
+
1139
+ if dest.scalar_type != src.scalar_type:
1140
+ raise ValueError(f"All arguments must have the same scalar type, got {dest.scalar_type} and {src.scalar_type}")
1141
+
1142
+ transpose_block_shape = src.block_shape[::-1]
1143
+
1144
+ if dest.block_shape != transpose_block_shape:
1145
+ raise ValueError(f"Destination block shape must be {transpose_block_shape}, got {dest.block_shape}")
1146
+
1147
+ if masked:
1148
+ if dest.nrow != src.ncol or dest.ncol != src.nrow:
1149
+ raise ValueError(
1150
+ f"Destination matrix must have {src.ncol} rows and {src.nrow} columns, got {dest.nrow} and {dest.ncol}"
1151
+ )
1152
+ block_index_map = None
1153
+ dest.values.zero_()
1154
+ else:
1155
+ _bsr_resize(dest, rows_of_blocks=src.ncol, cols_of_blocks=src.nrow)
1156
+
1157
+ nnz = src.nnz
1158
+ if nnz == 0:
1159
+ bsr_set_zero(dest)
1160
+ return
1161
+
1162
+ # Increase dest array sizes if needed
1163
+ _bsr_ensure_fits(dest, nnz=nnz)
1164
+
1165
+ from warp._src.context import runtime
1166
+
1167
+ if dest.values.device.is_cpu:
1168
+ native_func = runtime.core.wp_bsr_transpose_host
1169
+ else:
1170
+ native_func = runtime.core.wp_bsr_transpose_device
1171
+
1172
+ block_index_map = wp.empty(shape=2 * nnz, dtype=int, device=src.device)
1173
+
1174
+ with wp.ScopedDevice(dest.device):
1175
+ native_func(
1176
+ src.nrow,
1177
+ src.ncol,
1178
+ nnz,
1179
+ ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1180
+ ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1181
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1182
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1183
+ ctypes.cast(block_index_map.ptr, ctypes.POINTER(ctypes.c_int32)),
1184
+ )
1185
+
1186
+ dest._copy_nnz_async()
1187
+
1188
+ wp.launch(
1189
+ _bsr_transpose_values,
1190
+ dim=(dest.nnz, *dest.block_shape),
1191
+ device=dest.device,
1192
+ inputs=[
1193
+ src.ncol,
1194
+ dest.scalar_type(src_scale),
1195
+ src.offsets,
1196
+ src.columns,
1197
+ src.scalar_values,
1198
+ block_index_map,
1199
+ dest.offsets,
1200
+ dest.columns,
1201
+ ],
1202
+ outputs=[dest.scalar_values],
1203
+ )
1204
+
1205
+
1206
+ def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
1207
+ """Return a copy of the transposed matrix ``A``."""
1208
+
1209
+ if A.block_shape == (1, 1):
1210
+ block_type = A.values.dtype
1211
+ else:
1212
+ block_type = wp.mat(shape=A.block_shape[::-1], dtype=A.scalar_type)
1213
+
1214
+ transposed = bsr_zeros(
1215
+ rows_of_blocks=A.ncol,
1216
+ cols_of_blocks=A.nrow,
1217
+ block_type=block_type,
1218
+ device=A.device,
1219
+ )
1220
+ transposed.values.requires_grad = A.requires_grad
1221
+ bsr_set_transpose(dest=transposed, src=A)
1222
+ return transposed
1223
+
1224
+
1225
+ @wp.kernel
1226
+ def _bsr_get_diag_kernel(
1227
+ scale: Any,
1228
+ A_offsets: wp.array(dtype=int),
1229
+ A_columns: wp.array(dtype=int),
1230
+ A_values: wp.array3d(dtype=Any),
1231
+ out: wp.array3d(dtype=Any),
1232
+ ):
1233
+ row, br, bc = wp.tid()
1234
+
1235
+ diag = bsr_block_index(row, row, A_offsets, A_columns)
1236
+ if diag != -1:
1237
+ out[row, br, bc] = scale * A_values[diag, br, bc]
1238
+
1239
+
1240
+ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: Array[BlockType] | None = None) -> Array[BlockType]:
1241
+ """Return the array of blocks that constitute the diagonal of a sparse matrix.
1242
+
1243
+ Args:
1244
+ A: The sparse matrix from which to extract the diagonal.
1245
+ out: If provided, the array into which to store the diagonal blocks.
1246
+ """
1247
+
1248
+ A, scale = _extract_matrix_and_scale(A)
1249
+
1250
+ dim = min(A.nrow, A.ncol)
1251
+
1252
+ if out is None:
1253
+ out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
1254
+ else:
1255
+ if not types_equal(out.dtype, A.values.dtype):
1256
+ raise ValueError(f"Output array must have type {A.values.dtype}, got {out.dtype}")
1257
+ if out.device != A.values.device:
1258
+ raise ValueError(f"Output array must reside on device {A.values.device}, got {out.device}")
1259
+ if out.shape[0] < dim:
1260
+ raise ValueError(f"Output array must be of length at least {dim}, got {out.shape[0]}")
1261
+
1262
+ wp.launch(
1263
+ kernel=_bsr_get_diag_kernel,
1264
+ dim=(dim, *A.block_shape),
1265
+ device=A.values.device,
1266
+ inputs=[A.scalar_type(scale), A.offsets, A.columns, A.scalar_values, _as_3d_array(out, A.block_shape)],
1267
+ )
1268
+
1269
+ return out
1270
+
1271
+
1272
+ @wp.kernel(enable_backward=False)
1273
+ def _bsr_set_diag_kernel(
1274
+ nnz: int,
1275
+ A_offsets: wp.array(dtype=int),
1276
+ A_columns: wp.array(dtype=int),
1277
+ ):
1278
+ row = wp.tid()
1279
+ A_offsets[row] = wp.min(row, nnz)
1280
+ if row < nnz:
1281
+ A_columns[row] = row
1282
+
1283
+
1284
+ def bsr_set_diag(
1285
+ A: BsrMatrix[BlockType],
1286
+ diag: BlockType | Array[BlockType],
1287
+ rows_of_blocks: int | None = None,
1288
+ cols_of_blocks: int | None = None,
1289
+ ) -> None:
1290
+ """Set ``A`` as a block-diagonal matrix.
1291
+
1292
+ Args:
1293
+ A: The sparse matrix to modify.
1294
+ diag: Specifies the values for diagonal blocks. Can be one of:
1295
+
1296
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1297
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1298
+ - ``None``: Diagonal block values are left uninitialized
1299
+
1300
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
1301
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
1302
+
1303
+ The shape of the matrix will be defined one of the following, in this order:
1304
+
1305
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1306
+ If only one is given, the second is assumed equal.
1307
+ - The first dimension of ``diag``, if ``diag`` is an array
1308
+ - The current dimensions of ``A`` otherwise
1309
+ """
1310
+
1311
+ if rows_of_blocks is None and cols_of_blocks is not None:
1312
+ rows_of_blocks = cols_of_blocks
1313
+ if cols_of_blocks is None and rows_of_blocks is not None:
1314
+ cols_of_blocks = rows_of_blocks
1315
+
1316
+ if is_array(diag):
1317
+ if rows_of_blocks is None:
1318
+ rows_of_blocks = diag.shape[0]
1319
+ cols_of_blocks = diag.shape[0]
1320
+
1321
+ if rows_of_blocks is not None:
1322
+ _bsr_resize(A, rows_of_blocks, cols_of_blocks)
1323
+
1324
+ nnz = min(A.nrow, A.ncol)
1325
+ A.notify_nnz_changed(nnz=nnz) # notify change of nnz upper bound
1326
+
1327
+ wp.launch(
1328
+ kernel=_bsr_set_diag_kernel,
1329
+ dim=nnz + 1,
1330
+ device=A.offsets.device,
1331
+ inputs=[nnz, A.offsets, A.columns],
1332
+ )
1333
+
1334
+ A.notify_nnz_changed(nnz=nnz) # notify change of offsets
1335
+
1336
+ if is_array(diag):
1337
+ wp.copy(src=diag, dest=A.values, count=nnz)
1338
+ elif diag is not None:
1339
+ A.values.fill_(diag)
1340
+
1341
+
1342
+ def bsr_diag(
1343
+ diag: BlockType | Array[BlockType] | None = None,
1344
+ rows_of_blocks: int | None = None,
1345
+ cols_of_blocks: int | None = None,
1346
+ block_type: BlockType | None = None,
1347
+ device=None,
1348
+ ) -> BsrMatrix[BlockType]:
1349
+ """Create and return a block-diagonal BSR matrix from an given block value or array of block values.
1350
+
1351
+ Args:
1352
+ diag: Specifies the values for diagonal blocks. Can be one of:
1353
+
1354
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1355
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1356
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
1357
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
1358
+ block_type: If ``diag`` is ``None``, block type of the matrix. Otherwise deduced from ``diag``
1359
+ device: If ``diag`` is not a Warp array, device on which to allocate the matrix. Otherwise deduced from ``diag``
1360
+
1361
+ The shape of the matrix will be defined one of the following, in this order:
1362
+
1363
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1364
+ If only one is given, the second is assumed equal.
1365
+ - The first dimension of ``diag`` if ``diag`` is an array.
1366
+ """
1367
+
1368
+ if rows_of_blocks is None and cols_of_blocks is not None:
1369
+ rows_of_blocks = cols_of_blocks
1370
+ if cols_of_blocks is None and rows_of_blocks is not None:
1371
+ cols_of_blocks = rows_of_blocks
1372
+
1373
+ if is_array(diag):
1374
+ if rows_of_blocks is None:
1375
+ rows_of_blocks = diag.shape[0]
1376
+ cols_of_blocks = diag.shape[0]
1377
+
1378
+ block_type = diag.dtype
1379
+ device = diag.device
1380
+ else:
1381
+ if rows_of_blocks is None:
1382
+ raise ValueError(
1383
+ "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
1384
+ )
1385
+
1386
+ if block_type is None:
1387
+ if diag is None:
1388
+ raise ValueError("Either `diag` or `block_type` needs to be provided")
1389
+
1390
+ block_type = type(diag)
1391
+ if not type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
1392
+ block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
1393
+
1394
+ A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
1395
+ if is_array(diag):
1396
+ A.values.requires_grad = diag.requires_grad
1397
+ bsr_set_diag(A, diag)
1398
+ return A
1399
+
1400
+
1401
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: int | None = None) -> None:
1402
+ """Set ``A`` as the identity matrix.
1403
+
1404
+ Args:
1405
+ A: The sparse matrix to modify.
1406
+ rows_of_blocks: If provided, the matrix will be resized as a square
1407
+ matrix with ``rows_of_blocks`` rows and columns.
1408
+ """
1409
+
1410
+ if A.block_shape == (1, 1):
1411
+ identity = A.scalar_type(1.0)
1412
+ else:
1413
+ from numpy import eye
1414
+
1415
+ identity = eye(A.block_shape[0])
1416
+
1417
+ bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
1418
+
1419
+
1420
+ def bsr_identity(
1421
+ rows_of_blocks: int,
1422
+ block_type: BlockType[Rows, Rows, Scalar],
1423
+ device: wp._src.context.Devicelike = None,
1424
+ ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
1425
+ """Create and return a square identity matrix.
1426
+
1427
+ Args:
1428
+ rows_of_blocks: Number of rows and columns of blocks in the created matrix.
1429
+ block_type: Block type for the newly created matrix. Must be square
1430
+ device: Device onto which to allocate the data arrays
1431
+ """
1432
+ A = bsr_zeros(
1433
+ rows_of_blocks=rows_of_blocks,
1434
+ cols_of_blocks=rows_of_blocks,
1435
+ block_type=block_type,
1436
+ device=device,
1437
+ )
1438
+ bsr_set_identity(A)
1439
+ return A
1440
+
1441
+
1442
+ @wp.kernel
1443
+ def _bsr_scale_kernel(
1444
+ alpha: Any,
1445
+ values: wp.array(dtype=Any),
1446
+ ):
1447
+ row = wp.tid()
1448
+ values[row] = alpha * values[row]
1449
+
1450
+
1451
+ @wp.kernel
1452
+ def _bsr_scale_kernel(
1453
+ alpha: Any,
1454
+ values: wp.array3d(dtype=Any),
1455
+ ):
1456
+ row, br, bc = wp.tid()
1457
+ values[row, br, bc] = alpha * values[row, br, bc]
1458
+
1459
+
1460
+ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1461
+ """Perform the operation ``x := alpha * x`` on BSR matrix ``x`` and return ``x``."""
1462
+
1463
+ x, scale = _extract_matrix_and_scale(x)
1464
+ alpha *= scale
1465
+
1466
+ if alpha != 1.0 and x.nnz > 0:
1467
+ if alpha == 0.0:
1468
+ x.values.zero_()
1469
+ else:
1470
+ alpha = x.scalar_type(alpha)
1471
+
1472
+ wp.launch(
1473
+ kernel=_bsr_scale_kernel,
1474
+ dim=(x.nnz, *x.block_shape),
1475
+ device=x.values.device,
1476
+ inputs=[alpha, x.scalar_values],
1477
+ )
1478
+
1479
+ return x
1480
+
1481
+
1482
+ @wp.kernel(enable_backward=False)
1483
+ def _bsr_get_block_row(row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
1484
+ block = wp.tid()
1485
+ rows[block] = bsr_row_index(bsr_offsets, row_count, block)
1486
+
1487
+
1488
+ @wp.kernel
1489
+ def _bsr_axpy_add_block(
1490
+ src_offset: int,
1491
+ scale: Any,
1492
+ rows: wp.array(dtype=int),
1493
+ cols: wp.array(dtype=int),
1494
+ dst_offsets: wp.array(dtype=int),
1495
+ dst_columns: wp.array(dtype=int),
1496
+ src_values: wp.array3d(dtype=Any),
1497
+ dst_values: wp.array3d(dtype=Any),
1498
+ ):
1499
+ i, br, bc = wp.tid()
1500
+ row = rows[i + src_offset]
1501
+ col = cols[i + src_offset]
1502
+
1503
+ block = bsr_block_index(row, col, dst_offsets, dst_columns)
1504
+ if block != -1:
1505
+ dst_values[block, br, bc] += scale * src_values[i, br, bc]
1506
+
1507
+
1508
+ @wp.kernel
1509
+ def _bsr_axpy_masked(
1510
+ alpha: Any,
1511
+ row_count: int,
1512
+ src_offsets: wp.array(dtype=int),
1513
+ src_columns: wp.array(dtype=int),
1514
+ src_values: wp.array3d(dtype=Any),
1515
+ dst_offsets: wp.array(dtype=int),
1516
+ dst_columns: wp.array(dtype=int),
1517
+ dst_values: wp.array3d(dtype=Any),
1518
+ ):
1519
+ block, br, bc = wp.tid()
1520
+
1521
+ row = bsr_row_index(dst_offsets, row_count, block)
1522
+ if row == -1:
1523
+ return
1524
+
1525
+ col = dst_columns[block]
1526
+ src_block = bsr_block_index(row, col, src_offsets, src_columns)
1527
+ if src_block != -1:
1528
+ dst_values[block, br, bc] += alpha * src_values[src_block, br, bc]
1529
+
1530
+
1531
+ class bsr_axpy_work_arrays:
1532
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls."""
1533
+
1534
+ def __init__(self):
1535
+ self._reset(None)
1536
+
1537
+ def _reset(self, device):
1538
+ self.device = device
1539
+ self._sum_rows = None
1540
+ self._sum_cols = None
1541
+ self._old_y_values = None
1542
+ self._old_x_values = None
1543
+
1544
+ def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
1545
+ if self.device != device:
1546
+ self._reset(device)
1547
+
1548
+ if self._sum_rows is None or self._sum_rows.size < sum_nnz:
1549
+ self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
1550
+ if self._sum_cols is None or self._sum_cols.size < sum_nnz:
1551
+ self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
1552
+
1553
+ if self._old_y_values is None or self._old_y_values.size < y.nnz:
1554
+ self._old_y_values = wp.empty_like(y.values[: y.nnz])
1555
+
1556
+
1557
+ def bsr_axpy(
1558
+ x: BsrMatrixOrExpression,
1559
+ y: BsrMatrix[BlockType[Rows, Cols, Scalar]] | None = None,
1560
+ alpha: Scalar = 1.0,
1561
+ beta: Scalar = 1.0,
1562
+ masked: bool = False,
1563
+ work_arrays: bsr_axpy_work_arrays | None = None,
1564
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1565
+ """
1566
+ Perform the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices ``x`` and ``y`` and return ``y``.
1567
+
1568
+ The ``x`` and ``y`` matrices are allowed to alias.
1569
+
1570
+ Args:
1571
+ x: Read-only first operand.
1572
+ y: Mutable second operand and output matrix. If ``y`` is not provided, it will be allocated and treated as zero.
1573
+ alpha: Uniform scaling factor for ``x``.
1574
+ beta: Uniform scaling factor for ``y``.
1575
+ masked: If ``True``, keep the non-zero topology of ``y`` unchanged.
1576
+ work_arrays: In most cases, this function will require the use of temporary storage.
1577
+ This storage can be reused across calls by passing an instance of
1578
+ :class:`bsr_axpy_work_arrays` in ``work_arrays``.
1579
+ """
1580
+
1581
+ x, x_scale = _extract_matrix_and_scale(x)
1582
+ alpha *= x_scale
1583
+
1584
+ if y is None:
1585
+ if masked:
1586
+ raise ValueError("Left-hand-side 'y' matrix must be provided for masked addition")
1587
+
1588
+ # If not output matrix is provided, allocate it for convenience
1589
+ y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
1590
+ y.values.requires_grad = x.requires_grad
1591
+ beta = 0.0
1592
+
1593
+ x_nnz = x.nnz
1594
+ y_nnz = y.nnz
1595
+
1596
+ # Handle easy cases first
1597
+ if beta == 0.0 or y_nnz == 0:
1598
+ bsr_assign(src=x, dest=y, masked=masked)
1599
+ return bsr_scale(y, alpha=alpha)
1600
+
1601
+ if alpha == 0.0 or x_nnz == 0:
1602
+ return bsr_scale(y, alpha=beta)
1603
+
1604
+ if x == y:
1605
+ # Aliasing case
1606
+ return bsr_scale(y, alpha=alpha + beta)
1607
+
1608
+ # General case
1609
+
1610
+ if not isinstance(alpha, y.scalar_type):
1611
+ alpha = y.scalar_type(alpha)
1612
+ if not isinstance(beta, y.scalar_type):
1613
+ beta = y.scalar_type(beta)
1614
+
1615
+ if x.values.device != y.values.device:
1616
+ raise ValueError(f"All arguments must reside on the same device, got {x.values.device} and {y.values.device}")
1617
+
1618
+ if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
1619
+ raise ValueError(
1620
+ f"Matrices must have the same block type, got ({x.block_shape}, {x.scalar_type}) and ({y.block_shape}, {y.scalar_type})"
1621
+ )
1622
+
1623
+ if x.nrow != y.nrow or x.ncol != y.ncol:
1624
+ raise ValueError(
1625
+ f"Matrices must have the same number of rows and columns, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
1626
+ )
1627
+
1628
+ device = y.values.device
1629
+ if masked:
1630
+ bsr_scale(y, alpha=beta.value)
1631
+ wp.launch(
1632
+ kernel=_bsr_axpy_masked,
1633
+ device=device,
1634
+ dim=(y_nnz, y.block_shape[0], y.block_shape[1]),
1635
+ inputs=[
1636
+ alpha,
1637
+ x.nrow,
1638
+ x.offsets,
1639
+ x.columns,
1640
+ x.scalar_values,
1641
+ y.offsets,
1642
+ y.columns,
1643
+ y.scalar_values,
1644
+ ],
1645
+ )
1646
+
1647
+ else:
1648
+ if work_arrays is None:
1649
+ work_arrays = bsr_axpy_work_arrays()
1650
+
1651
+ sum_nnz = x_nnz + y_nnz
1652
+ work_arrays._allocate(device, y, sum_nnz)
1653
+
1654
+ wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
1655
+ y.uncompress_rows(out=work_arrays._sum_rows)
1656
+
1657
+ wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
1658
+ x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
1659
+
1660
+ # Save old y values before overwriting matrix
1661
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
1662
+
1663
+ # Increase dest array sizes if needed
1664
+ _bsr_ensure_fits(y, nnz=sum_nnz)
1665
+
1666
+ from warp._src.context import runtime
1667
+
1668
+ if device.is_cpu:
1669
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
1670
+ else:
1671
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
1672
+
1673
+ old_y_nnz = y_nnz
1674
+ nnz_buf, nnz_event = y._setup_nnz_transfer()
1675
+
1676
+ with wp.ScopedDevice(y.device):
1677
+ native_func(
1678
+ y.block_size,
1679
+ 0, # scalar_size_in_bytes
1680
+ y.nrow,
1681
+ y.ncol,
1682
+ sum_nnz,
1683
+ None, # device nnz
1684
+ ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1685
+ ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1686
+ None, # triplet values
1687
+ 0, # zero_value_mask
1688
+ masked,
1689
+ None, # summed block offsets
1690
+ None, # summed block indices
1691
+ ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1692
+ ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1693
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
1694
+ _optional_ctypes_event(nnz_event),
1695
+ )
1696
+
1697
+ y.values.zero_()
1698
+
1699
+ wp.launch(
1700
+ kernel=_bsr_axpy_add_block,
1701
+ device=device,
1702
+ dim=(old_y_nnz, y.block_shape[0], y.block_shape[1]),
1703
+ inputs=[
1704
+ 0,
1705
+ beta,
1706
+ work_arrays._sum_rows,
1707
+ work_arrays._sum_cols,
1708
+ y.offsets,
1709
+ y.columns,
1710
+ _as_3d_array(work_arrays._old_y_values, y.block_shape),
1711
+ y.scalar_values,
1712
+ ],
1713
+ )
1714
+
1715
+ wp.launch(
1716
+ kernel=_bsr_axpy_add_block,
1717
+ device=device,
1718
+ dim=(x_nnz, y.block_shape[0], y.block_shape[1]),
1719
+ inputs=[
1720
+ old_y_nnz,
1721
+ alpha,
1722
+ work_arrays._sum_rows,
1723
+ work_arrays._sum_cols,
1724
+ y.offsets,
1725
+ y.columns,
1726
+ x.scalar_values,
1727
+ y.scalar_values,
1728
+ ],
1729
+ )
1730
+
1731
+ return y
1732
+
1733
+
1734
+ def make_bsr_mm_count_coeffs(tile_size):
1735
+ from warp._src.fem.cache import dynamic_kernel
1736
+
1737
+ @dynamic_kernel(suffix=tile_size)
1738
+ def bsr_mm_count_coeffs(
1739
+ y_ncol: int,
1740
+ z_nnz: int,
1741
+ x_offsets: wp.array(dtype=int),
1742
+ x_columns: wp.array(dtype=int),
1743
+ y_offsets: wp.array(dtype=int),
1744
+ y_columns: wp.array(dtype=int),
1745
+ row_min: wp.array(dtype=int),
1746
+ block_counts: wp.array(dtype=int),
1747
+ ):
1748
+ row, lane = wp.tid()
1749
+ row_count = int(0)
1750
+
1751
+ x_beg = x_offsets[row]
1752
+ x_end = x_offsets[row + 1]
1753
+
1754
+ min_col = y_ncol
1755
+ max_col = int(0)
1756
+
1757
+ for x_block in range(x_beg + lane, x_end, tile_size):
1758
+ x_col = x_columns[x_block]
1759
+ y_row_end = y_offsets[x_col + 1]
1760
+ y_row_beg = y_offsets[x_col]
1761
+ block_count = y_row_end - y_row_beg
1762
+ if block_count != 0:
1763
+ min_col = wp.min(y_columns[y_row_beg], min_col)
1764
+ max_col = wp.max(y_columns[y_row_end - 1], max_col)
1765
+
1766
+ block_counts[x_block + 1] = block_count
1767
+ row_count += block_count
1768
+
1769
+ if wp.static(tile_size) > 1:
1770
+ row_count = wp.tile_sum(wp.tile(row_count))[0]
1771
+ min_col = wp.tile_min(wp.tile(min_col))[0]
1772
+ max_col = wp.tile_max(wp.tile(max_col))[0]
1773
+ col_range_size = wp.max(0, max_col - min_col + 1)
1774
+
1775
+ if row_count > col_range_size:
1776
+ # Optimization for deep products.
1777
+ # Do not store the whole whole list of src product terms, they would be highly redundant
1778
+ # Instead just mark a range in the output matrix
1779
+
1780
+ if lane == 0:
1781
+ row_min[row] = min_col
1782
+ block_counts[x_end] = col_range_size
1783
+
1784
+ for x_block in range(x_beg + lane, x_end - 1, tile_size):
1785
+ block_counts[x_block + 1] = 0
1786
+ elif lane == 0:
1787
+ row_min[row] = -1
1788
+
1789
+ if lane == 0 and row == 0:
1790
+ block_counts[0] = z_nnz
1791
+
1792
+ return bsr_mm_count_coeffs
1793
+
1794
+
1795
+ @wp.kernel(enable_backward=False)
1796
+ def _bsr_mm_list_coeffs(
1797
+ copied_z_nnz: int,
1798
+ mm_nnz: int,
1799
+ x_nrow: int,
1800
+ x_offsets: wp.array(dtype=int),
1801
+ x_columns: wp.array(dtype=int),
1802
+ y_offsets: wp.array(dtype=int),
1803
+ y_columns: wp.array(dtype=int),
1804
+ mm_row_min: wp.array(dtype=int),
1805
+ mm_offsets: wp.array(dtype=int),
1806
+ mm_rows: wp.array(dtype=int),
1807
+ mm_cols: wp.array(dtype=int),
1808
+ mm_src_blocks: wp.array(dtype=int),
1809
+ ):
1810
+ mm_block = wp.tid() + copied_z_nnz
1811
+
1812
+ x_nnz = x_offsets[x_nrow]
1813
+
1814
+ x_block = bsr_row_index(mm_offsets, x_nnz, mm_block)
1815
+
1816
+ if x_block == -1:
1817
+ mm_cols[mm_block] = -1
1818
+ mm_rows[mm_block] = -1
1819
+ return
1820
+
1821
+ if mm_block + 1 == mm_nnz and mm_nnz < mm_offsets[x_nnz]:
1822
+ wp.printf(
1823
+ "Number of potential `bsr_mm` blocks (%d) exceeded `max_nnz` (%d)\n",
1824
+ mm_offsets[x_nnz] - copied_z_nnz,
1825
+ mm_nnz - copied_z_nnz,
1826
+ )
1827
+
1828
+ pos = mm_block - mm_offsets[x_block]
1829
+
1830
+ row = bsr_row_index(x_offsets, x_nrow, x_block)
1831
+
1832
+ row_min_col = mm_row_min[row]
1833
+ if row_min_col == -1:
1834
+ x_col = x_columns[x_block]
1835
+ y_beg = y_offsets[x_col]
1836
+ y_block = y_beg + pos
1837
+ col = y_columns[y_block]
1838
+ src_block = x_block
1839
+ else:
1840
+ col = row_min_col + pos
1841
+ src_block = -1
1842
+
1843
+ mm_cols[mm_block] = col
1844
+ mm_rows[mm_block] = row
1845
+ mm_src_blocks[mm_block] = src_block
1846
+
1847
+
1848
+ @wp.func
1849
+ def _bsr_mm_use_triplets(
1850
+ row: int,
1851
+ mm_block: int,
1852
+ mm_row_min: wp.array(dtype=int),
1853
+ row_offsets: wp.array(dtype=int),
1854
+ summed_triplet_offsets: wp.array(dtype=int),
1855
+ ):
1856
+ x_beg = row_offsets[row]
1857
+ x_end = row_offsets[row + 1]
1858
+
1859
+ if mm_row_min:
1860
+ if mm_row_min[row] == -1:
1861
+ if mm_block == 0:
1862
+ block_beg = 0
1863
+ else:
1864
+ block_beg = summed_triplet_offsets[mm_block - 1]
1865
+ block_end = summed_triplet_offsets[mm_block]
1866
+
1867
+ if x_end - x_beg > 3 * (block_end - block_beg):
1868
+ return True, block_beg, block_end
1869
+
1870
+ return False, x_beg, x_end
1871
+
1872
+
1873
+ @wp.kernel(enable_backward=False)
1874
+ def _bsr_mm_compute_values(
1875
+ alpha: Any,
1876
+ x_offsets: wp.array(dtype=int),
1877
+ x_columns: wp.array(dtype=int),
1878
+ x_values: wp.array(dtype=Any),
1879
+ y_offsets: wp.array(dtype=int),
1880
+ y_columns: wp.array(dtype=int),
1881
+ y_values: wp.array(dtype=Any),
1882
+ mm_row_min: wp.array(dtype=int),
1883
+ summed_triplet_offsets: wp.array(dtype=int),
1884
+ summed_triplet_src_blocks: wp.indexedarray(dtype=int),
1885
+ mm_row_count: int,
1886
+ mm_offsets: wp.array(dtype=int),
1887
+ mm_cols: wp.array(dtype=int),
1888
+ mm_values: wp.array(dtype=Any),
1889
+ ):
1890
+ mm_block = wp.tid()
1891
+
1892
+ row = bsr_row_index(mm_offsets, mm_row_count, mm_block)
1893
+ if row == -1:
1894
+ return
1895
+
1896
+ use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
1897
+ row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
1898
+ )
1899
+
1900
+ mm_val = mm_values.dtype(type(alpha)(0.0))
1901
+ col = mm_cols[mm_block]
1902
+ if use_triplets:
1903
+ for tpl_idx in range(block_beg, block_end):
1904
+ x_block = summed_triplet_src_blocks[tpl_idx]
1905
+ x_col = x_columns[x_block]
1906
+ if x_block != -1:
1907
+ y_block = bsr_block_index(x_col, col, y_offsets, y_columns)
1908
+ mm_val += x_values[x_block] * y_values[y_block]
1909
+ else:
1910
+ for x_block in range(block_beg, block_end):
1911
+ x_col = x_columns[x_block]
1912
+ y_block = bsr_block_index(x_col, col, y_offsets, y_columns)
1913
+ if y_block != -1:
1914
+ mm_val += x_values[x_block] * y_values[y_block]
1915
+
1916
+ mm_values[mm_block] += alpha * mm_val
1917
+
1918
+
1919
+ def make_bsr_mm_compute_values_tiled_outer(subblock_rows, subblock_cols, block_depth, scalar_type, tile_size):
1920
+ from warp._src.fem.cache import dynamic_func, dynamic_kernel
1921
+
1922
+ mm_type = wp.mat(dtype=scalar_type, shape=(subblock_rows, subblock_cols))
1923
+
1924
+ x_col_vec_t = wp.vec(dtype=scalar_type, length=subblock_rows)
1925
+ y_row_vec_t = wp.vec(dtype=scalar_type, length=subblock_cols)
1926
+
1927
+ suffix = (subblock_rows, subblock_cols, block_depth, tile_size, scalar_type.__name__)
1928
+
1929
+ @dynamic_func(suffix=suffix)
1930
+ def _outer_product(
1931
+ x_values: wp.array2d(dtype=Any),
1932
+ y_values: wp.array2d(dtype=Any),
1933
+ brow_off: int,
1934
+ bcol_off: int,
1935
+ block_col: int,
1936
+ brow_count: int,
1937
+ bcol_count: int,
1938
+ ):
1939
+ x_col_vec = x_col_vec_t()
1940
+ y_row_vec = y_row_vec_t()
1941
+
1942
+ for k in range(brow_count):
1943
+ x_col_vec[k] = x_values[brow_off + k, block_col]
1944
+ for k in range(bcol_count):
1945
+ y_row_vec[k] = y_values[block_col, bcol_off + k]
1946
+
1947
+ return wp.outer(x_col_vec, y_row_vec)
1948
+
1949
+ @dynamic_kernel(suffix=suffix, kernel_options={"enable_backward": False})
1950
+ def bsr_mm_compute_values(
1951
+ alpha: Any,
1952
+ x_offsets: wp.array(dtype=int),
1953
+ x_columns: wp.array(dtype=int),
1954
+ x_values: wp.array3d(dtype=Any),
1955
+ y_offsets: wp.array(dtype=int),
1956
+ y_columns: wp.array(dtype=int),
1957
+ y_values: wp.array3d(dtype=Any),
1958
+ mm_row_min: wp.array(dtype=int),
1959
+ summed_triplet_offsets: wp.array(dtype=int),
1960
+ summed_triplet_src_blocks: wp.indexedarray(dtype=int),
1961
+ mm_row_count: int,
1962
+ mm_offsets: wp.array(dtype=int),
1963
+ mm_cols: wp.array(dtype=int),
1964
+ mm_values: wp.array3d(dtype=Any),
1965
+ ):
1966
+ mm_block, subrow, subcol, lane = wp.tid()
1967
+
1968
+ brow_off = subrow * wp.static(subblock_rows)
1969
+ bcol_off = subcol * wp.static(subblock_cols)
1970
+
1971
+ brow_count = wp.min(mm_values.shape[1] - brow_off, subblock_rows)
1972
+ bcol_count = wp.min(mm_values.shape[2] - bcol_off, subblock_cols)
1973
+
1974
+ mm_row = bsr_row_index(mm_offsets, mm_row_count, mm_block)
1975
+ if mm_row == -1:
1976
+ return
1977
+
1978
+ lane_val = mm_type()
1979
+
1980
+ use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
1981
+ mm_row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
1982
+ )
1983
+
1984
+ col_count = (block_end - block_beg) * block_depth
1985
+
1986
+ mm_col = mm_cols[mm_block]
1987
+ if use_triplets:
1988
+ for col in range(lane, col_count, tile_size):
1989
+ tpl_block = col // wp.static(block_depth)
1990
+ block_col = col - tpl_block * wp.static(block_depth)
1991
+ tpl_block += block_beg
1992
+
1993
+ x_block = summed_triplet_src_blocks[tpl_block]
1994
+ if x_block != -1:
1995
+ x_col = x_columns[x_block]
1996
+ y_block = bsr_block_index(x_col, mm_col, y_offsets, y_columns)
1997
+ lane_val += _outer_product(
1998
+ x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
1999
+ )
2000
+ else:
2001
+ for col in range(lane, col_count, tile_size):
2002
+ x_block = col // wp.static(block_depth)
2003
+ block_col = col - x_block * wp.static(block_depth)
2004
+ x_block += block_beg
2005
+
2006
+ x_col = x_columns[x_block]
2007
+ y_block = bsr_block_index(x_col, mm_col, y_offsets, y_columns)
2008
+
2009
+ if y_block != -1:
2010
+ lane_val += _outer_product(
2011
+ x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
2012
+ )
2013
+
2014
+ mm_val = wp.tile_sum(wp.tile(lane_val, preserve_type=True))[0]
2015
+
2016
+ for coef in range(lane, wp.static(subblock_cols * subblock_rows), tile_size):
2017
+ br = coef // subblock_cols
2018
+ bc = coef - br * subblock_cols
2019
+ if br < brow_count and bc < bcol_count:
2020
+ mm_values[mm_block, br + brow_off, bc + bcol_off] += mm_val[br, bc] * alpha
2021
+
2022
+ return bsr_mm_compute_values
2023
+
2024
+
2025
+ class bsr_mm_work_arrays:
2026
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
2027
+
2028
+ def __init__(self):
2029
+ self._reset(None)
2030
+
2031
+ def _reset(self, device):
2032
+ self.device = device
2033
+ self._mm_row_min = None
2034
+ self._mm_block_counts = None
2035
+ self._mm_rows = None
2036
+ self._mm_cols = None
2037
+ self._mm_src_blocks = None
2038
+ self._old_z_values = None
2039
+ self._old_z_offsets = None
2040
+ self._old_z_columns = None
2041
+ self._mm_nnz = 0
2042
+
2043
+ def _allocate_stage_1(self, device, x_nnz: int, z: BsrMatrix, beta: float, z_aliasing: bool):
2044
+ if self.device != device:
2045
+ self._reset(device)
2046
+
2047
+ # Allocations that do not depend on any computation
2048
+ self._copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
2049
+
2050
+ if self._mm_row_min is None or self._mm_block_counts.size < z.nrow + 1:
2051
+ self._mm_row_min = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
2052
+ if self._mm_block_counts is None or self._mm_block_counts.size < x_nnz + 1:
2053
+ self._mm_block_counts = wp.empty(shape=(x_nnz + 1,), dtype=int, device=self.device)
2054
+
2055
+ if self._copied_z_nnz > 0:
2056
+ if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
2057
+ self._old_z_values = wp.empty(shape=(self._copied_z_nnz,), dtype=z.values.dtype, device=self.device)
2058
+
2059
+ if z_aliasing:
2060
+ if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
2061
+ self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
2062
+ if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
2063
+ self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
2064
+
2065
+ def _allocate_stage_2(self, mm_nnz: int):
2066
+ # Allocations that depend on unmerged nnz estimate
2067
+ self._mm_nnz = mm_nnz
2068
+ if self._mm_rows is None or self._mm_rows.size < mm_nnz:
2069
+ self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
2070
+ if self._mm_cols is None or self._mm_cols.size < mm_nnz:
2071
+ self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
2072
+ if self._mm_src_blocks is None or self._mm_src_blocks.size < mm_nnz:
2073
+ self._mm_src_blocks = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
2074
+
2075
+
2076
+ def bsr_mm(
2077
+ x: BsrMatrixOrExpression[BlockType[Rows, Any, Scalar]],
2078
+ y: BsrMatrixOrExpression[BlockType[Any, Cols, Scalar]],
2079
+ z: BsrMatrix[BlockType[Rows, Cols, Scalar]] | None = None,
2080
+ alpha: Scalar = 1.0,
2081
+ beta: Scalar = 0.0,
2082
+ masked: bool = False,
2083
+ work_arrays: bsr_mm_work_arrays | None = None,
2084
+ reuse_topology: bool = False,
2085
+ tile_size: int = 0,
2086
+ max_new_nnz: int | None = None,
2087
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
2088
+ """
2089
+ Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
2090
+
2091
+ The ``x``, ``y`` and ``z`` matrices are allowed to alias.
2092
+ If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
2093
+
2094
+ This method can be graph-captured if either:
2095
+ - `masked=True`
2096
+ - `reuse_topology=True`
2097
+ - `max_new_nnz` is provided
2098
+
2099
+ Args:
2100
+ x: Read-only left operand of the matrix-matrix product.
2101
+ y: Read-only right operand of the matrix-matrix product.
2102
+ z: Mutable affine operand and result matrix. If ``z`` is not provided, it will be allocated and treated as zero.
2103
+ alpha: Uniform scaling factor for the ``x @ y`` product
2104
+ beta: Uniform scaling factor for ``z``
2105
+ masked: If ``True``, keep the non-zero topology of ``z`` unchanged.
2106
+ work_arrays: In most cases, this function will require the use of temporary storage.
2107
+ This storage can be reused across calls by passing an instance of
2108
+ :class:`bsr_mm_work_arrays` in ``work_arrays``.
2109
+ reuse_topology: If ``True``, reuse the product topology information
2110
+ stored in ``work_arrays`` rather than recompute it from scratch.
2111
+ The matrices ``x``, ``y`` and ``z`` must be structurally similar to
2112
+ the previous call in which ``work_arrays`` were populated.
2113
+ max_new_nnz: If provided, the maximum number of non-zeros for the matrix-matrix product result
2114
+ (not counting the existing non-zeros in ``z``).
2115
+ tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
2116
+ If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
2117
+ use tiles using using an heuristic based on the matrix shape and number of non-zeros..
2118
+ """
2119
+
2120
+ x, x_scale = _extract_matrix_and_scale(x)
2121
+ alpha *= x_scale
2122
+ y, y_scale = _extract_matrix_and_scale(y)
2123
+ alpha *= y_scale
2124
+
2125
+ if z is None:
2126
+ if masked:
2127
+ raise ValueError("Left-hand-side 'z' matrix must be provided for masked multiplication")
2128
+
2129
+ # If not output matrix is provided, allocate it for convenience
2130
+ z_block_shape = (x.block_shape[0], y.block_shape[1])
2131
+ if z_block_shape == (1, 1):
2132
+ z_block_type = x.scalar_type
2133
+ else:
2134
+ z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
2135
+ z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
2136
+ z.values.requires_grad = x.requires_grad or y.requires_grad
2137
+ beta = 0.0
2138
+
2139
+ if x.values.device != y.values.device or x.values.device != z.values.device:
2140
+ raise ValueError(
2141
+ f"All arguments must reside on the same device, got {x.values.device}, {y.values.device} and {z.values.device}"
2142
+ )
2143
+
2144
+ if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
2145
+ raise ValueError(
2146
+ f"Matrices must have the same scalar type, got {x.scalar_type}, {y.scalar_type} and {z.scalar_type}"
2147
+ )
2148
+
2149
+ if (
2150
+ x.block_shape[0] != z.block_shape[0]
2151
+ or y.block_shape[1] != z.block_shape[1]
2152
+ or x.block_shape[1] != y.block_shape[0]
2153
+ ):
2154
+ raise ValueError(
2155
+ f"Incompatible block sizes for matrix multiplication, got ({x.block_shape}, {y.block_shape}) and ({z.block_shape})"
2156
+ )
2157
+
2158
+ if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
2159
+ raise ValueError(
2160
+ f"Incompatible number of rows/columns for matrix multiplication, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
2161
+ )
2162
+
2163
+ device = z.values.device
2164
+
2165
+ if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
2166
+ # Easy case
2167
+ return bsr_scale(z, beta)
2168
+
2169
+ z_aliasing = z == x or z == y
2170
+
2171
+ if masked:
2172
+ # no need to copy z, scale in-place
2173
+ copied_z_nnz = 0
2174
+ mm_nnz = z.nnz
2175
+
2176
+ if z_aliasing:
2177
+ raise ValueError("`masked=True` is not supported for aliased inputs")
2178
+
2179
+ if beta == 0.0:
2180
+ # do not bsr_scale(0), this would not preserve topology
2181
+ z.values.zero_()
2182
+ else:
2183
+ bsr_scale(z, beta)
2184
+ elif reuse_topology:
2185
+ if work_arrays is None:
2186
+ raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
2187
+
2188
+ copied_z_nnz = work_arrays._copied_z_nnz
2189
+ mm_nnz = work_arrays._mm_nnz
2190
+ else:
2191
+ if work_arrays is None:
2192
+ work_arrays = bsr_mm_work_arrays()
2193
+
2194
+ if max_new_nnz is None:
2195
+ if device.is_capturing:
2196
+ raise RuntimeError(
2197
+ "`bsr_mm` requires either `reuse_topology=True`, `masked=True` or `max_new_nnz` to be set for use in graph capture"
2198
+ )
2199
+ z.nnz_sync()
2200
+
2201
+ work_arrays._allocate_stage_1(device, x.nnz, z, beta, z_aliasing)
2202
+ copied_z_nnz = work_arrays._copied_z_nnz
2203
+
2204
+ # Prefix sum of number of (unmerged) mm blocks per row
2205
+ # Use either a thread or a block per row depending on avg nnz/row
2206
+ work_arrays._mm_block_counts.zero_()
2207
+ count_tile_size = 32
2208
+ if not device.is_cuda or x.nnz < 3 * count_tile_size * x.nrow:
2209
+ count_tile_size = 1
2210
+
2211
+ wp.launch(
2212
+ kernel=make_bsr_mm_count_coeffs(count_tile_size),
2213
+ device=device,
2214
+ dim=(z.nrow, count_tile_size),
2215
+ block_dim=count_tile_size if count_tile_size > 1 else 256,
2216
+ inputs=[
2217
+ y.ncol,
2218
+ copied_z_nnz,
2219
+ x.offsets,
2220
+ x.columns,
2221
+ y.offsets,
2222
+ y.columns,
2223
+ work_arrays._mm_row_min,
2224
+ work_arrays._mm_block_counts,
2225
+ ],
2226
+ )
2227
+ warp._src.utils.array_scan(work_arrays._mm_block_counts[: x.nnz + 1], work_arrays._mm_block_counts[: x.nnz + 1])
2228
+
2229
+ if max_new_nnz is not None:
2230
+ mm_nnz = max_new_nnz + copied_z_nnz
2231
+ else:
2232
+ # Get back total counts on host -- we need a synchronization here
2233
+ # Use pinned buffer from z, we are going to need it later anyway
2234
+ nnz_buf, _ = z._setup_nnz_transfer()
2235
+ stream = wp.get_stream(device) if device.is_cuda else None
2236
+ wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
2237
+ if device.is_cuda:
2238
+ wp.synchronize_stream(stream)
2239
+ mm_nnz = int(nnz_buf.numpy()[0])
2240
+
2241
+ if mm_nnz == copied_z_nnz:
2242
+ # x@y = 0
2243
+ return bsr_scale(z, beta)
2244
+
2245
+ work_arrays._allocate_stage_2(mm_nnz)
2246
+
2247
+ # If z has a non-zero scale, save current data before overwriting it
2248
+ if copied_z_nnz > 0:
2249
+ # Copy z row and column indices
2250
+ wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
2251
+ z.uncompress_rows(out=work_arrays._mm_rows)
2252
+ work_arrays._mm_src_blocks[:copied_z_nnz].fill_(-1)
2253
+ if z_aliasing:
2254
+ # If z is aliasing with x or y, need to save topology as well
2255
+ wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
2256
+ wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
2257
+
2258
+ # Fill unmerged mm blocks rows and columns
2259
+ wp.launch(
2260
+ kernel=_bsr_mm_list_coeffs,
2261
+ device=device,
2262
+ dim=mm_nnz - copied_z_nnz,
2263
+ inputs=[
2264
+ copied_z_nnz,
2265
+ mm_nnz,
2266
+ x.nrow,
2267
+ x.offsets,
2268
+ x.columns,
2269
+ y.offsets,
2270
+ y.columns,
2271
+ work_arrays._mm_row_min,
2272
+ work_arrays._mm_block_counts,
2273
+ work_arrays._mm_rows,
2274
+ work_arrays._mm_cols,
2275
+ work_arrays._mm_src_blocks,
2276
+ ],
2277
+ )
2278
+
2279
+ alpha = z.scalar_type(alpha)
2280
+ beta = z.scalar_type(beta)
2281
+
2282
+ if copied_z_nnz > 0:
2283
+ # Save current z values in temporary buffer
2284
+ wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
2285
+
2286
+ if not masked:
2287
+ # Increase dest array size if needed
2288
+ if z.columns.shape[0] < mm_nnz:
2289
+ z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
2290
+
2291
+ from warp._src.context import runtime
2292
+
2293
+ if device.is_cpu:
2294
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
2295
+ else:
2296
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
2297
+
2298
+ nnz_buf, nnz_event = z._setup_nnz_transfer()
2299
+ summed_triplet_offsets = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
2300
+ summed_triplet_indices = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
2301
+
2302
+ with wp.ScopedDevice(z.device):
2303
+ native_func(
2304
+ z.block_size,
2305
+ 0, # scalar_size_in_bytes
2306
+ z.nrow,
2307
+ z.ncol,
2308
+ mm_nnz,
2309
+ None, # device nnz
2310
+ ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
2311
+ ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
2312
+ None, # triplet values
2313
+ 0, # zero_value_mask
2314
+ False, # masked_topology
2315
+ ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
2316
+ ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
2317
+ ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
2318
+ ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
2319
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
2320
+ _optional_ctypes_event(nnz_event),
2321
+ )
2322
+
2323
+ # Resize z to fit mm result if necessary
2324
+ # If we are not reusing the product topology, this needs another synchronization
2325
+ if not reuse_topology:
2326
+ work_arrays.result_nnz = z.nnz_sync() if max_new_nnz is None else mm_nnz
2327
+
2328
+ _bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
2329
+ z.values.zero_()
2330
+
2331
+ if copied_z_nnz > 0:
2332
+ # Add back original z values
2333
+ wp.launch(
2334
+ kernel=_bsr_axpy_add_block,
2335
+ device=device,
2336
+ dim=(copied_z_nnz, z.block_shape[0], z.block_shape[1]),
2337
+ inputs=[
2338
+ 0,
2339
+ beta,
2340
+ work_arrays._mm_rows,
2341
+ work_arrays._mm_cols,
2342
+ z.offsets,
2343
+ z.columns,
2344
+ _as_3d_array(work_arrays._old_z_values, z.block_shape),
2345
+ z.scalar_values,
2346
+ ],
2347
+ )
2348
+
2349
+ max_subblock_dim = 12
2350
+ if tile_size > 0:
2351
+ use_tiles = True
2352
+ elif tile_size < 0:
2353
+ use_tiles = False
2354
+ else:
2355
+ # Heuristic for using tiled variant: few or very large blocks
2356
+ tile_size = 64
2357
+ max_tiles_per_sm = 2048 // tile_size # assume 64 resident warps per SM
2358
+ use_tiles = device.is_cuda and (
2359
+ max(x.block_size, y.block_size, z.block_size) > max_subblock_dim**2
2360
+ or z.nnz < max_tiles_per_sm * device.sm_count
2361
+ )
2362
+
2363
+ if use_tiles:
2364
+ subblock_rows = min(max_subblock_dim, z.block_shape[0])
2365
+ subblock_cols = min(max_subblock_dim, z.block_shape[1])
2366
+
2367
+ wp.launch(
2368
+ kernel=make_bsr_mm_compute_values_tiled_outer(
2369
+ subblock_rows, subblock_cols, x.block_shape[1], z.scalar_type, tile_size
2370
+ ),
2371
+ device=device,
2372
+ dim=(
2373
+ z.nnz,
2374
+ (z.block_shape[0] + subblock_rows - 1) // subblock_rows,
2375
+ (z.block_shape[1] + subblock_cols - 1) // subblock_cols,
2376
+ tile_size,
2377
+ ),
2378
+ block_dim=tile_size,
2379
+ inputs=[
2380
+ alpha,
2381
+ work_arrays._old_z_offsets if x == z else x.offsets,
2382
+ work_arrays._old_z_columns if x == z else x.columns,
2383
+ _as_3d_array(work_arrays._old_z_values, z.block_shape) if x == z else x.scalar_values,
2384
+ work_arrays._old_z_offsets if y == z else y.offsets,
2385
+ work_arrays._old_z_columns if y == z else y.columns,
2386
+ _as_3d_array(work_arrays._old_z_values, z.block_shape) if y == z else y.scalar_values,
2387
+ None if masked else work_arrays._mm_row_min,
2388
+ None if masked else summed_triplet_offsets,
2389
+ None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
2390
+ z.nrow,
2391
+ z.offsets,
2392
+ z.columns,
2393
+ z.scalar_values,
2394
+ ],
2395
+ )
2396
+
2397
+ return z
2398
+
2399
+ # Add mm blocks to z values
2400
+ if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
2401
+ # Result block type is scalar, but operands are matrices
2402
+ # Cast result to (1x1) matrix to perform multiplication
2403
+ mm_values = z.values.view(wp.mat(shape=(1, 1), dtype=z.scalar_type))
2404
+ else:
2405
+ mm_values = z.values
2406
+
2407
+ wp.launch(
2408
+ kernel=_bsr_mm_compute_values,
2409
+ device=device,
2410
+ dim=z.nnz,
2411
+ inputs=[
2412
+ alpha,
2413
+ work_arrays._old_z_offsets if x == z else x.offsets,
2414
+ work_arrays._old_z_columns if x == z else x.columns,
2415
+ work_arrays._old_z_values if x == z else x.values,
2416
+ work_arrays._old_z_offsets if y == z else y.offsets,
2417
+ work_arrays._old_z_columns if y == z else y.columns,
2418
+ work_arrays._old_z_values if y == z else y.values,
2419
+ None if masked else work_arrays._mm_row_min,
2420
+ None if masked else summed_triplet_offsets,
2421
+ None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
2422
+ z.nrow,
2423
+ z.offsets,
2424
+ z.columns,
2425
+ mm_values,
2426
+ ],
2427
+ )
2428
+
2429
+ return z
2430
+
2431
+
2432
+ def make_bsr_mv_kernel(block_cols: int):
2433
+ from warp._src.fem.cache import dynamic_kernel
2434
+
2435
+ @dynamic_kernel(suffix=block_cols, kernel_options={"enable_backward": False})
2436
+ def bsr_mv_kernel(
2437
+ alpha: Any,
2438
+ A_offsets: wp.array(dtype=int),
2439
+ A_columns: wp.array(dtype=int),
2440
+ A_values: wp.array3d(dtype=Any),
2441
+ x: wp.array(dtype=Any),
2442
+ beta: Any,
2443
+ y: wp.array(dtype=Any),
2444
+ ):
2445
+ row, subrow = wp.tid()
2446
+
2447
+ block_rows = A_values.shape[1]
2448
+
2449
+ yi = row * block_rows + subrow
2450
+
2451
+ # zero-initialize with type of y elements
2452
+ scalar_zero = type(alpha)(0)
2453
+ v = scalar_zero
2454
+
2455
+ if alpha != scalar_zero:
2456
+ beg = A_offsets[row]
2457
+ end = A_offsets[row + 1]
2458
+ for block in range(beg, end):
2459
+ xs = A_columns[block] * block_cols
2460
+ for col in range(wp.static(block_cols)):
2461
+ v += A_values[block, subrow, col] * x[xs + col]
2462
+ v *= alpha
2463
+
2464
+ if beta != scalar_zero:
2465
+ v += beta * y[yi]
2466
+
2467
+ y[yi] = v
2468
+
2469
+ return bsr_mv_kernel
2470
+
2471
+
2472
+ def make_bsr_mv_tiled_kernel(tile_size: int):
2473
+ from warp._src.fem.cache import dynamic_kernel
2474
+
2475
+ @dynamic_kernel(suffix=tile_size, kernel_options={"enable_backward": False})
2476
+ def bsr_mv_tiled_kernel(
2477
+ alpha: Any,
2478
+ A_offsets: wp.array(dtype=int),
2479
+ A_columns: wp.array(dtype=int),
2480
+ A_values: wp.array3d(dtype=Any),
2481
+ x: wp.array(dtype=Any),
2482
+ beta: Any,
2483
+ y: wp.array(dtype=Any),
2484
+ ):
2485
+ row, subrow, lane = wp.tid()
2486
+
2487
+ scalar_zero = type(alpha)(0)
2488
+ block_rows = A_values.shape[1]
2489
+ block_cols = A_values.shape[2]
2490
+
2491
+ yi = row * block_rows + subrow
2492
+
2493
+ if beta == scalar_zero:
2494
+ subrow_sum = wp.tile_zeros(shape=(1,), dtype=y.dtype)
2495
+ else:
2496
+ subrow_sum = beta * wp.tile_load(y, 1, yi)
2497
+
2498
+ if alpha != scalar_zero:
2499
+ block_beg = A_offsets[row]
2500
+ col_count = (A_offsets[row + 1] - block_beg) * block_cols
2501
+
2502
+ col = lane
2503
+ lane_sum = y.dtype(0)
2504
+
2505
+ for col in range(lane, col_count, tile_size):
2506
+ block = col // block_cols
2507
+ block_col = col - block * block_cols
2508
+ block += block_beg
2509
+
2510
+ xi = x[A_columns[block] * block_cols + block_col]
2511
+ lane_sum += A_values[block, subrow, block_col] * xi
2512
+
2513
+ lane_sum *= alpha
2514
+ subrow_sum += wp.tile_sum(wp.tile(lane_sum))
2515
+
2516
+ wp.tile_store(y, subrow_sum, yi)
2517
+
2518
+ return bsr_mv_tiled_kernel
2519
+
2520
+
2521
+ def make_bsr_mv_transpose_kernel(block_rows: int):
2522
+ from warp._src.fem.cache import dynamic_kernel
2523
+
2524
+ @dynamic_kernel(suffix=block_rows, kernel_options={"enable_backward": False})
2525
+ def bsr_mv_transpose_kernel(
2526
+ alpha: Any,
2527
+ A_row_count: int,
2528
+ A_offsets: wp.array(dtype=int),
2529
+ A_columns: wp.array(dtype=int),
2530
+ A_values: wp.array3d(dtype=Any),
2531
+ x: wp.array(dtype=Any),
2532
+ y: wp.array(dtype=Any),
2533
+ ):
2534
+ block, subcol = wp.tid()
2535
+
2536
+ row = bsr_row_index(A_offsets, A_row_count, block)
2537
+ if row == -1:
2538
+ return
2539
+
2540
+ block_cols = A_values.shape[2]
2541
+
2542
+ A_block = A_values[block]
2543
+
2544
+ col_sum = type(alpha)(0)
2545
+ for subrow in range(wp.static(block_rows)):
2546
+ col_sum += A_block[subrow, subcol] * x[row * block_rows + subrow]
2547
+
2548
+ wp.atomic_add(y, A_columns[block] * block_cols + subcol, alpha * col_sum)
2549
+
2550
+ return bsr_mv_transpose_kernel
2551
+
2552
+
2553
+ def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
2554
+ # cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
2555
+
2556
+ scalar_count = array.size * type_size(array.dtype)
2557
+ if scalar_count != expected_scalar_count:
2558
+ raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
2559
+
2560
+ if array.ndim == 1 and types_equal(array.dtype, dtype):
2561
+ return array
2562
+
2563
+ if type_scalar_type(array.dtype) != type_scalar_type(dtype):
2564
+ raise ValueError(f"Incompatible scalar types, expected {type_repr(array.dtype)}, got {type_repr(dtype)}")
2565
+
2566
+ if array.ndim > 2:
2567
+ raise ValueError(f"Incompatible array number of dimensions, expected 1 or 2, got {array.ndim}")
2568
+
2569
+ if not array.is_contiguous:
2570
+ raise ValueError("Array must be contiguous")
2571
+
2572
+ vec_length = type_size(dtype)
2573
+ vec_count = scalar_count // vec_length
2574
+ if vec_count * vec_length != scalar_count:
2575
+ raise ValueError(
2576
+ f"Array of shape {array.shape} and type {type_repr(array.dtype)} cannot be reshaped to an array of type {type_repr(dtype)}"
2577
+ )
2578
+
2579
+ def vec_view(array):
2580
+ return wp.array(
2581
+ data=None,
2582
+ ptr=array.ptr,
2583
+ capacity=array.capacity,
2584
+ device=array.device,
2585
+ dtype=dtype,
2586
+ shape=vec_count,
2587
+ grad=None if array.grad is None else vec_view(array.grad),
2588
+ )
2589
+
2590
+ view = vec_view(array)
2591
+ view._ref = array
2592
+ return view
2593
+
2594
+
2595
+ def bsr_mv(
2596
+ A: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
2597
+ x: Array[Vector[Cols, Scalar] | Scalar],
2598
+ y: Array[Vector[Rows, Scalar] | Scalar] | None = None,
2599
+ alpha: Scalar = 1.0,
2600
+ beta: Scalar = 0.0,
2601
+ transpose: bool = False,
2602
+ work_buffer: Array[Vector[Rows, Scalar] | Scalar] | None = None,
2603
+ tile_size: int = 0,
2604
+ ) -> Array[Vector[Rows, Scalar] | Scalar]:
2605
+ """Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
2606
+
2607
+ The ``x`` and ``y`` vectors are allowed to alias.
2608
+
2609
+ Args:
2610
+ A: Read-only, left matrix operand of the matrix-vector product.
2611
+ x: Read-only, right vector operand of the matrix-vector product.
2612
+ y: Mutable affine operand and result vector. If ``y`` is not provided, it will be allocated and treated as zero.
2613
+ alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
2614
+ beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
2615
+ transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
2616
+ work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
2617
+ If provided, the ``work_buffer`` array will be used for this purpose,
2618
+ otherwise a temporary allocation will be performed.
2619
+ tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
2620
+ If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
2621
+ use tiles using using an heuristic based on the matrix shape and number of non-zeros..
2622
+ """
2623
+
2624
+ A, A_scale = _extract_matrix_and_scale(A)
2625
+ alpha *= A_scale
2626
+
2627
+ if transpose:
2628
+ block_shape = A.block_shape[1], A.block_shape[0]
2629
+ nrow, ncol = A.ncol, A.nrow
2630
+ else:
2631
+ block_shape = A.block_shape
2632
+ nrow, ncol = A.nrow, A.ncol
2633
+
2634
+ if y is None:
2635
+ # If no output array is provided, allocate one for convenience
2636
+ y_vec_len = block_shape[0]
2637
+ y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
2638
+ y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype, requires_grad=x.requires_grad)
2639
+ beta = 0.0
2640
+
2641
+ alpha = A.scalar_type(alpha)
2642
+ beta = A.scalar_type(beta)
2643
+
2644
+ device = A.values.device
2645
+ if A.values.device != x.device or A.values.device != y.device:
2646
+ raise ValueError(
2647
+ f"A, x, and y must reside on the same device, got {A.values.device}, {x.device} and {y.device}"
2648
+ )
2649
+
2650
+ if x.ptr == y.ptr:
2651
+ # Aliasing case, need temporary storage
2652
+ if work_buffer is None:
2653
+ work_buffer = wp.empty_like(y)
2654
+ elif work_buffer.size < y.size:
2655
+ raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}, got {work_buffer.size}")
2656
+ elif not types_equal(work_buffer.dtype, y.dtype):
2657
+ raise ValueError(
2658
+ f"Work buffer must have same data type as y, {type_repr(y.dtype)} vs {type_repr(work_buffer.dtype)}"
2659
+ )
2660
+
2661
+ # Save old y values before overwriting vector
2662
+ wp.copy(dest=work_buffer, src=y, count=y.size)
2663
+ x = work_buffer
2664
+
2665
+ try:
2666
+ x_view = _vec_array_view(x, A.scalar_type, expected_scalar_count=ncol * block_shape[1])
2667
+ except ValueError as err:
2668
+ raise ValueError("Incompatible 'x' vector for bsr_mv") from err
2669
+ try:
2670
+ y_view = _vec_array_view(y, A.scalar_type, expected_scalar_count=nrow * block_shape[0])
2671
+ except ValueError as err:
2672
+ raise ValueError("Incompatible 'y' vector for bsr_mv") from err
2673
+
2674
+ # heuristic to use tiled version for long rows
2675
+ if tile_size > 0:
2676
+ use_tiles = True
2677
+ elif tile_size < 0:
2678
+ use_tiles = False
2679
+ else:
2680
+ tile_size = 64
2681
+ use_tiles = device.is_cuda and A.nnz * A.block_size > 2 * tile_size * A.shape[0]
2682
+
2683
+ if transpose:
2684
+ if beta.value == 0.0:
2685
+ y.zero_()
2686
+ elif beta.value != 1.0:
2687
+ wp.launch(
2688
+ kernel=_bsr_scale_kernel,
2689
+ device=y.device,
2690
+ dim=y_view.shape[0],
2691
+ inputs=[beta, y_view],
2692
+ )
2693
+ if alpha.value != 0.0:
2694
+ wp.launch(
2695
+ kernel=make_bsr_mv_transpose_kernel(block_rows=block_shape[1]),
2696
+ device=A.values.device,
2697
+ dim=(A.nnz, block_shape[0]),
2698
+ inputs=[alpha, A.nrow, A.offsets, A.columns, A.scalar_values, x_view, y_view],
2699
+ )
2700
+ elif use_tiles:
2701
+ wp.launch(
2702
+ kernel=make_bsr_mv_tiled_kernel(tile_size),
2703
+ device=A.values.device,
2704
+ dim=(nrow, block_shape[0], tile_size),
2705
+ block_dim=tile_size,
2706
+ inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
2707
+ )
2708
+ else:
2709
+ wp.launch(
2710
+ kernel=make_bsr_mv_kernel(block_cols=block_shape[1]),
2711
+ device=A.values.device,
2712
+ dim=(nrow, block_shape[0]),
2713
+ inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
2714
+ )
2715
+
2716
+ return y