warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,2565 +13,39 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- import ctypes
17
- from typing import Any, Generic, Optional, Tuple, TypeVar, Union
18
-
19
- import warp as wp
20
- import warp.types
21
- import warp.utils
22
- from warp.types import (
23
- Array,
24
- Cols,
25
- Rows,
26
- Scalar,
27
- Vector,
28
- is_array,
29
- scalar_types,
30
- type_is_matrix,
31
- type_repr,
32
- type_scalar_type,
33
- type_size,
34
- type_size_in_bytes,
35
- type_to_warp,
36
- types_equal,
37
- )
38
-
39
- __all__ = [
40
- "BsrMatrix",
41
- "bsr_assign",
42
- "bsr_axpy",
43
- "bsr_copy",
44
- "bsr_diag",
45
- "bsr_from_triplets",
46
- "bsr_get_diag",
47
- "bsr_identity",
48
- "bsr_matrix_t",
49
- "bsr_mm",
50
- "bsr_mm_work_arrays",
51
- "bsr_mv",
52
- "bsr_scale",
53
- "bsr_set_diag",
54
- "bsr_set_from_triplets",
55
- "bsr_set_identity",
56
- "bsr_set_transpose",
57
- "bsr_set_zero",
58
- "bsr_transposed",
59
- "bsr_zeros",
60
- ]
61
-
62
-
63
- # typing hints
64
-
65
- _BlockType = TypeVar("BlockType") # noqa: PLC0132
66
-
67
-
68
- class _MatrixBlockType(Generic[Rows, Cols, Scalar]):
69
- pass
70
-
71
-
72
- class _ScalarBlockType(Generic[Scalar]):
73
- pass
74
-
75
-
76
- BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
77
-
78
- _struct_cache = {}
79
- _transfer_buffer_cache = {}
80
-
81
-
82
- class BsrMatrix(Generic[_BlockType]):
83
- """Untyped base class for BSR and CSR matrices.
84
-
85
- Should not be constructed directly but through functions such as :func:`bsr_zeros`.
86
-
87
- Attributes:
88
- nrow (int): Number of rows of blocks.
89
- ncol (int): Number of columns of blocks.
90
- nnz (int): Upper bound for the number of non-zero blocks, used for
91
- dimensioning launches. The exact number is at ``offsets[nrow-1]``.
92
- See also :meth:`nnz_sync`.
93
- offsets (Array[int]): Array of size at least ``1 + nrow`` such that the
94
- start and end indices of the blocks of row ``r`` are ``offsets[r]``
95
- and ``offsets[r+1]``, respectively.
96
- columns (Array[int]): Array of size at least equal to ``nnz`` containing
97
- block column indices.
98
- values (Array[BlockType]): Array of size at least equal to ``nnz``
99
- containing block values.
100
- """
101
-
102
- @property
103
- def scalar_type(self) -> Scalar:
104
- """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type."""
105
- return type_scalar_type(self.values.dtype)
106
-
107
- @property
108
- def block_shape(self) -> Tuple[int, int]:
109
- """Shape of the individual blocks."""
110
- return getattr(self.values.dtype, "_shape_", (1, 1))
111
-
112
- @property
113
- def block_size(self) -> int:
114
- """Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
115
- return type_size(self.values.dtype)
116
-
117
- @property
118
- def shape(self) -> Tuple[int, int]:
119
- """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block."""
120
- block_shape = self.block_shape
121
- return (self.nrow * block_shape[0], self.ncol * block_shape[1])
122
-
123
- @property
124
- def dtype(self) -> type:
125
- """Data type for individual block values."""
126
- return self.values.dtype
127
-
128
- @property
129
- def device(self) -> wp.context.Device:
130
- """Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
131
- return self.values.device
132
-
133
- @property
134
- def requires_grad(self) -> bool:
135
- """Read-only property indicating whether the matrix participates in adjoint computations."""
136
- return self.values.requires_grad
137
-
138
- @property
139
- def scalar_values(self) -> wp.array:
140
- """Accesses the ``values`` array as a 3d scalar array."""
141
- values_view = _as_3d_array(self.values, self.block_shape)
142
- values_view._ref = self.values # keep ref in case we're garbage collected
143
- return values_view
144
-
145
- def uncompress_rows(self, out: wp.array = None) -> wp.array:
146
- """Compute the row index for each non-zero block from the compressed row offsets."""
147
- if out is None:
148
- out = wp.empty(self.nnz, dtype=int, device=self.device)
149
-
150
- wp.launch(
151
- kernel=_bsr_get_block_row,
152
- device=self.device,
153
- dim=self.nnz,
154
- inputs=[self.nrow, self.offsets, out],
155
- )
156
- return out
157
-
158
- def nnz_sync(self):
159
- """Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
160
- or, if none has been scheduled yet, starts a new transfer and waits for it to complete.
161
- Then updates the nnz upper bound.
162
-
163
- See also :meth:`copy_nnz_async`.
164
- """
165
-
166
- buf, event = self._nnz_transfer_if_any()
167
- if buf is None:
168
- self.copy_nnz_async()
169
- buf, event = self._nnz_transfer_if_any()
170
-
171
- if event is not None:
172
- wp.synchronize_event(event)
173
- self.nnz = int(buf.numpy()[0])
174
- return self.nnz
175
-
176
- def copy_nnz_async(self) -> None:
177
- """
178
- Start the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
179
-
180
- Needs to be called whenever the offsets array has been modified from outside ``warp.sparse``.
181
-
182
- See also :meth:`nnz_sync`.
183
- """
184
-
185
- buf, event = self._setup_nnz_transfer()
186
- stream = wp.get_stream(self.device) if self.device.is_cuda else None
187
- wp.copy(src=self.offsets, dest=buf, src_offset=self.nrow, count=1, stream=stream)
188
- if event is not None:
189
- stream.record_event(event)
190
-
191
- def _setup_nnz_transfer(self):
192
- buf, event = self._nnz_transfer_if_any()
193
- if buf is not None:
194
- return buf, event
195
-
196
- buf, event = _allocate_transfer_buf(self.device)
197
- if buf is not None:
198
- BsrMatrix.__setattr__(self, "_nnz_transfer", (buf, event))
199
-
200
- return buf, event
201
-
202
- def _nnz_transfer_if_any(self):
203
- return getattr(self, "_nnz_transfer", (None, None))
204
-
205
- def __del__(self):
206
- buf, event = self._nnz_transfer_if_any()
207
- if buf is not None:
208
- _redeem_transfer_buf(self.device, buf, event)
209
-
210
- # Overloaded math operators
211
- def __add__(self, y):
212
- return bsr_axpy(y, bsr_copy(self))
213
-
214
- def __iadd__(self, y):
215
- return bsr_axpy(y, self)
216
-
217
- def __radd__(self, x):
218
- return bsr_axpy(x, bsr_copy(self))
219
-
220
- def __sub__(self, y):
221
- return bsr_axpy(y, bsr_copy(self), alpha=-1.0)
222
-
223
- def __rsub__(self, x):
224
- return bsr_axpy(x, bsr_copy(self), beta=-1.0)
225
-
226
- def __isub__(self, y):
227
- return bsr_axpy(y, self, alpha=-1.0)
228
-
229
- def __mul__(self, y):
230
- return _BsrScalingExpression(self, y)
231
-
232
- def __rmul__(self, x):
233
- return _BsrScalingExpression(self, x)
234
-
235
- def __imul__(self, y):
236
- return bsr_scale(self, y)
237
-
238
- def __matmul__(self, y):
239
- if isinstance(y, wp.array):
240
- return bsr_mv(self, y)
241
-
242
- return bsr_mm(self, y)
243
-
244
- def __rmatmul__(self, x):
245
- if isinstance(x, wp.array):
246
- return bsr_mv(self, x, transpose=True)
247
-
248
- return bsr_mm(x, self)
249
-
250
- def __imatmul__(self, y):
251
- return bsr_mm(self, y, self)
252
-
253
- def __truediv__(self, y):
254
- return _BsrScalingExpression(self, 1.0 / y)
255
-
256
- def __neg__(self):
257
- return _BsrScalingExpression(self, -1.0)
258
-
259
- def transpose(self):
260
- """Return a transposed copy of this matrix."""
261
- return bsr_transposed(self)
262
-
263
-
264
- def _allocate_transfer_buf(device):
265
- if device.ordinal in _transfer_buffer_cache:
266
- all_, pool = _transfer_buffer_cache[device.ordinal]
267
- else:
268
- all_ = []
269
- pool = []
270
- _transfer_buffer_cache[device.ordinal] = (all_, pool)
271
-
272
- if pool:
273
- return pool.pop()
274
-
275
- if device.is_capturing:
276
- return None, None
277
-
278
- buf = wp.empty(dtype=int, shape=(1,), device="cpu", pinned=device.is_cuda)
279
- event = wp.Event(device) if device.is_cuda else None
280
- all_.append((buf, event)) # keep a reference to the buffer and event, prevent garbage collection before redeem
281
- return buf, event
282
-
283
-
284
- def _redeem_transfer_buf(device, buf, event):
285
- all_, pool = _transfer_buffer_cache[device.ordinal]
286
- pool.append((buf, event))
287
-
288
-
289
- def bsr_matrix_t(dtype: BlockType):
290
- dtype = type_to_warp(dtype)
291
-
292
- if not type_is_matrix(dtype) and dtype not in scalar_types:
293
- raise ValueError(f"BsrMatrix block type must be either warp matrix or scalar; got {type_repr(dtype)}")
294
-
295
- class BsrMatrixTyped(BsrMatrix):
296
- nrow: int
297
- """Number of rows of blocks."""
298
- ncol: int
299
- """Number of columns of blocks."""
300
- nnz: int
301
- """Upper bound for the number of non-zeros."""
302
- offsets: wp.array(dtype=int)
303
- """Array of size at least ``1 + nrow``."""
304
- columns: wp.array(dtype=int)
305
- """Array of size at least equal to ``nnz``."""
306
- values: wp.array(dtype=dtype)
307
-
308
- module = wp.get_module(BsrMatrix.__module__)
309
-
310
- if hasattr(dtype, "_shape_"):
311
- type_str = f"{type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
312
- else:
313
- type_str = dtype.__name__
314
- key = f"{BsrMatrix.__qualname__}_{type_str}"
315
-
316
- if key not in _struct_cache:
317
- _struct_cache[key] = wp.codegen.Struct(
318
- key=key,
319
- cls=BsrMatrixTyped,
320
- module=module,
321
- )
322
-
323
- return _struct_cache[key]
324
-
325
-
326
- def bsr_zeros(
327
- rows_of_blocks: int,
328
- cols_of_blocks: int,
329
- block_type: BlockType,
330
- device: wp.context.Devicelike = None,
331
- ) -> BsrMatrix:
332
- """Construct and return an empty BSR or CSR matrix with the given shape.
333
-
334
- Args:
335
- bsr: The BSR or CSR matrix to set to zero.
336
- rows_of_blocks: Number of rows of blocks.
337
- cols_of_blocks: Number of columns of blocks.
338
- block_type: Type of individual blocks.
339
- For CSR matrices, this should be a scalar type.
340
- For BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`).
341
- device: Device on which to allocate the matrix arrays.
342
- """
343
-
344
- bsr = bsr_matrix_t(block_type)()
345
-
346
- bsr.nrow = int(rows_of_blocks)
347
- bsr.ncol = int(cols_of_blocks)
348
- bsr.nnz = 0
349
- bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
350
- bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
351
- bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
352
-
353
- return bsr
354
-
355
-
356
- def _bsr_ensure_fits(bsr: BsrMatrix, nrow: Optional[int] = None, nnz: Optional[int] = None) -> None:
357
- if nrow is None:
358
- nrow = bsr.nrow
359
- if nnz is None:
360
- nnz = bsr.nnz
361
- else:
362
- # update nnz upper bound
363
- bsr.nnz = int(nnz)
364
-
365
- if bsr.offsets.size < nrow + 1:
366
- bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
367
- if bsr.columns.size < nnz:
368
- bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
369
- if bsr.values.size < nnz:
370
- bsr.values = wp.empty(
371
- shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device, requires_grad=bsr.values.requires_grad
372
- )
373
-
374
-
375
- def bsr_set_zero(
376
- bsr: BsrMatrix,
377
- rows_of_blocks: Optional[int] = None,
378
- cols_of_blocks: Optional[int] = None,
379
- ):
380
- """Set a BSR matrix to zero, possibly changing its size.
381
-
382
- Args:
383
- bsr: The BSR or CSR matrix to set to zero.
384
- rows_of_blocks: If not ``None``, the new number of rows of blocks.
385
- cols_of_blocks: If not ``None``, the new number of columns of blocks.
386
- """
387
-
388
- if rows_of_blocks is not None:
389
- bsr.nrow = int(rows_of_blocks)
390
- if cols_of_blocks is not None:
391
- bsr.ncol = int(cols_of_blocks)
392
-
393
- _bsr_ensure_fits(bsr, nnz=0)
394
- bsr.offsets.zero_()
395
- bsr.copy_nnz_async()
396
-
397
-
398
- def _as_3d_array(arr, block_shape):
399
- return wp.array(
400
- ptr=arr.ptr,
401
- capacity=arr.capacity,
402
- device=arr.device,
403
- dtype=type_scalar_type(arr.dtype),
404
- shape=(arr.shape[0], *block_shape),
405
- grad=None if arr.grad is None else _as_3d_array(arr.grad, block_shape),
406
- )
407
-
408
-
409
- def _optional_ctypes_pointer(array: Optional[wp.array], ctype):
410
- return None if array is None else ctypes.cast(array.ptr, ctypes.POINTER(ctype))
411
-
412
-
413
- def _optional_ctypes_event(event: Optional[wp.Event]):
414
- return None if event is None else event.cuda_event
415
-
416
-
417
- _zero_value_masks = {
418
- wp.float16: 0x7FFF,
419
- wp.float32: 0x7FFFFFFF,
420
- wp.float64: 0x7FFFFFFFFFFFFFFF,
421
- wp.int8: 0xFF,
422
- wp.int16: 0xFFFF,
423
- wp.int32: 0xFFFFFFFF,
424
- wp.int64: 0xFFFFFFFFFFFFFFFF,
425
- }
426
-
427
-
428
- @wp.kernel
429
- def _bsr_accumulate_triplet_values(
430
- row_count: int,
431
- tpl_summed_offsets: wp.array(dtype=int),
432
- tpl_summed_indices: wp.array(dtype=int),
433
- tpl_values: wp.array3d(dtype=Any),
434
- bsr_offsets: wp.array(dtype=int),
435
- bsr_values: wp.array3d(dtype=Any),
436
- ):
437
- block, i, j = wp.tid()
438
-
439
- if block >= bsr_offsets[row_count]:
440
- return
441
-
442
- if block == 0:
443
- beg = 0
444
- else:
445
- beg = tpl_summed_offsets[block - 1]
446
- end = tpl_summed_offsets[block]
447
-
448
- val = tpl_values[tpl_summed_indices[beg], i, j]
449
- for k in range(beg + 1, end):
450
- val += tpl_values[tpl_summed_indices[k], i, j]
451
-
452
- bsr_values[block, i, j] = val
453
-
454
-
455
- def bsr_set_from_triplets(
456
- dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
457
- rows: "Array[int]",
458
- columns: "Array[int]",
459
- values: Optional["Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]"] = None,
460
- count: Optional["Array[int]"] = None,
461
- prune_numerical_zeros: bool = True,
462
- masked: bool = False,
463
- ):
464
- """Fill a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
465
-
466
- The first dimension of the three input arrays must match and indicates the number of COO triplets.
467
-
468
- Args:
469
- dest: Sparse matrix to populate.
470
- rows: Row index for each non-zero.
471
- columns: Columns index for each non-zero.
472
- values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
473
- to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
474
- If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
475
- count: Single-element array indicating the number of triplets. If ``None``, the number of triplets is determined from the shape of
476
- ``rows`` and ``columns`` arrays.
477
- prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
478
- masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
479
- """
480
-
481
- if rows.device != columns.device or rows.device != dest.device:
482
- raise ValueError(
483
- f"Rows and columns must reside on the destination matrix device, got {rows.device}, {columns.device} and {dest.device}"
484
- )
485
-
486
- if rows.shape[0] != columns.shape[0]:
487
- raise ValueError(
488
- f"Rows and columns arrays must have the same length, got {rows.shape[0]} and {columns.shape[0]}"
489
- )
490
-
491
- if rows.dtype != wp.int32 or columns.dtype != wp.int32:
492
- raise TypeError("Rows and columns arrays must be of type int32")
493
-
494
- if count is not None:
495
- if count.device != rows.device:
496
- raise ValueError(f"Count and rows must reside on the same device, got {count.device} and {rows.device}")
497
-
498
- if count.shape != (1,):
499
- raise ValueError(f"Count array must be a single-element array, got {count.shape}")
500
-
501
- if count.dtype != wp.int32:
502
- raise TypeError("Count array must be of type int32")
503
-
504
- # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
505
- if values is not None:
506
- if values.device != rows.device:
507
- raise ValueError(f"Values and rows must reside on the same device, got {values.device} and {rows.device}")
508
-
509
- if values.shape[0] != rows.shape[0]:
510
- raise ValueError(
511
- f"Values and rows arrays must have the same length, got {values.shape[0]} and {rows.shape[0]}"
512
- )
513
-
514
- if values.ndim == 1:
515
- if not types_equal(values.dtype, dest.values.dtype):
516
- raise ValueError(
517
- f"Values array type must correspond to that of the dest matrix, got {type_repr(values.dtype)} and {type_repr(dest.values.dtype)}"
518
- )
519
- elif values.ndim == 3:
520
- if values.shape[1:] != dest.block_shape:
521
- raise ValueError(
522
- f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
523
- )
524
-
525
- if type_scalar_type(values.dtype) != dest.scalar_type:
526
- raise ValueError(
527
- f"Scalar type of values array ({type_repr(values.dtype)}) should correspond to that of matrix ({type_repr(dest.scalar_type)})"
528
- )
529
- else:
530
- raise ValueError(f"Number of dimension for values array should be 1 or 3, got {values.ndim}")
531
-
532
- if prune_numerical_zeros and not values.is_contiguous:
533
- raise ValueError("Values array should be contiguous for numerical zero pruning")
534
-
535
- nnz = rows.shape[0]
536
- if nnz == 0:
537
- bsr_set_zero(dest)
538
- return
539
-
540
- # Increase dest array sizes if needed
541
- if not masked:
542
- _bsr_ensure_fits(dest, nnz=nnz)
543
-
544
- device = dest.values.device
545
- scalar_type = dest.scalar_type
546
- zero_value_mask = _zero_value_masks.get(scalar_type, 0) if prune_numerical_zeros else 0
547
-
548
- # compute the BSR topology
549
-
550
- from warp.context import runtime
551
-
552
- if device.is_cpu:
553
- native_func = runtime.core.wp_bsr_matrix_from_triplets_host
554
- else:
555
- native_func = runtime.core.wp_bsr_matrix_from_triplets_device
556
-
557
- nnz_buf, nnz_event = dest._setup_nnz_transfer()
558
- summed_triplet_offsets = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
559
- summed_triplet_indices = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
560
-
561
- with wp.ScopedDevice(device):
562
- native_func(
563
- dest.block_size,
564
- type_size_in_bytes(scalar_type),
565
- dest.nrow,
566
- dest.ncol,
567
- nnz,
568
- _optional_ctypes_pointer(count, ctype=ctypes.c_int32),
569
- ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
570
- ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
571
- _optional_ctypes_pointer(values, ctype=ctypes.c_int32),
572
- zero_value_mask,
573
- masked,
574
- ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
575
- ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
576
- ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
577
- ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
578
- _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
579
- _optional_ctypes_event(nnz_event),
580
- )
581
-
582
- # now accumulate repeated blocks
583
- wp.launch(
584
- _bsr_accumulate_triplet_values,
585
- dim=(nnz, *dest.block_shape),
586
- inputs=[
587
- dest.nrow,
588
- summed_triplet_offsets,
589
- summed_triplet_indices,
590
- _as_3d_array(values, dest.block_shape),
591
- dest.offsets,
592
- ],
593
- outputs=[dest.scalar_values],
594
- )
595
-
596
-
597
- def bsr_from_triplets(
598
- rows_of_blocks: int,
599
- cols_of_blocks: int,
600
- rows: "Array[int]",
601
- columns: "Array[int]",
602
- values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
603
- prune_numerical_zeros: bool = True,
604
- ):
605
- """Constructs a BSR matrix with values defined by coordinate-oriented (COO) triplets.
606
-
607
- The first dimension of the three input arrays must match and indicates the number of COO triplets.
608
-
609
- Args:
610
- rows_of_blocks: Number of rows of blocks.
611
- cols_of_blocks: Number of columns of blocks.
612
- rows: Row index for each non-zero.
613
- columns: Columns index for each non-zero.
614
- values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
615
- to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
616
- prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
617
- """
618
-
619
- if values.ndim == 3:
620
- block_type = wp.mat(shape=values.shape[1:], dtype=values.dtype)
621
- else:
622
- block_type = values.dtype
623
-
624
- A = bsr_zeros(
625
- rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
626
- )
627
- A.values.requires_grad = values.requires_grad
628
- bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
629
- return A
630
-
631
-
632
- class _BsrExpression(Generic[_BlockType]):
633
- pass
634
-
635
-
636
- class _BsrScalingExpression(_BsrExpression):
637
- def __init__(self, mat, scale):
638
- self.mat = mat
639
- self.scale = scale
640
-
641
- def eval(self):
642
- return bsr_copy(self)
643
-
644
- @property
645
- def nrow(self) -> int:
646
- return self.mat.nrow
647
-
648
- @property
649
- def ncol(self) -> int:
650
- return self.mat.ncol
651
-
652
- @property
653
- def nnz(self) -> int:
654
- return self.mat.nnz
655
-
656
- @property
657
- def offsets(self) -> wp.array:
658
- return self.mat.offsets
659
-
660
- @property
661
- def columns(self) -> wp.array:
662
- return self.mat.columns
663
-
664
- @property
665
- def scalar_type(self) -> Scalar:
666
- return self.mat.scalar_type
667
-
668
- @property
669
- def block_shape(self) -> Tuple[int, int]:
670
- return self.mat.block_shape
671
-
672
- @property
673
- def block_size(self) -> int:
674
- return self.mat.block_size
675
-
676
- @property
677
- def shape(self) -> Tuple[int, int]:
678
- return self.mat.shape
679
-
680
- @property
681
- def dtype(self) -> type:
682
- return self.mat.dtype
683
-
684
- @property
685
- def requires_grad(self) -> bool:
686
- return self.mat.requires_grad
687
-
688
- @property
689
- def device(self) -> wp.context.Device:
690
- return self.mat.device
691
-
692
- # Overloaded math operators
693
- def __add__(self, y):
694
- return bsr_axpy(y, bsr_copy(self.mat), alpha=self.scale)
695
-
696
- def __radd__(self, x):
697
- return bsr_axpy(x, bsr_copy(self.mat), beta=self.scale)
698
-
699
- def __sub__(self, y):
700
- return bsr_axpy(y, bsr_copy(self.mat), alpha=-self.scale)
701
-
702
- def __rsub__(self, x):
703
- return bsr_axpy(x, bsr_copy(self.mat), beta=-self.scale)
704
-
705
- def __mul__(self, y):
706
- return _BsrScalingExpression(self.mat, y * self.scale)
707
-
708
- def __rmul__(self, x):
709
- return _BsrScalingExpression(self.mat, x * self.scale)
710
-
711
- def __matmul__(self, y):
712
- if isinstance(y, wp.array):
713
- return bsr_mv(self.mat, y, alpha=self.scale)
714
-
715
- return bsr_mm(self.mat, y, alpha=self.scale)
716
-
717
- def __rmatmul__(self, x):
718
- if isinstance(x, wp.array):
719
- return bsr_mv(self.mat, x, alpha=self.scale, transpose=True)
720
-
721
- return bsr_mm(x, self.mat, alpha=self.scale)
722
-
723
- def __truediv__(self, y):
724
- return _BsrScalingExpression(self.mat, self.scale / y)
725
-
726
- def __neg__(self):
727
- return _BsrScalingExpression(self.mat, -self.scale)
728
-
729
- def transpose(self):
730
- """Returns a transposed copy of this matrix"""
731
- return _BsrScalingExpression(self.mat.transpose(), self.scale)
732
-
733
-
734
- BsrMatrixOrExpression = Union[BsrMatrix[_BlockType], _BsrExpression[_BlockType]]
735
-
736
-
737
- def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
738
- if isinstance(bsr, BsrMatrix):
739
- return bsr, 1.0
740
- if isinstance(bsr, _BsrScalingExpression):
741
- return bsr.mat, bsr.scale
742
-
743
- raise ValueError("Argument cannot be interpreted as a BsrMatrix")
744
-
745
-
746
- @wp.func
747
- def _bsr_row_index(
748
- offsets: wp.array(dtype=int),
749
- row_count: int,
750
- block: int,
751
- ):
752
- """Index of the row containing a block, or -1 if non-existing."""
753
- return wp.where(block < offsets[row_count], wp.lower_bound(offsets, 0, row_count + 1, block + 1), 0) - 1
754
-
755
-
756
- @wp.func
757
- def _bsr_block_index(
758
- row: int,
759
- col: int,
760
- bsr_offsets: wp.array(dtype=int),
761
- bsr_columns: wp.array(dtype=int),
762
- ):
763
- """Index of the block at block-coordinates (row, col), or -1 if non-existing.
764
- Assumes bsr_columns is sorted.
765
- """
766
-
767
- if row < 0:
768
- return -1
769
-
770
- mask_row_beg = bsr_offsets[row]
771
- mask_row_end = bsr_offsets[row + 1]
772
-
773
- if mask_row_beg == mask_row_end:
774
- return -1
775
-
776
- block_index = wp.lower_bound(bsr_columns, mask_row_beg, mask_row_end, col)
777
- return wp.where(bsr_columns[block_index] == col, block_index, -1)
778
-
779
-
780
- @wp.kernel(enable_backward=False)
781
- def _bsr_assign_list_blocks(
782
- src_subrows: int,
783
- src_subcols: int,
784
- dest_subrows: int,
785
- dest_subcols: int,
786
- src_row_count: int,
787
- src_offsets: wp.array(dtype=int),
788
- src_columns: wp.array(dtype=int),
789
- dest_rows: wp.array(dtype=int),
790
- dest_cols: wp.array(dtype=int),
791
- ):
792
- block, subrow, subcol = wp.tid()
793
- dest_block = (block * src_subcols + subcol) * src_subrows + subrow
794
-
795
- row = _bsr_row_index(src_offsets, src_row_count, block)
796
- if row == -1:
797
- dest_rows[dest_block] = row # invalid
798
- dest_cols[dest_block] = row
799
- else:
800
- dest_subrow = row * src_subrows + subrow
801
- dest_subcol = src_columns[block] * src_subcols + subcol
802
- dest_rows[dest_block] = dest_subrow // dest_subrows
803
- dest_cols[dest_block] = dest_subcol // dest_subcols
804
-
805
-
806
- @wp.kernel
807
- def _bsr_assign_copy_blocks(
808
- scale: Any,
809
- src_subrows: int,
810
- src_subcols: int,
811
- dest_subrows: int,
812
- dest_subcols: int,
813
- src_row_count: int,
814
- src_offsets: wp.array(dtype=int),
815
- src_columns: wp.array(dtype=int),
816
- src_values: wp.array3d(dtype=Any),
817
- dest_offsets: wp.array(dtype=int),
818
- dest_columns: wp.array(dtype=int),
819
- dest_values: wp.array3d(dtype=Any),
820
- ):
821
- src_block = wp.tid()
822
- src_block, subrow, subcol = wp.tid()
823
-
824
- src_row = _bsr_row_index(src_offsets, src_row_count, src_block)
825
- if src_row == -1:
826
- return
827
-
828
- src_col = src_columns[src_block]
829
-
830
- dest_subrow = src_row * src_subrows + subrow
831
- dest_subcol = src_col * src_subcols + subcol
832
- dest_row = dest_subrow // dest_subrows
833
- dest_col = dest_subcol // dest_subcols
834
-
835
- dest_block = _bsr_block_index(dest_row, dest_col, dest_offsets, dest_columns)
836
- if dest_block == -1:
837
- return
838
-
839
- split_row = dest_subrow - dest_subrows * dest_row
840
- split_col = dest_subcol - dest_subcols * dest_col
841
-
842
- rows_per_subblock = src_values.shape[1] // src_subrows
843
- cols_per_subblock = src_values.shape[2] // src_subcols
844
-
845
- dest_base_i = split_row * rows_per_subblock
846
- dest_base_j = split_col * cols_per_subblock
847
-
848
- src_base_i = subrow * rows_per_subblock
849
- src_base_j = subcol * cols_per_subblock
850
-
851
- for i in range(rows_per_subblock):
852
- for j in range(cols_per_subblock):
853
- dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
854
- scale * src_values[src_block, i + src_base_i, j + src_base_j]
855
- )
856
-
857
-
858
- def bsr_assign(
859
- dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
860
- src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
861
- structure_only: bool = False,
862
- masked: bool = False,
863
- ):
864
- """Copy the content of the ``src`` BSR matrix to ``dest``.
865
-
866
- Args:
867
- src: Matrix to be copied.
868
- dest: Destination matrix. May have a different block shape or scalar type
869
- than ``src``, in which case the required casting will be performed.
870
- structure_only: If ``True``, only the non-zero indices are copied, and uninitialized value storage is allocated
871
- to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
872
- casting if the two matrices use distinct scalar types.
873
- masked: If ``True``, prevent the assignment operation from adding new non-zero blocks to ``dest``.
874
- """
875
-
876
- src, src_scale = _extract_matrix_and_scale(src)
877
-
878
- if dest.values.device != src.values.device:
879
- raise ValueError("Source and destination matrices must reside on the same device")
880
-
881
- if src.block_shape[0] >= dest.block_shape[0]:
882
- src_subrows = src.block_shape[0] // dest.block_shape[0]
883
- dest_subrows = 1
884
- else:
885
- dest_subrows = dest.block_shape[0] // src.block_shape[0]
886
- src_subrows = 1
887
-
888
- if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
889
- raise ValueError(
890
- f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {dest.block_shape[0]}, {src.block_shape[0]})"
891
- )
892
-
893
- if src.block_shape[1] >= dest.block_shape[1]:
894
- src_subcols = src.block_shape[1] // dest.block_shape[1]
895
- dest_subcols = 1
896
- else:
897
- dest_subcols = dest.block_shape[1] // src.block_shape[1]
898
- src_subcols = 1
899
-
900
- if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
901
- raise ValueError(
902
- f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {dest.block_shape[1]}, {src.block_shape[1]})"
903
- )
904
-
905
- dest_nrow = (src.nrow * src_subrows) // dest_subrows
906
- dest_ncol = (src.ncol * src_subcols) // dest_subcols
907
-
908
- if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
909
- raise ValueError(
910
- f"The requested block shape {dest.block_shape} does not evenly divide the source matrix of total size {src.shape}"
911
- )
912
-
913
- nnz_alloc = src.nnz * src_subrows * src_subcols
914
- if masked:
915
- if dest_nrow != dest.nrow or dest_ncol != dest.ncol:
916
- raise ValueError(
917
- f"Incompatible destination matrix size, expected ({dest_nrow}, {dest_ncol}), got ({dest.nrow}, {dest.ncol})"
918
- )
919
- else:
920
- dest.nrow = dest_nrow
921
- dest.ncol = dest_ncol
922
- _bsr_ensure_fits(dest, nnz=nnz_alloc)
923
-
924
- if dest.block_shape == src.block_shape and not masked:
925
- # Direct copy
926
-
927
- wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
928
- dest.copy_nnz_async()
929
-
930
- if nnz_alloc > 0:
931
- wp.copy(dest=dest.columns, src=src.columns, count=nnz_alloc)
932
-
933
- if not structure_only:
934
- warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
935
- bsr_scale(dest, src_scale)
936
-
937
- else:
938
- # Masked and/or multiple src blocks per dest block, go through COO format
939
-
940
- # Compute destination rows and columns
941
- dest_rows = wp.empty(nnz_alloc, dtype=int, device=dest.device)
942
- dest_cols = wp.empty(nnz_alloc, dtype=int, device=dest.device)
943
- wp.launch(
944
- _bsr_assign_list_blocks,
945
- dim=(src.nnz, src_subrows, src_subcols),
946
- device=dest.device,
947
- inputs=[
948
- src_subrows,
949
- src_subcols,
950
- dest_subrows,
951
- dest_subcols,
952
- src.nrow,
953
- src.offsets,
954
- src.columns,
955
- dest_rows,
956
- dest_cols,
957
- ],
958
- )
959
-
960
- # Compute destination offsets from triplets
961
- from warp.context import runtime
962
-
963
- if dest.device.is_cpu:
964
- native_func = runtime.core.wp_bsr_matrix_from_triplets_host
965
- else:
966
- native_func = runtime.core.wp_bsr_matrix_from_triplets_device
967
-
968
- nnz_buf, nnz_event = dest._setup_nnz_transfer()
969
- with wp.ScopedDevice(dest.device):
970
- native_func(
971
- dest.block_size,
972
- 0, # scalar_size_in_bytes
973
- dest.nrow,
974
- dest.ncol,
975
- nnz_alloc,
976
- None, # device nnz
977
- ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
978
- ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
979
- None, # triplet values
980
- 0, # zero_value_mask
981
- masked,
982
- None, # summed block offsets
983
- None, # summed block indices
984
- ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
985
- ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
986
- _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
987
- _optional_ctypes_event(nnz_event),
988
- )
989
-
990
- # merge block values
991
- if not structure_only:
992
- dest.values.zero_()
993
- wp.launch(
994
- _bsr_assign_copy_blocks,
995
- dim=(src.nnz, src_subrows, src_subcols),
996
- device=dest.device,
997
- inputs=[
998
- src.scalar_type(src_scale),
999
- src_subrows,
1000
- src_subcols,
1001
- dest_subrows,
1002
- dest_subcols,
1003
- src.nrow,
1004
- src.offsets,
1005
- src.columns,
1006
- src.scalar_values,
1007
- dest.offsets,
1008
- dest.columns,
1009
- dest.scalar_values,
1010
- ],
1011
- )
1012
-
1013
-
1014
- def bsr_copy(
1015
- A: BsrMatrixOrExpression,
1016
- scalar_type: Optional[Scalar] = None,
1017
- block_shape: Optional[Tuple[int, int]] = None,
1018
- structure_only: bool = False,
1019
- ):
1020
- """Return a copy of matrix ``A``, possibly changing its scalar type.
1021
-
1022
- Args:
1023
- A: Matrix to be copied.
1024
- scalar_type: If provided, the returned matrix will use this scalar type instead of the one from ``A``.
1025
- block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from ``A``.
1026
- Both dimensions of ``block_shape`` must be either a multiple or an exact divider of the ones from ``A``.
1027
- structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
1028
- to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
1029
- casting if the two matrices use distinct scalar types.
1030
- """
1031
- if scalar_type is None:
1032
- scalar_type = A.scalar_type
1033
- if block_shape is None:
1034
- block_shape = A.block_shape
1035
-
1036
- if block_shape == (1, 1):
1037
- block_type = scalar_type
1038
- else:
1039
- block_type = wp.mat(shape=block_shape, dtype=scalar_type)
1040
-
1041
- copy = bsr_zeros(
1042
- rows_of_blocks=A.nrow,
1043
- cols_of_blocks=A.ncol,
1044
- block_type=block_type,
1045
- device=A.device,
1046
- )
1047
- copy.values.requires_grad = A.requires_grad
1048
- bsr_assign(dest=copy, src=A, structure_only=structure_only)
1049
- return copy
1050
-
1051
-
1052
- @wp.kernel
1053
- def _bsr_transpose_values(
1054
- col_count: int,
1055
- scale: Any,
1056
- bsr_values: wp.array3d(dtype=Any),
1057
- block_index_map: wp.array(dtype=int),
1058
- transposed_bsr_offsets: wp.array(dtype=int),
1059
- transposed_bsr_values: wp.array3d(dtype=Any),
1060
- ):
1061
- block, i, j = wp.tid()
1062
-
1063
- if block >= transposed_bsr_offsets[col_count]:
1064
- return
1065
-
1066
- transposed_bsr_values[block, i, j] = bsr_values[block_index_map[block], j, i] * scale
1067
-
1068
-
1069
- def bsr_set_transpose(
1070
- dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
1071
- src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
1072
- ):
1073
- """Assign the transposed matrix ``src`` to matrix ``dest``."""
1074
-
1075
- src, src_scale = _extract_matrix_and_scale(src)
1076
-
1077
- if dest.values.device != src.values.device:
1078
- raise ValueError(
1079
- f"All arguments must reside on the same device, got {dest.values.device} and {src.values.device}"
1080
- )
1081
-
1082
- if dest.scalar_type != src.scalar_type:
1083
- raise ValueError(f"All arguments must have the same scalar type, got {dest.scalar_type} and {src.scalar_type}")
1084
-
1085
- transpose_block_shape = src.block_shape[::-1]
1086
-
1087
- if dest.block_shape != transpose_block_shape:
1088
- raise ValueError(f"Destination block shape must be {transpose_block_shape}, got {dest.block_shape}")
1089
-
1090
- nnz = src.nnz
1091
- dest.nrow = src.ncol
1092
- dest.ncol = src.nrow
1093
-
1094
- if nnz == 0:
1095
- bsr_set_zero(dest)
1096
- return
1097
-
1098
- # Increase dest array sizes if needed
1099
- _bsr_ensure_fits(dest, nnz=nnz)
1100
-
1101
- from warp.context import runtime
1102
-
1103
- if dest.values.device.is_cpu:
1104
- native_func = runtime.core.wp_bsr_transpose_host
1105
- else:
1106
- native_func = runtime.core.wp_bsr_transpose_device
1107
-
1108
- block_index_map = wp.empty(shape=2 * nnz, dtype=int, device=src.device)
1109
-
1110
- with wp.ScopedDevice(dest.device):
1111
- native_func(
1112
- src.nrow,
1113
- src.ncol,
1114
- nnz,
1115
- ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1116
- ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1117
- ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1118
- ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1119
- ctypes.cast(block_index_map.ptr, ctypes.POINTER(ctypes.c_int32)),
1120
- )
1121
-
1122
- dest.copy_nnz_async()
1123
-
1124
- wp.launch(
1125
- _bsr_transpose_values,
1126
- dim=(nnz, *dest.block_shape),
1127
- device=dest.device,
1128
- inputs=[src.ncol, dest.scalar_type(src_scale), src.scalar_values, block_index_map, dest.offsets],
1129
- outputs=[dest.scalar_values],
1130
- )
1131
-
1132
-
1133
- def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
1134
- """Return a copy of the transposed matrix ``A``."""
1135
-
1136
- if A.block_shape == (1, 1):
1137
- block_type = A.values.dtype
1138
- else:
1139
- block_type = wp.mat(shape=A.block_shape[::-1], dtype=A.scalar_type)
1140
-
1141
- transposed = bsr_zeros(
1142
- rows_of_blocks=A.ncol,
1143
- cols_of_blocks=A.nrow,
1144
- block_type=block_type,
1145
- device=A.device,
1146
- )
1147
- transposed.values.requires_grad = A.requires_grad
1148
- bsr_set_transpose(dest=transposed, src=A)
1149
- return transposed
1150
-
1151
-
1152
- @wp.kernel
1153
- def _bsr_get_diag_kernel(
1154
- scale: Any,
1155
- A_offsets: wp.array(dtype=int),
1156
- A_columns: wp.array(dtype=int),
1157
- A_values: wp.array3d(dtype=Any),
1158
- out: wp.array3d(dtype=Any),
1159
- ):
1160
- row, br, bc = wp.tid()
1161
-
1162
- diag = _bsr_block_index(row, row, A_offsets, A_columns)
1163
- if diag != -1:
1164
- out[row, br, bc] = scale * A_values[diag, br, bc]
1165
-
1166
-
1167
- def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
1168
- """Return the array of blocks that constitute the diagonal of a sparse matrix.
1169
-
1170
- Args:
1171
- A: The sparse matrix from which to extract the diagonal.
1172
- out: If provided, the array into which to store the diagonal blocks.
1173
- """
1174
-
1175
- A, scale = _extract_matrix_and_scale(A)
1176
-
1177
- dim = min(A.nrow, A.ncol)
1178
-
1179
- if out is None:
1180
- out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
1181
- else:
1182
- if not types_equal(out.dtype, A.values.dtype):
1183
- raise ValueError(f"Output array must have type {A.values.dtype}, got {out.dtype}")
1184
- if out.device != A.values.device:
1185
- raise ValueError(f"Output array must reside on device {A.values.device}, got {out.device}")
1186
- if out.shape[0] < dim:
1187
- raise ValueError(f"Output array must be of length at least {dim}, got {out.shape[0]}")
1188
-
1189
- wp.launch(
1190
- kernel=_bsr_get_diag_kernel,
1191
- dim=(dim, *A.block_shape),
1192
- device=A.values.device,
1193
- inputs=[A.scalar_type(scale), A.offsets, A.columns, A.scalar_values, _as_3d_array(out, A.block_shape)],
1194
- )
1195
-
1196
- return out
1197
-
1198
-
1199
- @wp.kernel(enable_backward=False)
1200
- def _bsr_set_diag_kernel(
1201
- nnz: int,
1202
- A_offsets: wp.array(dtype=int),
1203
- A_columns: wp.array(dtype=int),
1204
- ):
1205
- row = wp.tid()
1206
- A_offsets[row] = wp.min(row, nnz)
1207
- if row < nnz:
1208
- A_columns[row] = row
1209
-
1210
-
1211
- def bsr_set_diag(
1212
- A: BsrMatrix[BlockType],
1213
- diag: "Union[BlockType, Array[BlockType]]",
1214
- rows_of_blocks: Optional[int] = None,
1215
- cols_of_blocks: Optional[int] = None,
1216
- ) -> None:
1217
- """Set ``A`` as a block-diagonal matrix.
1218
-
1219
- Args:
1220
- A: The sparse matrix to modify.
1221
- diag: Specifies the values for diagonal blocks. Can be one of:
1222
-
1223
- - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1224
- - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1225
- - ``None``: Diagonal block values are left uninitialized
1226
-
1227
- rows_of_blocks: If not ``None``, the new number of rows of blocks.
1228
- cols_of_blocks: If not ``None``, the new number of columns of blocks.
1229
-
1230
- The shape of the matrix will be defined one of the following, in this order:
1231
-
1232
- - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1233
- If only one is given, the second is assumed equal.
1234
- - The first dimension of ``diag``, if ``diag`` is an array
1235
- - The current dimensions of ``A`` otherwise
1236
- """
1237
-
1238
- if rows_of_blocks is None and cols_of_blocks is not None:
1239
- rows_of_blocks = cols_of_blocks
1240
- if cols_of_blocks is None and rows_of_blocks is not None:
1241
- cols_of_blocks = rows_of_blocks
1242
-
1243
- if is_array(diag):
1244
- if rows_of_blocks is None:
1245
- rows_of_blocks = diag.shape[0]
1246
- cols_of_blocks = diag.shape[0]
1247
-
1248
- if rows_of_blocks is not None:
1249
- A.nrow = rows_of_blocks
1250
- A.ncol = cols_of_blocks
1251
-
1252
- nnz = min(A.nrow, A.ncol)
1253
- _bsr_ensure_fits(A, nnz=nnz)
1254
-
1255
- wp.launch(
1256
- kernel=_bsr_set_diag_kernel,
1257
- dim=nnz + 1,
1258
- device=A.offsets.device,
1259
- inputs=[nnz, A.offsets, A.columns],
1260
- )
1261
-
1262
- if is_array(diag):
1263
- wp.copy(src=diag, dest=A.values, count=nnz)
1264
- elif diag is not None:
1265
- A.values.fill_(diag)
1266
-
1267
- A.copy_nnz_async()
1268
-
1269
-
1270
- def bsr_diag(
1271
- diag: Optional[Union[BlockType, Array[BlockType]]] = None,
1272
- rows_of_blocks: Optional[int] = None,
1273
- cols_of_blocks: Optional[int] = None,
1274
- block_type: Optional[BlockType] = None,
1275
- device=None,
1276
- ) -> BsrMatrix["BlockType"]:
1277
- """Create and return a block-diagonal BSR matrix from an given block value or array of block values.
1278
-
1279
- Args:
1280
- diag: Specifies the values for diagonal blocks. Can be one of:
1281
-
1282
- - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1283
- - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1284
- rows_of_blocks: If not ``None``, the new number of rows of blocks
1285
- cols_of_blocks: If not ``None``, the new number of columns of blocks
1286
- block_type: If ``diag`` is ``None``, block type of the matrix. Otherwise deduced from ``diag``
1287
- device: If ``diag`` is not a Warp array, device on which to allocate the matrix. Otherwise deduced from ``diag``
1288
-
1289
- The shape of the matrix will be defined one of the following, in this order:
1290
-
1291
- - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1292
- If only one is given, the second is assumed equal.
1293
- - The first dimension of ``diag`` if ``diag`` is an array.
1294
- """
1295
-
1296
- if rows_of_blocks is None and cols_of_blocks is not None:
1297
- rows_of_blocks = cols_of_blocks
1298
- if cols_of_blocks is None and rows_of_blocks is not None:
1299
- cols_of_blocks = rows_of_blocks
1300
-
1301
- if is_array(diag):
1302
- if rows_of_blocks is None:
1303
- rows_of_blocks = diag.shape[0]
1304
- cols_of_blocks = diag.shape[0]
1305
-
1306
- block_type = diag.dtype
1307
- device = diag.device
1308
- else:
1309
- if rows_of_blocks is None:
1310
- raise ValueError(
1311
- "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
1312
- )
1313
-
1314
- if block_type is None:
1315
- if diag is None:
1316
- raise ValueError("Either `diag` or `block_type` needs to be provided")
1317
-
1318
- block_type = type(diag)
1319
- if not type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
1320
- block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
1321
-
1322
- A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
1323
- if is_array(diag):
1324
- A.values.requires_grad = diag.requires_grad
1325
- bsr_set_diag(A, diag)
1326
- return A
1327
-
1328
-
1329
- def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None) -> None:
1330
- """Set ``A`` as the identity matrix.
1331
-
1332
- Args:
1333
- A: The sparse matrix to modify.
1334
- rows_of_blocks: If provided, the matrix will be resized as a square
1335
- matrix with ``rows_of_blocks`` rows and columns.
1336
- """
1337
-
1338
- if A.block_shape == (1, 1):
1339
- identity = A.scalar_type(1.0)
1340
- else:
1341
- from numpy import eye
1342
-
1343
- identity = eye(A.block_shape[0])
1344
-
1345
- bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
1346
-
1347
-
1348
- def bsr_identity(
1349
- rows_of_blocks: int,
1350
- block_type: BlockType[Rows, Rows, Scalar],
1351
- device: wp.context.Devicelike = None,
1352
- ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
1353
- """Create and return a square identity matrix.
1354
-
1355
- Args:
1356
- rows_of_blocks: Number of rows and columns of blocks in the created matrix.
1357
- block_type: Block type for the newly created matrix. Must be square
1358
- device: Device onto which to allocate the data arrays
1359
- """
1360
- A = bsr_zeros(
1361
- rows_of_blocks=rows_of_blocks,
1362
- cols_of_blocks=rows_of_blocks,
1363
- block_type=block_type,
1364
- device=device,
1365
- )
1366
- bsr_set_identity(A)
1367
- return A
1368
-
1369
-
1370
- @wp.kernel
1371
- def _bsr_scale_kernel(
1372
- alpha: Any,
1373
- values: wp.array(dtype=Any),
1374
- ):
1375
- row = wp.tid()
1376
- values[row] = alpha * values[row]
1377
-
1378
-
1379
- @wp.kernel
1380
- def _bsr_scale_kernel(
1381
- alpha: Any,
1382
- values: wp.array3d(dtype=Any),
1383
- ):
1384
- row, br, bc = wp.tid()
1385
- values[row, br, bc] = alpha * values[row, br, bc]
1386
-
1387
-
1388
- def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1389
- """Perform the operation ``x := alpha * x`` on BSR matrix ``x`` and return ``x``."""
1390
-
1391
- x, scale = _extract_matrix_and_scale(x)
1392
- alpha *= scale
1393
-
1394
- if alpha != 1.0 and x.nnz > 0:
1395
- if alpha == 0.0:
1396
- bsr_set_zero(x)
1397
- else:
1398
- alpha = x.scalar_type(alpha)
1399
-
1400
- wp.launch(
1401
- kernel=_bsr_scale_kernel,
1402
- dim=(x.nnz, *x.block_shape),
1403
- device=x.values.device,
1404
- inputs=[alpha, x.scalar_values],
1405
- )
1406
-
1407
- return x
1408
-
1409
-
1410
- @wp.kernel(enable_backward=False)
1411
- def _bsr_get_block_row(row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
1412
- block = wp.tid()
1413
- rows[block] = _bsr_row_index(bsr_offsets, row_count, block)
1414
-
1415
-
1416
- @wp.kernel
1417
- def _bsr_axpy_add_block(
1418
- src_offset: int,
1419
- scale: Any,
1420
- rows: wp.array(dtype=int),
1421
- cols: wp.array(dtype=int),
1422
- dst_offsets: wp.array(dtype=int),
1423
- dst_columns: wp.array(dtype=int),
1424
- src_values: wp.array3d(dtype=Any),
1425
- dst_values: wp.array3d(dtype=Any),
1426
- ):
1427
- i, br, bc = wp.tid()
1428
- row = rows[i + src_offset]
1429
- col = cols[i + src_offset]
1430
-
1431
- block = _bsr_block_index(row, col, dst_offsets, dst_columns)
1432
- if block != -1:
1433
- dst_values[block, br, bc] += scale * src_values[i, br, bc]
1434
-
1435
-
1436
- class bsr_axpy_work_arrays:
1437
- """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls."""
1438
-
1439
- def __init__(self):
1440
- self._reset(None)
1441
-
1442
- def _reset(self, device):
1443
- self.device = device
1444
- self._sum_rows = None
1445
- self._sum_cols = None
1446
- self._old_y_values = None
1447
- self._old_x_values = None
1448
-
1449
- def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
1450
- if self.device != device:
1451
- self._reset(device)
1452
-
1453
- if self._sum_rows is None or self._sum_rows.size < sum_nnz:
1454
- self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
1455
- if self._sum_cols is None or self._sum_cols.size < sum_nnz:
1456
- self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
1457
-
1458
- if self._old_y_values is None or self._old_y_values.size < y.nnz:
1459
- self._old_y_values = wp.empty_like(y.values[: y.nnz])
1460
-
1461
-
1462
- def bsr_axpy(
1463
- x: BsrMatrixOrExpression,
1464
- y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
1465
- alpha: Scalar = 1.0,
1466
- beta: Scalar = 1.0,
1467
- masked: bool = False,
1468
- work_arrays: Optional[bsr_axpy_work_arrays] = None,
1469
- ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1470
- """
1471
- Perform the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices ``x`` and ``y`` and return ``y``.
1472
-
1473
- The ``x`` and ``y`` matrices are allowed to alias.
1474
-
1475
- Args:
1476
- x: Read-only first operand.
1477
- y: Mutable second operand and output matrix. If ``y`` is not provided, it will be allocated and treated as zero.
1478
- alpha: Uniform scaling factor for ``x``.
1479
- beta: Uniform scaling factor for ``y``.
1480
- masked: If ``True``, discard all blocks from ``x`` which are not
1481
- existing non-zeros of ``y``.
1482
- work_arrays: In most cases, this function will require the use of temporary storage.
1483
- This storage can be reused across calls by passing an instance of
1484
- :class:`bsr_axpy_work_arrays` in ``work_arrays``.
1485
- """
1486
-
1487
- x, x_scale = _extract_matrix_and_scale(x)
1488
- alpha *= x_scale
1489
-
1490
- if y is None:
1491
- if masked:
1492
- raise ValueError("Left-hand-side 'y' matrix must be provided for masked addition")
1493
-
1494
- # If not output matrix is provided, allocate it for convenience
1495
- y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
1496
- y.values.requires_grad = x.requires_grad
1497
- beta = 0.0
1498
-
1499
- x_nnz = x.nnz
1500
- y_nnz = y.nnz
1501
-
1502
- # Handle easy cases first
1503
- if beta == 0.0 or y_nnz == 0:
1504
- bsr_assign(src=x, dest=y)
1505
- return bsr_scale(y, alpha=alpha)
1506
-
1507
- if alpha == 0.0 or x_nnz == 0:
1508
- return bsr_scale(y, alpha=beta)
1509
-
1510
- if not isinstance(alpha, y.scalar_type):
1511
- alpha = y.scalar_type(alpha)
1512
- if not isinstance(beta, y.scalar_type):
1513
- beta = y.scalar_type(beta)
1514
-
1515
- if x == y:
1516
- # Aliasing case
1517
- return bsr_scale(y, alpha=alpha.value + beta.value)
1518
-
1519
- # General case
1520
-
1521
- if x.values.device != y.values.device:
1522
- raise ValueError(f"All arguments must reside on the same device, got {x.values.device} and {y.values.device}")
1523
-
1524
- if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
1525
- raise ValueError(
1526
- f"Matrices must have the same block type, got ({x.block_shape}, {x.scalar_type}) and ({y.block_shape}, {y.scalar_type})"
1527
- )
1528
-
1529
- if x.nrow != y.nrow or x.ncol != y.ncol:
1530
- raise ValueError(
1531
- f"Matrices must have the same number of rows and columns, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
1532
- )
1533
-
1534
- if work_arrays is None:
1535
- work_arrays = bsr_axpy_work_arrays()
1536
-
1537
- sum_nnz = x_nnz + y_nnz
1538
- device = y.values.device
1539
- work_arrays._allocate(device, y, sum_nnz)
1540
-
1541
- wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
1542
- y.uncompress_rows(out=work_arrays._sum_rows)
1543
-
1544
- wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
1545
- x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
1546
-
1547
- # Save old y values before overwriting matrix
1548
- wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
1549
-
1550
- # Increase dest array sizes if needed
1551
- if not masked:
1552
- _bsr_ensure_fits(y, nnz=sum_nnz)
1553
-
1554
- from warp.context import runtime
1555
-
1556
- if device.is_cpu:
1557
- native_func = runtime.core.wp_bsr_matrix_from_triplets_host
1558
- else:
1559
- native_func = runtime.core.wp_bsr_matrix_from_triplets_device
1560
-
1561
- old_y_nnz = y_nnz
1562
- nnz_buf, nnz_event = y._setup_nnz_transfer()
1563
-
1564
- with wp.ScopedDevice(y.device):
1565
- native_func(
1566
- y.block_size,
1567
- 0, # scalar_size_in_bytes
1568
- y.nrow,
1569
- y.ncol,
1570
- sum_nnz,
1571
- None, # device nnz
1572
- ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1573
- ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1574
- None, # triplet values
1575
- 0, # zero_value_mask
1576
- masked,
1577
- None, # summed block offsets
1578
- None, # summed block indices
1579
- ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1580
- ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1581
- _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
1582
- _optional_ctypes_event(nnz_event),
1583
- )
1584
-
1585
- y.values.zero_()
1586
-
1587
- wp.launch(
1588
- kernel=_bsr_axpy_add_block,
1589
- device=device,
1590
- dim=(old_y_nnz, y.block_shape[0], y.block_shape[1]),
1591
- inputs=[
1592
- 0,
1593
- beta,
1594
- work_arrays._sum_rows,
1595
- work_arrays._sum_cols,
1596
- y.offsets,
1597
- y.columns,
1598
- _as_3d_array(work_arrays._old_y_values, y.block_shape),
1599
- y.scalar_values,
1600
- ],
1601
- )
1602
-
1603
- wp.launch(
1604
- kernel=_bsr_axpy_add_block,
1605
- device=device,
1606
- dim=(x_nnz, y.block_shape[0], y.block_shape[1]),
1607
- inputs=[
1608
- old_y_nnz,
1609
- alpha,
1610
- work_arrays._sum_rows,
1611
- work_arrays._sum_cols,
1612
- y.offsets,
1613
- y.columns,
1614
- x.scalar_values,
1615
- y.scalar_values,
1616
- ],
1617
- )
1618
-
1619
- return y
1620
-
1621
-
1622
- def make_bsr_mm_count_coeffs(tile_size):
1623
- from warp.fem.cache import dynamic_kernel
1624
-
1625
- @dynamic_kernel(suffix=tile_size)
1626
- def bsr_mm_count_coeffs(
1627
- y_ncol: int,
1628
- z_nnz: int,
1629
- x_offsets: wp.array(dtype=int),
1630
- x_columns: wp.array(dtype=int),
1631
- y_offsets: wp.array(dtype=int),
1632
- y_columns: wp.array(dtype=int),
1633
- row_min: wp.array(dtype=int),
1634
- block_counts: wp.array(dtype=int),
1635
- ):
1636
- row, lane = wp.tid()
1637
- row_count = int(0)
1638
-
1639
- x_beg = x_offsets[row]
1640
- x_end = x_offsets[row + 1]
1641
-
1642
- min_col = y_ncol
1643
- max_col = int(0)
1644
-
1645
- for x_block in range(x_beg + lane, x_end, tile_size):
1646
- x_col = x_columns[x_block]
1647
- y_row_end = y_offsets[x_col + 1]
1648
- y_row_beg = y_offsets[x_col]
1649
- block_count = y_row_end - y_row_beg
1650
- if block_count != 0:
1651
- min_col = wp.min(y_columns[y_row_beg], min_col)
1652
- max_col = wp.max(y_columns[y_row_end - 1], max_col)
1653
-
1654
- block_counts[x_block + 1] = block_count
1655
- row_count += block_count
1656
-
1657
- if wp.static(tile_size) > 1:
1658
- row_count = wp.tile_sum(wp.tile(row_count))[0]
1659
- min_col = wp.tile_min(wp.tile(min_col))[0]
1660
- max_col = wp.tile_max(wp.tile(max_col))[0]
1661
- col_range_size = wp.max(0, max_col - min_col + 1)
1662
-
1663
- if row_count > col_range_size:
1664
- # Optimization for deep products.
1665
- # Do not store the whole whole list of src product terms, they would be highly redundant
1666
- # Instead just mark a range in the output matrix
1667
-
1668
- if lane == 0:
1669
- row_min[row] = min_col
1670
- block_counts[x_end] = col_range_size
1671
-
1672
- for x_block in range(x_beg + lane, x_end - 1, tile_size):
1673
- block_counts[x_block + 1] = 0
1674
- elif lane == 0:
1675
- row_min[row] = -1
1676
-
1677
- if lane == 0 and row == 0:
1678
- block_counts[0] = z_nnz
1679
-
1680
- return bsr_mm_count_coeffs
1681
-
1682
-
1683
- @wp.kernel(enable_backward=False)
1684
- def _bsr_mm_list_coeffs(
1685
- copied_z_nnz: int,
1686
- x_nrow: int,
1687
- x_offsets: wp.array(dtype=int),
1688
- x_columns: wp.array(dtype=int),
1689
- y_offsets: wp.array(dtype=int),
1690
- y_columns: wp.array(dtype=int),
1691
- mm_row_min: wp.array(dtype=int),
1692
- mm_offsets: wp.array(dtype=int),
1693
- mm_rows: wp.array(dtype=int),
1694
- mm_cols: wp.array(dtype=int),
1695
- mm_src_blocks: wp.array(dtype=int),
1696
- ):
1697
- mm_block = wp.tid() + copied_z_nnz
1698
-
1699
- x_nnz = x_offsets[x_nrow]
1700
- x_block = wp.lower_bound(mm_offsets, 0, x_nnz + 1, mm_block + 1) - 1
1701
- pos = mm_block - mm_offsets[x_block]
1702
-
1703
- row = _bsr_row_index(x_offsets, x_nrow, x_block)
1704
-
1705
- row_min_col = mm_row_min[row]
1706
- if row_min_col == -1:
1707
- x_col = x_columns[x_block]
1708
- y_beg = y_offsets[x_col]
1709
- y_block = y_beg + pos
1710
- col = y_columns[y_block]
1711
- src_block = x_block
1712
- else:
1713
- col = row_min_col + pos
1714
- src_block = -1
1715
-
1716
- mm_cols[mm_block] = col
1717
- mm_rows[mm_block] = row
1718
- mm_src_blocks[mm_block] = src_block
1719
-
1720
-
1721
- @wp.func
1722
- def _bsr_mm_use_triplets(
1723
- row: int,
1724
- mm_block: int,
1725
- mm_row_min: wp.array(dtype=int),
1726
- row_offsets: wp.array(dtype=int),
1727
- summed_triplet_offsets: wp.array(dtype=int),
1728
- ):
1729
- x_beg = row_offsets[row]
1730
- x_end = row_offsets[row + 1]
1731
-
1732
- if mm_row_min:
1733
- if mm_row_min[row] == -1:
1734
- if mm_block == 0:
1735
- block_beg = 0
1736
- else:
1737
- block_beg = summed_triplet_offsets[mm_block - 1]
1738
- block_end = summed_triplet_offsets[mm_block]
1739
-
1740
- if x_end - x_beg > 3 * (block_end - block_beg):
1741
- return True, block_beg, block_end
1742
-
1743
- return False, x_beg, x_end
1744
-
1745
-
1746
- @wp.kernel(enable_backward=False)
1747
- def _bsr_mm_compute_values(
1748
- alpha: Any,
1749
- x_offsets: wp.array(dtype=int),
1750
- x_columns: wp.array(dtype=int),
1751
- x_values: wp.array(dtype=Any),
1752
- y_offsets: wp.array(dtype=int),
1753
- y_columns: wp.array(dtype=int),
1754
- y_values: wp.array(dtype=Any),
1755
- mm_row_min: wp.array(dtype=int),
1756
- summed_triplet_offsets: wp.array(dtype=int),
1757
- summed_triplet_src_blocks: wp.indexedarray(dtype=int),
1758
- mm_row_count: int,
1759
- mm_offsets: wp.array(dtype=int),
1760
- mm_cols: wp.array(dtype=int),
1761
- mm_values: wp.array(dtype=Any),
1762
- ):
1763
- mm_block = wp.tid()
1764
-
1765
- row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
1766
- if row == -1:
1767
- return
1768
-
1769
- use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
1770
- row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
1771
- )
1772
-
1773
- mm_val = mm_values.dtype(type(alpha)(0.0))
1774
- col = mm_cols[mm_block]
1775
- if use_triplets:
1776
- for tpl_idx in range(block_beg, block_end):
1777
- x_block = summed_triplet_src_blocks[tpl_idx]
1778
- x_col = x_columns[x_block]
1779
- if x_block != -1:
1780
- y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
1781
- mm_val += x_values[x_block] * y_values[y_block]
1782
- else:
1783
- for x_block in range(block_beg, block_end):
1784
- x_col = x_columns[x_block]
1785
- y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
1786
- if y_block != -1:
1787
- mm_val += x_values[x_block] * y_values[y_block]
1788
-
1789
- mm_values[mm_block] += alpha * mm_val
1790
-
1791
-
1792
- def make_bsr_mm_compute_values_tiled_outer(subblock_rows, subblock_cols, block_depth, scalar_type, tile_size):
1793
- from warp.fem.cache import dynamic_func, dynamic_kernel
1794
-
1795
- mm_type = wp.mat(dtype=scalar_type, shape=(subblock_rows, subblock_cols))
1796
-
1797
- x_col_vec_t = wp.vec(dtype=scalar_type, length=subblock_rows)
1798
- y_row_vec_t = wp.vec(dtype=scalar_type, length=subblock_cols)
1799
-
1800
- suffix = f"{subblock_rows}{subblock_cols}{block_depth}{tile_size}{scalar_type.__name__}"
1801
-
1802
- @dynamic_func(suffix=suffix)
1803
- def _outer_product(
1804
- x_values: wp.array2d(dtype=scalar_type),
1805
- y_values: wp.array2d(dtype=scalar_type),
1806
- brow_off: int,
1807
- bcol_off: int,
1808
- block_col: int,
1809
- brow_count: int,
1810
- bcol_count: int,
1811
- ):
1812
- x_col_vec = x_col_vec_t()
1813
- y_row_vec = y_row_vec_t()
1814
-
1815
- for k in range(brow_count):
1816
- x_col_vec[k] = x_values[brow_off + k, block_col]
1817
- for k in range(bcol_count):
1818
- y_row_vec[k] = y_values[block_col, bcol_off + k]
1819
-
1820
- return wp.outer(x_col_vec, y_row_vec)
1821
-
1822
- @dynamic_kernel(suffix=suffix, kernel_options={"enable_backward": False})
1823
- def bsr_mm_compute_values(
1824
- alpha: scalar_type,
1825
- x_offsets: wp.array(dtype=int),
1826
- x_columns: wp.array(dtype=int),
1827
- x_values: wp.array3d(dtype=scalar_type),
1828
- y_offsets: wp.array(dtype=int),
1829
- y_columns: wp.array(dtype=int),
1830
- y_values: wp.array3d(dtype=scalar_type),
1831
- mm_row_min: wp.array(dtype=int),
1832
- summed_triplet_offsets: wp.array(dtype=int),
1833
- summed_triplet_src_blocks: wp.indexedarray(dtype=int),
1834
- mm_row_count: int,
1835
- mm_offsets: wp.array(dtype=int),
1836
- mm_cols: wp.array(dtype=int),
1837
- mm_values: wp.array3d(dtype=scalar_type),
1838
- ):
1839
- mm_block, subrow, subcol, lane = wp.tid()
1840
-
1841
- brow_off = subrow * wp.static(subblock_rows)
1842
- bcol_off = subcol * wp.static(subblock_cols)
1843
-
1844
- brow_count = wp.min(mm_values.shape[1] - brow_off, subblock_rows)
1845
- bcol_count = wp.min(mm_values.shape[2] - bcol_off, subblock_cols)
1846
-
1847
- mm_row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
1848
- if mm_row == -1:
1849
- return
1850
-
1851
- lane_val = mm_type()
1852
-
1853
- use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
1854
- mm_row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
1855
- )
1856
-
1857
- col_count = (block_end - block_beg) * block_depth
1858
-
1859
- mm_col = mm_cols[mm_block]
1860
- if use_triplets:
1861
- for col in range(lane, col_count, tile_size):
1862
- tpl_block = col // wp.static(block_depth)
1863
- block_col = col - tpl_block * wp.static(block_depth)
1864
- tpl_block += block_beg
1865
-
1866
- x_block = summed_triplet_src_blocks[tpl_block]
1867
- if x_block != -1:
1868
- x_col = x_columns[x_block]
1869
- y_block = _bsr_block_index(x_col, mm_col, y_offsets, y_columns)
1870
- lane_val += _outer_product(
1871
- x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
1872
- )
1873
- else:
1874
- for col in range(lane, col_count, tile_size):
1875
- x_block = col // wp.static(block_depth)
1876
- block_col = col - x_block * wp.static(block_depth)
1877
- x_block += block_beg
1878
-
1879
- x_col = x_columns[x_block]
1880
- y_block = _bsr_block_index(x_col, mm_col, y_offsets, y_columns)
1881
-
1882
- if y_block != -1:
1883
- lane_val += _outer_product(
1884
- x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
1885
- )
1886
-
1887
- mm_val = wp.tile_sum(wp.tile(lane_val, preserve_type=True))[0]
1888
-
1889
- for coef in range(lane, wp.static(subblock_cols * subblock_rows), tile_size):
1890
- br = coef // subblock_cols
1891
- bc = coef - br * subblock_cols
1892
- if br < brow_count and bc < bcol_count:
1893
- mm_values[mm_block, br + brow_off, bc + bcol_off] += mm_val[br, bc] * alpha
1894
-
1895
- return bsr_mm_compute_values
1896
-
1897
-
1898
- class bsr_mm_work_arrays:
1899
- """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
1900
-
1901
- def __init__(self):
1902
- self._reset(None)
1903
-
1904
- def _reset(self, device):
1905
- self.device = device
1906
- self._mm_row_min = None
1907
- self._mm_block_counts = None
1908
- self._mm_rows = None
1909
- self._mm_cols = None
1910
- self._mm_src_blocks = None
1911
- self._old_z_values = None
1912
- self._old_z_offsets = None
1913
- self._old_z_columns = None
1914
- self._mm_nnz = 0
1915
-
1916
- def _allocate_stage_1(self, device, x_nnz: int, z: BsrMatrix, beta: float, z_aliasing: bool):
1917
- if self.device != device:
1918
- self._reset(device)
1919
-
1920
- # Allocations that do not depend on any computation
1921
- z_nnz = z.nnz_sync()
1922
- self._copied_z_nnz = z_nnz if beta != 0.0 or z_aliasing else 0
1923
-
1924
- if self._mm_row_min is None or self._mm_block_counts.size < z.nrow + 1:
1925
- self._mm_row_min = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
1926
- if self._mm_block_counts is None or self._mm_block_counts.size < x_nnz + 1:
1927
- self._mm_block_counts = wp.empty(shape=(x_nnz + 1,), dtype=int, device=self.device)
1928
-
1929
- if self._copied_z_nnz > 0:
1930
- if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
1931
- self._old_z_values = wp.empty(shape=(self._copied_z_nnz,), dtype=z.values.dtype, device=self.device)
1932
-
1933
- if z_aliasing:
1934
- if self._old_z_columns is None or self._old_z_columns.size < z_nnz:
1935
- self._old_z_columns = wp.empty(shape=(z_nnz,), dtype=z.columns.dtype, device=self.device)
1936
- if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
1937
- self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
1938
-
1939
- def _allocate_stage_2(self, mm_nnz: int):
1940
- # Allocations that depend on unmerged nnz estimate
1941
- self._mm_nnz = mm_nnz
1942
- if self._mm_rows is None or self._mm_rows.size < mm_nnz:
1943
- self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
1944
- if self._mm_cols is None or self._mm_cols.size < mm_nnz:
1945
- self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
1946
- if self._mm_src_blocks is None or self._mm_src_blocks.size < mm_nnz:
1947
- self._mm_src_blocks = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
1948
-
1949
-
1950
- def bsr_mm(
1951
- x: BsrMatrixOrExpression[BlockType[Rows, Any, Scalar]],
1952
- y: BsrMatrixOrExpression[BlockType[Any, Cols, Scalar]],
1953
- z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
1954
- alpha: Scalar = 1.0,
1955
- beta: Scalar = 0.0,
1956
- masked: bool = False,
1957
- work_arrays: Optional[bsr_mm_work_arrays] = None,
1958
- reuse_topology: bool = False,
1959
- tile_size: int = 0,
1960
- ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1961
- """
1962
- Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
1963
-
1964
- The ``x``, ``y`` and ``z`` matrices are allowed to alias.
1965
- If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
1966
-
1967
- Args:
1968
- x: Read-only left operand of the matrix-matrix product.
1969
- y: Read-only right operand of the matrix-matrix product.
1970
- z: Mutable affine operand and result matrix. If ``z`` is not provided, it will be allocated and treated as zero.
1971
- alpha: Uniform scaling factor for the ``x @ y`` product
1972
- beta: Uniform scaling factor for ``z``
1973
- masked: If ``True``, ignore all blocks from ``x @ y`` which are not existing non-zeros of ``y``
1974
- work_arrays: In most cases, this function will require the use of temporary storage.
1975
- This storage can be reused across calls by passing an instance of
1976
- :class:`bsr_mm_work_arrays` in ``work_arrays``.
1977
- reuse_topology: If ``True``, reuse the product topology information
1978
- stored in ``work_arrays`` rather than recompute it from scratch.
1979
- The matrices ``x``, ``y`` and ``z`` must be structurally similar to
1980
- the previous call in which ``work_arrays`` were populated.
1981
- This is necessary for ``bsr_mm`` to be captured in a CUDA graph.
1982
- tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
1983
- If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
1984
- use tiles using using an heuristic based on the matrix shape and number of non-zeros..
1985
- """
1986
-
1987
- x, x_scale = _extract_matrix_and_scale(x)
1988
- alpha *= x_scale
1989
- y, y_scale = _extract_matrix_and_scale(y)
1990
- alpha *= y_scale
1991
-
1992
- if z is None:
1993
- if masked:
1994
- raise ValueError("Left-hand-side 'z' matrix must be provided for masked multiplication")
1995
-
1996
- # If not output matrix is provided, allocate it for convenience
1997
- z_block_shape = (x.block_shape[0], y.block_shape[1])
1998
- if z_block_shape == (1, 1):
1999
- z_block_type = x.scalar_type
2000
- else:
2001
- z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
2002
- z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
2003
- z.values.requires_grad = x.requires_grad or y.requires_grad
2004
- beta = 0.0
2005
-
2006
- if x.values.device != y.values.device or x.values.device != z.values.device:
2007
- raise ValueError(
2008
- f"All arguments must reside on the same device, got {x.values.device}, {y.values.device} and {z.values.device}"
2009
- )
2010
-
2011
- if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
2012
- raise ValueError(
2013
- f"Matrices must have the same scalar type, got {x.scalar_type}, {y.scalar_type} and {z.scalar_type}"
2014
- )
2015
-
2016
- if (
2017
- x.block_shape[0] != z.block_shape[0]
2018
- or y.block_shape[1] != z.block_shape[1]
2019
- or x.block_shape[1] != y.block_shape[0]
2020
- ):
2021
- raise ValueError(
2022
- f"Incompatible block sizes for matrix multiplication, got ({x.block_shape}, {y.block_shape}) and ({z.block_shape})"
2023
- )
2024
-
2025
- if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
2026
- raise ValueError(
2027
- f"Incompatible number of rows/columns for matrix multiplication, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
2028
- )
2029
-
2030
- device = z.values.device
2031
-
2032
- if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
2033
- # Easy case
2034
- return bsr_scale(z, beta)
2035
-
2036
- z_aliasing = z == x or z == y
2037
-
2038
- if masked:
2039
- # no need to copy z, scale in-place
2040
- copied_z_nnz = 0
2041
- mm_nnz = z.nnz
2042
-
2043
- if z_aliasing:
2044
- raise ValueError("`masked=True` is not supported for aliased inputs")
2045
-
2046
- if beta == 0.0:
2047
- # do not bsr_scale(0), this would not preserve topology
2048
- z.values.zero_()
2049
- else:
2050
- bsr_scale(z, beta)
2051
- elif reuse_topology:
2052
- if work_arrays is None:
2053
- raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
2054
-
2055
- copied_z_nnz = work_arrays._copied_z_nnz
2056
- mm_nnz = work_arrays._mm_nnz
2057
- else:
2058
- if device.is_capturing:
2059
- raise RuntimeError(
2060
- "`bsr_mm` requires either `reuse_topology=True` or `masked=True` for use in graph capture"
2061
- )
2062
-
2063
- if work_arrays is None:
2064
- work_arrays = bsr_mm_work_arrays()
2065
-
2066
- work_arrays._allocate_stage_1(device, x.nnz, z, beta, z_aliasing)
2067
- copied_z_nnz = work_arrays._copied_z_nnz
2068
-
2069
- # Prefix sum of number of (unmerged) mm blocks per row
2070
- # Use either a thread or a block per row depending on avg nnz/row
2071
- work_arrays._mm_block_counts.zero_()
2072
- count_tile_size = 32
2073
- if not device.is_cuda or x.nnz < 3 * count_tile_size * x.nrow:
2074
- count_tile_size = 1
2075
-
2076
- wp.launch(
2077
- kernel=make_bsr_mm_count_coeffs(count_tile_size),
2078
- device=device,
2079
- dim=(z.nrow, count_tile_size),
2080
- block_dim=count_tile_size if count_tile_size > 1 else 256,
2081
- inputs=[
2082
- y.ncol,
2083
- copied_z_nnz,
2084
- x.offsets,
2085
- x.columns,
2086
- y.offsets,
2087
- y.columns,
2088
- work_arrays._mm_row_min,
2089
- work_arrays._mm_block_counts,
2090
- ],
2091
- )
2092
- warp.utils.array_scan(work_arrays._mm_block_counts[: x.nnz + 1], work_arrays._mm_block_counts[: x.nnz + 1])
2093
-
2094
- # Get back total counts on host -- we need a synchronization here
2095
- # Use pinned buffer from z, we are going to need it later anyway
2096
- nnz_buf, _ = z._setup_nnz_transfer()
2097
- stream = wp.get_stream(device) if device.is_cuda else None
2098
- wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
2099
- if device.is_cuda:
2100
- wp.synchronize_stream(stream)
2101
- mm_nnz = int(nnz_buf.numpy()[0])
2102
-
2103
- if mm_nnz == copied_z_nnz:
2104
- # x@y = 0
2105
- return bsr_scale(z, beta)
2106
-
2107
- work_arrays._allocate_stage_2(mm_nnz)
2108
-
2109
- # If z has a non-zero scale, save current data before overwriting it
2110
- if copied_z_nnz > 0:
2111
- # Copy z row and column indices
2112
- wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
2113
- z.uncompress_rows(out=work_arrays._mm_rows)
2114
- work_arrays._mm_src_blocks[:copied_z_nnz].fill_(-1)
2115
- if z_aliasing:
2116
- # If z is aliasing with x or y, need to save topology as well
2117
- wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
2118
- wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
2119
-
2120
- # Fill unmerged mm blocks rows and columns
2121
- wp.launch(
2122
- kernel=_bsr_mm_list_coeffs,
2123
- device=device,
2124
- dim=mm_nnz - copied_z_nnz,
2125
- inputs=[
2126
- copied_z_nnz,
2127
- x.nrow,
2128
- x.offsets,
2129
- x.columns,
2130
- y.offsets,
2131
- y.columns,
2132
- work_arrays._mm_row_min,
2133
- work_arrays._mm_block_counts,
2134
- work_arrays._mm_rows,
2135
- work_arrays._mm_cols,
2136
- work_arrays._mm_src_blocks,
2137
- ],
2138
- )
2139
-
2140
- alpha = z.scalar_type(alpha)
2141
- beta = z.scalar_type(beta)
2142
-
2143
- if copied_z_nnz > 0:
2144
- # Save current z values in temporary buffer
2145
- wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
2146
-
2147
- if not masked:
2148
- # Increase dest array size if needed
2149
- if z.columns.shape[0] < mm_nnz:
2150
- z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
2151
-
2152
- from warp.context import runtime
2153
-
2154
- if device.is_cpu:
2155
- native_func = runtime.core.wp_bsr_matrix_from_triplets_host
2156
- else:
2157
- native_func = runtime.core.wp_bsr_matrix_from_triplets_device
2158
-
2159
- nnz_buf, nnz_event = z._setup_nnz_transfer()
2160
- summed_triplet_offsets = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
2161
- summed_triplet_indices = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
2162
-
2163
- with wp.ScopedDevice(z.device):
2164
- native_func(
2165
- z.block_size,
2166
- 0, # scalar_size_in_bytes
2167
- z.nrow,
2168
- z.ncol,
2169
- mm_nnz,
2170
- None, # device nnz
2171
- ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
2172
- ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
2173
- None, # triplet values
2174
- 0, # zero_value_mask
2175
- False, # masked_topology
2176
- ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
2177
- ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
2178
- ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
2179
- ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
2180
- _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
2181
- _optional_ctypes_event(nnz_event),
2182
- )
2183
-
2184
- # Resize z to fit mm result if necessary
2185
- # If we are not reusing the product topology, this needs another synchronization
2186
- if not reuse_topology:
2187
- work_arrays.result_nnz = z.nnz_sync()
2188
-
2189
- _bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
2190
- z.values.zero_()
2191
-
2192
- if copied_z_nnz > 0:
2193
- # Add back original z values
2194
- wp.launch(
2195
- kernel=_bsr_axpy_add_block,
2196
- device=device,
2197
- dim=(copied_z_nnz, z.block_shape[0], z.block_shape[1]),
2198
- inputs=[
2199
- 0,
2200
- beta,
2201
- work_arrays._mm_rows,
2202
- work_arrays._mm_cols,
2203
- z.offsets,
2204
- z.columns,
2205
- _as_3d_array(work_arrays._old_z_values, z.block_shape),
2206
- z.scalar_values,
2207
- ],
2208
- )
2209
-
2210
- max_subblock_dim = 12
2211
- if tile_size > 0:
2212
- use_tiles = True
2213
- elif tile_size < 0:
2214
- use_tiles = False
2215
- else:
2216
- # Heuristic for using tiled variant: few or very large blocks
2217
- tile_size = 64
2218
- max_tiles_per_sm = 2048 // tile_size # assume 64 resident warps per SM
2219
- use_tiles = device.is_cuda and (
2220
- max(x.block_size, y.block_size, z.block_size) > max_subblock_dim**2
2221
- or mm_nnz < max_tiles_per_sm * device.sm_count
2222
- )
2223
-
2224
- if use_tiles:
2225
- subblock_rows = min(max_subblock_dim, z.block_shape[0])
2226
- subblock_cols = min(max_subblock_dim, z.block_shape[1])
2227
-
2228
- wp.launch(
2229
- kernel=make_bsr_mm_compute_values_tiled_outer(
2230
- subblock_rows, subblock_cols, x.block_shape[1], z.scalar_type, tile_size
2231
- ),
2232
- device=device,
2233
- dim=(
2234
- z.nnz,
2235
- (z.block_shape[0] + subblock_rows - 1) // subblock_rows,
2236
- (z.block_shape[1] + subblock_cols - 1) // subblock_cols,
2237
- tile_size,
2238
- ),
2239
- block_dim=tile_size,
2240
- inputs=[
2241
- alpha,
2242
- work_arrays._old_z_offsets if x == z else x.offsets,
2243
- work_arrays._old_z_columns if x == z else x.columns,
2244
- _as_3d_array(work_arrays._old_z_values, z.block_shape) if x == z else x.scalar_values,
2245
- work_arrays._old_z_offsets if y == z else y.offsets,
2246
- work_arrays._old_z_columns if y == z else y.columns,
2247
- _as_3d_array(work_arrays._old_z_values, z.block_shape) if y == z else y.scalar_values,
2248
- None if masked else work_arrays._mm_row_min,
2249
- None if masked else summed_triplet_offsets,
2250
- None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
2251
- z.nrow,
2252
- z.offsets,
2253
- z.columns,
2254
- z.scalar_values,
2255
- ],
2256
- )
2257
-
2258
- return z
2259
-
2260
- # Add mm blocks to z values
2261
- if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
2262
- # Result block type is scalar, but operands are matrices
2263
- # Cast result to (1x1) matrix to perform multiplication
2264
- mm_values = z.values.view(wp.mat(shape=(1, 1), dtype=z.scalar_type))
2265
- else:
2266
- mm_values = z.values
2267
-
2268
- wp.launch(
2269
- kernel=_bsr_mm_compute_values,
2270
- device=device,
2271
- dim=z.nnz,
2272
- inputs=[
2273
- alpha,
2274
- work_arrays._old_z_offsets if x == z else x.offsets,
2275
- work_arrays._old_z_columns if x == z else x.columns,
2276
- work_arrays._old_z_values if x == z else x.values,
2277
- work_arrays._old_z_offsets if y == z else y.offsets,
2278
- work_arrays._old_z_columns if y == z else y.columns,
2279
- work_arrays._old_z_values if y == z else y.values,
2280
- None if masked else work_arrays._mm_row_min,
2281
- None if masked else summed_triplet_offsets,
2282
- None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
2283
- z.nrow,
2284
- z.offsets,
2285
- z.columns,
2286
- mm_values,
2287
- ],
2288
- )
2289
-
2290
- return z
2291
-
2292
-
2293
- def make_bsr_mv_kernel(block_cols: int):
2294
- from warp.fem.cache import dynamic_kernel
2295
-
2296
- @dynamic_kernel(suffix=f"{block_cols}", kernel_options={"enable_backward": False})
2297
- def bsr_mv_kernel(
2298
- alpha: Any,
2299
- A_offsets: wp.array(dtype=int),
2300
- A_columns: wp.array(dtype=int),
2301
- A_values: wp.array3d(dtype=Any),
2302
- x: wp.array(dtype=Any),
2303
- beta: Any,
2304
- y: wp.array(dtype=Any),
2305
- ):
2306
- row, subrow = wp.tid()
2307
-
2308
- block_rows = A_values.shape[1]
2309
-
2310
- yi = row * block_rows + subrow
2311
-
2312
- # zero-initialize with type of y elements
2313
- scalar_zero = type(alpha)(0)
2314
- v = scalar_zero
2315
-
2316
- if alpha != scalar_zero:
2317
- beg = A_offsets[row]
2318
- end = A_offsets[row + 1]
2319
- for block in range(beg, end):
2320
- xs = A_columns[block] * block_cols
2321
- for col in range(wp.static(block_cols)):
2322
- v += A_values[block, subrow, col] * x[xs + col]
2323
- v *= alpha
2324
-
2325
- if beta != scalar_zero:
2326
- v += beta * y[yi]
2327
-
2328
- y[yi] = v
2329
-
2330
- return bsr_mv_kernel
2331
-
2332
-
2333
- def make_bsr_mv_tiled_kernel(tile_size: int):
2334
- from warp.fem.cache import dynamic_kernel
2335
-
2336
- @dynamic_kernel(suffix=f"{tile_size}", kernel_options={"enable_backward": False})
2337
- def bsr_mv_tiled_kernel(
2338
- alpha: Any,
2339
- A_offsets: wp.array(dtype=int),
2340
- A_columns: wp.array(dtype=int),
2341
- A_values: wp.array3d(dtype=Any),
2342
- x: wp.array(dtype=Any),
2343
- beta: Any,
2344
- y: wp.array(dtype=Any),
2345
- ):
2346
- row, subrow, lane = wp.tid()
2347
-
2348
- scalar_zero = type(alpha)(0)
2349
- block_rows = A_values.shape[1]
2350
- block_cols = A_values.shape[2]
2351
-
2352
- yi = row * block_rows + subrow
2353
-
2354
- if beta == scalar_zero:
2355
- subrow_sum = wp.tile_zeros(shape=(1,), dtype=y.dtype)
2356
- else:
2357
- subrow_sum = beta * wp.tile_load(y, 1, yi)
2358
-
2359
- if alpha != scalar_zero:
2360
- block_beg = A_offsets[row]
2361
- col_count = (A_offsets[row + 1] - block_beg) * block_cols
2362
-
2363
- col = lane
2364
- lane_sum = y.dtype(0)
2365
-
2366
- for col in range(lane, col_count, tile_size):
2367
- block = col // block_cols
2368
- block_col = col - block * block_cols
2369
- block += block_beg
2370
-
2371
- xi = x[A_columns[block] * block_cols + block_col]
2372
- lane_sum += A_values[block, subrow, block_col] * xi
2373
-
2374
- lane_sum *= alpha
2375
- subrow_sum += wp.tile_sum(wp.tile(lane_sum))
2376
-
2377
- wp.tile_store(y, subrow_sum, yi)
2378
-
2379
- return bsr_mv_tiled_kernel
2380
-
2381
-
2382
- def make_bsr_mv_transpose_kernel(block_rows: int):
2383
- from warp.fem.cache import dynamic_kernel
2384
-
2385
- @dynamic_kernel(suffix=f"{block_rows}", kernel_options={"enable_backward": False})
2386
- def bsr_mv_transpose_kernel(
2387
- alpha: Any,
2388
- A_row_count: int,
2389
- A_offsets: wp.array(dtype=int),
2390
- A_columns: wp.array(dtype=int),
2391
- A_values: wp.array3d(dtype=Any),
2392
- x: wp.array(dtype=Any),
2393
- y: wp.array(dtype=Any),
2394
- ):
2395
- block, subcol = wp.tid()
2396
-
2397
- row = _bsr_row_index(A_offsets, A_row_count, block)
2398
- if row == -1:
2399
- return
2400
-
2401
- block_cols = A_values.shape[2]
2402
-
2403
- A_block = A_values[block]
2404
-
2405
- col_sum = type(alpha)(0)
2406
- for subrow in range(wp.static(block_rows)):
2407
- col_sum += A_block[subrow, subcol] * x[row * block_rows + subrow]
2408
-
2409
- wp.atomic_add(y, A_columns[block] * block_cols + subcol, alpha * col_sum)
2410
-
2411
- return bsr_mv_transpose_kernel
2412
-
2413
-
2414
- def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
2415
- # cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
2416
-
2417
- scalar_count = array.size * type_size(array.dtype)
2418
- if scalar_count != expected_scalar_count:
2419
- raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
2420
-
2421
- if array.ndim == 1 and types_equal(array.dtype, dtype):
2422
- return array
2423
-
2424
- if type_scalar_type(array.dtype) != type_scalar_type(dtype):
2425
- raise ValueError(f"Incompatible scalar types, expected {type_repr(array.dtype)}, got {type_repr(dtype)}")
2426
-
2427
- if array.ndim > 2:
2428
- raise ValueError(f"Incompatible array number of dimensions, expected 1 or 2, got {array.ndim}")
2429
-
2430
- if not array.is_contiguous:
2431
- raise ValueError("Array must be contiguous")
2432
-
2433
- vec_length = type_size(dtype)
2434
- vec_count = scalar_count // vec_length
2435
- if vec_count * vec_length != scalar_count:
2436
- raise ValueError(
2437
- f"Array of shape {array.shape} and type {type_repr(array.dtype)} cannot be reshaped to an array of type {type_repr(dtype)}"
2438
- )
2439
-
2440
- def vec_view(array):
2441
- return wp.array(
2442
- data=None,
2443
- ptr=array.ptr,
2444
- capacity=array.capacity,
2445
- device=array.device,
2446
- dtype=dtype,
2447
- shape=vec_count,
2448
- grad=None if array.grad is None else vec_view(array.grad),
2449
- )
2450
-
2451
- view = vec_view(array)
2452
- view._ref = array
2453
- return view
2454
-
2455
-
2456
- def bsr_mv(
2457
- A: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
2458
- x: "Array[Vector[Cols, Scalar] | Scalar]",
2459
- y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
2460
- alpha: Scalar = 1.0,
2461
- beta: Scalar = 0.0,
2462
- transpose: bool = False,
2463
- work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
2464
- tile_size: int = 0,
2465
- ) -> "Array[Vector[Rows, Scalar] | Scalar]":
2466
- """Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
2467
-
2468
- The ``x`` and ``y`` vectors are allowed to alias.
2469
-
2470
- Args:
2471
- A: Read-only, left matrix operand of the matrix-vector product.
2472
- x: Read-only, right vector operand of the matrix-vector product.
2473
- y: Mutable affine operand and result vector. If ``y`` is not provided, it will be allocated and treated as zero.
2474
- alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
2475
- beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
2476
- transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
2477
- work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
2478
- If provided, the ``work_buffer`` array will be used for this purpose,
2479
- otherwise a temporary allocation will be performed.
2480
- tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
2481
- If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
2482
- use tiles using using an heuristic based on the matrix shape and number of non-zeros..
2483
- """
2484
-
2485
- A, A_scale = _extract_matrix_and_scale(A)
2486
- alpha *= A_scale
2487
-
2488
- if transpose:
2489
- block_shape = A.block_shape[1], A.block_shape[0]
2490
- nrow, ncol = A.ncol, A.nrow
2491
- else:
2492
- block_shape = A.block_shape
2493
- nrow, ncol = A.nrow, A.ncol
2494
-
2495
- if y is None:
2496
- # If no output array is provided, allocate one for convenience
2497
- y_vec_len = block_shape[0]
2498
- y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
2499
- y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype, requires_grad=x.requires_grad)
2500
- beta = 0.0
2501
-
2502
- alpha = A.scalar_type(alpha)
2503
- beta = A.scalar_type(beta)
2504
-
2505
- device = A.values.device
2506
- if A.values.device != x.device or A.values.device != y.device:
2507
- raise ValueError(
2508
- f"A, x, and y must reside on the same device, got {A.values.device}, {x.device} and {y.device}"
2509
- )
2510
-
2511
- if x.ptr == y.ptr:
2512
- # Aliasing case, need temporary storage
2513
- if work_buffer is None:
2514
- work_buffer = wp.empty_like(y)
2515
- elif work_buffer.size < y.size:
2516
- raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}, got {work_buffer.size}")
2517
- elif not types_equal(work_buffer.dtype, y.dtype):
2518
- raise ValueError(
2519
- f"Work buffer must have same data type as y, {type_repr(y.dtype)} vs {type_repr(work_buffer.dtype)}"
2520
- )
2521
-
2522
- # Save old y values before overwriting vector
2523
- wp.copy(dest=work_buffer, src=y, count=y.size)
2524
- x = work_buffer
2525
-
2526
- try:
2527
- x_view = _vec_array_view(x, A.scalar_type, expected_scalar_count=ncol * block_shape[1])
2528
- except ValueError as err:
2529
- raise ValueError("Incompatible 'x' vector for bsr_mv") from err
2530
- try:
2531
- y_view = _vec_array_view(y, A.scalar_type, expected_scalar_count=nrow * block_shape[0])
2532
- except ValueError as err:
2533
- raise ValueError("Incompatible 'y' vector for bsr_mv") from err
2534
-
2535
- # heuristic to use tiled version for long rows
2536
- if tile_size > 0:
2537
- use_tiles = True
2538
- elif tile_size < 0:
2539
- use_tiles = False
2540
- else:
2541
- tile_size = 64
2542
- use_tiles = device.is_cuda and A.nnz * A.block_size > 2 * tile_size * A.shape[0]
2543
-
2544
- if transpose:
2545
- if beta.value == 0.0:
2546
- y.zero_()
2547
- elif beta.value != 1.0:
2548
- wp.launch(
2549
- kernel=_bsr_scale_kernel,
2550
- device=y.device,
2551
- dim=y_view.shape[0],
2552
- inputs=[beta, y_view],
2553
- )
2554
- if alpha.value != 0.0:
2555
- wp.launch(
2556
- kernel=make_bsr_mv_transpose_kernel(block_rows=block_shape[1]),
2557
- device=A.values.device,
2558
- dim=(A.nnz, block_shape[0]),
2559
- inputs=[alpha, A.nrow, A.offsets, A.columns, A.scalar_values, x_view, y_view],
2560
- )
2561
- elif use_tiles:
2562
- wp.launch(
2563
- kernel=make_bsr_mv_tiled_kernel(tile_size),
2564
- device=A.values.device,
2565
- dim=(nrow, block_shape[0], tile_size),
2566
- block_dim=tile_size,
2567
- inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
2568
- )
2569
- else:
2570
- wp.launch(
2571
- kernel=make_bsr_mv_kernel(block_cols=block_shape[1]),
2572
- device=A.values.device,
2573
- dim=(nrow, block_shape[0]),
2574
- inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
2575
- )
2576
-
2577
- return y
16
+ # isort: skip_file
17
+
18
+ from warp._src.sparse import BsrMatrix as BsrMatrix
19
+ from warp._src.sparse import bsr_assign as bsr_assign
20
+ from warp._src.sparse import bsr_axpy as bsr_axpy
21
+ from warp._src.sparse import bsr_axpy_work_arrays as bsr_axpy_work_arrays
22
+ from warp._src.sparse import bsr_block_index as bsr_block_index
23
+ from warp._src.sparse import bsr_copy as bsr_copy
24
+ from warp._src.sparse import bsr_diag as bsr_diag
25
+ from warp._src.sparse import bsr_from_triplets as bsr_from_triplets
26
+ from warp._src.sparse import bsr_get_diag as bsr_get_diag
27
+ from warp._src.sparse import bsr_identity as bsr_identity
28
+ from warp._src.sparse import bsr_matrix_t as bsr_matrix_t
29
+ from warp._src.sparse import bsr_mm as bsr_mm
30
+ from warp._src.sparse import bsr_mm_work_arrays as bsr_mm_work_arrays
31
+ from warp._src.sparse import bsr_mv as bsr_mv
32
+ from warp._src.sparse import bsr_row_index as bsr_row_index
33
+ from warp._src.sparse import bsr_scale as bsr_scale
34
+ from warp._src.sparse import bsr_set_diag as bsr_set_diag
35
+ from warp._src.sparse import bsr_set_from_triplets as bsr_set_from_triplets
36
+ from warp._src.sparse import bsr_set_identity as bsr_set_identity
37
+ from warp._src.sparse import bsr_set_transpose as bsr_set_transpose
38
+ from warp._src.sparse import bsr_set_zero as bsr_set_zero
39
+ from warp._src.sparse import bsr_transposed as bsr_transposed
40
+ from warp._src.sparse import bsr_zeros as bsr_zeros
41
+
42
+
43
+ # TODO: Remove after cleaning up the public API.
44
+
45
+ from warp._src import sparse as _sparse
46
+
47
+
48
+ def __getattr__(name):
49
+ from warp._src.utils import get_deprecated_api
50
+
51
+ return get_deprecated_api(_sparse, "wp", name)