warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -1,14 +1,29 @@
1
+ from typing import Any, Generic, Optional, Tuple, TypeVar, Union
2
+
1
3
  import warp as wp
2
4
  import warp.types
3
5
  import warp.utils
6
+ from warp.types import Array, Cols, Matrix, Rows, Scalar, Vector
7
+
8
+ # typing hints
9
+
10
+ _BlockType = TypeVar("BlockType")
11
+
12
+
13
+ class _MatrixBlockType(Matrix):
14
+ pass
4
15
 
5
- from typing import Tuple, Any, Union
6
16
 
17
+ class _ScalarBlockType(Generic[Scalar]):
18
+ pass
19
+
20
+
21
+ BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
7
22
 
8
23
  _struct_cache = dict()
9
24
 
10
25
 
11
- class BsrMatrix:
26
+ class BsrMatrix(Generic[_BlockType]):
12
27
  """Untyped base class for BSR and CSR matrices.
13
28
 
14
29
  Should not be constructed directly but through functions such as :func:`bsr_zeros`.
@@ -16,15 +31,15 @@ class BsrMatrix:
16
31
  Attributes:
17
32
  nrow (int): Number of rows of blocks
18
33
  ncol (int): Number of columns of blocks
19
- nnz (int): Number of non-zero blocks: equal to `offsets[-1]`, cached on host for convenience
20
- offsets (wp.array(dtype=int)): Array of size at least 1 + nrows containing start and end offsets og blocks in each row
21
- columns (wp.array(dtype=int)): Array of size at least equal to nnz containing block column indices
22
- values (wp.array(dtype=dtype)): Array of size at least equal to nnz containing block values
34
+ nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
35
+ offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
36
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
37
+ values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
23
38
  """
24
39
 
25
40
  @property
26
- def scalar_type(self) -> type:
27
- """Scalar type for each of the blocks' coefficients. FOr CSR matrices, this is equal to the block type"""
41
+ def scalar_type(self) -> Scalar:
42
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
28
43
  return warp.types.type_scalar_type(self.values.dtype)
29
44
 
30
45
  @property
@@ -33,20 +48,35 @@ class BsrMatrix:
33
48
  return getattr(self.values.dtype, "_shape_", (1, 1))
34
49
 
35
50
  @property
36
- def block_size(self) -> Tuple[int, int]:
37
- """Size of the individual blocks, i.e. number of rows per block times number of columsn per block"""
51
+ def block_size(self) -> int:
52
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
38
53
  return warp.types.type_length(self.values.dtype)
39
54
 
40
55
  @property
41
56
  def shape(self) -> Tuple[int, int]:
42
- """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columsn per block"""
57
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
43
58
  block_shape = self.block_shape
44
59
  return (self.nrow * block_shape[0], self.ncol * block_shape[1])
45
60
 
61
+ @property
62
+ def dtype(self) -> type:
63
+ """Data type for individual block values"""
64
+ return self.values.dtype
46
65
 
47
- def bsr_matrix_t(dtype: type):
66
+ @property
67
+ def device(self) -> wp.context.Device:
68
+ """Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays """
69
+ return self.values.device
70
+
71
+
72
+ def bsr_matrix_t(dtype: BlockType):
48
73
  dtype = wp.types.type_to_warp(dtype)
49
74
 
75
+ if not warp.types.type_is_matrix(dtype) and not dtype in warp.types.scalar_types:
76
+ raise ValueError(
77
+ f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
78
+ )
79
+
50
80
  class BsrMatrixTyped(BsrMatrix):
51
81
  nrow: int
52
82
  """Number of rows of blocks"""
@@ -79,11 +109,23 @@ def bsr_matrix_t(dtype: type):
79
109
 
80
110
 
81
111
  def bsr_zeros(
82
- rows_of_blocks: int, cols_of_blocks: int, block_type: type, device: wp.context.Devicelike = None
112
+ rows_of_blocks: int,
113
+ cols_of_blocks: int,
114
+ block_type: BlockType,
115
+ device: wp.context.Devicelike = None,
83
116
  ) -> BsrMatrix:
84
117
  """
85
- Constructs an empty BSR or CS matrix with the given shape
118
+ Constructs and returns an empty BSR or CSR matrix with the given shape
119
+
120
+ Args:
121
+ bsr: The BSR or CSR matrix to set to zero
122
+ rows_of_blocks: Number of rows of blocks
123
+ cols_of_blocks: Number of columns of blocks
124
+ block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
125
+ for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
126
+ device: Device on which to allocate the matrix arrays
86
127
  """
128
+
87
129
  bsr = bsr_matrix_t(block_type)()
88
130
 
89
131
  bsr.nrow = rows_of_blocks
@@ -110,19 +152,42 @@ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
110
152
  bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
111
153
 
112
154
 
155
+ def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
156
+ """
157
+ Sets a BSR matrix to zero, possibly changing its size
158
+
159
+ Args:
160
+ bsr: The BSR or CSR matrix to set to zero
161
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
162
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
163
+ """
164
+
165
+ if rows_of_blocks is not None:
166
+ bsr.nrow = rows_of_blocks
167
+ if cols_of_blocks is not None:
168
+ bsr.ncol = cols_of_blocks
169
+ bsr.nnz = 0
170
+ _bsr_ensure_fits(bsr)
171
+ bsr.offsets.zero_()
172
+
173
+
113
174
  def bsr_set_from_triplets(
114
- dest: BsrMatrix,
115
- rows: wp.array(dtype=int),
116
- columns: wp.array(dtype=int),
117
- values: wp.array(dtype=Any),
175
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
176
+ rows: "Array[int]",
177
+ columns: "Array[int]",
178
+ values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
118
179
  ):
119
180
  """
120
- Fills a BSR matrix `dest` with values defined by COO triplets `rows`, `columns`, `values`.
181
+ Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
121
182
 
122
- Values must be either one-dimensional with data type identical to the `dest` matrix block times,
123
- or a 3d array with data type equal to the `dest` matrix scalar type.
183
+ The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
124
184
 
125
- Previous blocks of `dest` are discarded.
185
+ Args:
186
+ dest: Sparse matrix to populate
187
+ rows: Row index for each non-zero
188
+ columns: Columns index for each non-zero
189
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
190
+ to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
126
191
  """
127
192
 
128
193
  if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
@@ -138,7 +203,7 @@ def bsr_set_from_triplets(
138
203
  elif values.ndim == 3:
139
204
  if values.shape[1:] != dest.block_shape:
140
205
  raise ValueError(
141
- f"Last two dimensions in values array ({values.shape[1:]}) shoudl correspond to matrix block shape {(dest.block_shape)})"
206
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
142
207
  )
143
208
 
144
209
  if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
@@ -150,6 +215,9 @@ def bsr_set_from_triplets(
150
215
  raise ValueError("Number of dimension for values array should be 1 or 3")
151
216
 
152
217
  nnz = rows.shape[0]
218
+ if nnz == 0:
219
+ bsr_set_zero(dest)
220
+ return
153
221
 
154
222
  # Increase dest array sizes if needed
155
223
  _bsr_ensure_fits(dest, nnz=nnz)
@@ -186,8 +254,8 @@ def bsr_set_from_triplets(
186
254
  )
187
255
 
188
256
 
189
- def bsr_assign(dest: BsrMatrix, src: BsrMatrix):
190
- """Copies the content of the `src` matrix to `dest`, possibly casting the block values."""
257
+ def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
258
+ """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
191
259
 
192
260
  if dest.values.device != src.values.device:
193
261
  raise ValueError("Source and destination matrices must reside on the same device")
@@ -202,13 +270,17 @@ def bsr_assign(dest: BsrMatrix, src: BsrMatrix):
202
270
  _bsr_ensure_fits(dest)
203
271
 
204
272
  wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
205
- wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
273
+ if src.nnz > 0:
274
+ wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
275
+ warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
206
276
 
207
- warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
208
277
 
278
+ def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
279
+ """Returns a copy of matrix ``A``, possibly changing its scalar type.
209
280
 
210
- def bsr_copy(A: BsrMatrix, scalar_type=None):
211
- """Returns a copy of matrix A, possibly asting values to a new scalar type"""
281
+ Args:
282
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
283
+ """
212
284
  if scalar_type is None:
213
285
  block_type = A.values.dtype
214
286
  elif A.block_shape == (1, 1):
@@ -221,7 +293,7 @@ def bsr_copy(A: BsrMatrix, scalar_type=None):
221
293
  return copy
222
294
 
223
295
 
224
- def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
296
+ def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
225
297
  """Assigns the transposed matrix `src` to matrix `dest`"""
226
298
 
227
299
  if dest.values.device != src.values.device:
@@ -230,10 +302,7 @@ def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
230
302
  if dest.scalar_type != src.scalar_type:
231
303
  raise ValueError("All arguments must have the same scalar type")
232
304
 
233
- if src.block_shape == (1, 1):
234
- transpose_block_shape = (1, 1)
235
- else:
236
- transpose_block_shape = src.block_shape[::-1]
305
+ transpose_block_shape = src.block_shape[::-1]
237
306
 
238
307
  if dest.block_shape != transpose_block_shape:
239
308
  raise ValueError(f"Destination block shape must be {transpose_block_shape}")
@@ -242,6 +311,9 @@ def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
242
311
  dest.ncol = src.nrow
243
312
  dest.nnz = src.nnz
244
313
 
314
+ if src.nnz == 0:
315
+ return
316
+
245
317
  # Increase dest array sizes if needed
246
318
  _bsr_ensure_fits(dest)
247
319
 
@@ -301,27 +373,33 @@ def _bsr_get_diag_kernel(
301
373
  end = A_offsets[row + 1]
302
374
 
303
375
  diag = wp.lower_bound(A_columns, beg, end, row)
304
- if A_columns[diag] == row:
305
- out[row] = A_values[diag]
376
+ if diag < end:
377
+ if A_columns[diag] == row:
378
+ out[row] = A_values[diag]
379
+
306
380
 
381
+ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
382
+ """Returns the array of blocks that constitute the diagonal of a sparse matrix.
383
+
384
+ Args:
385
+ A: the sparse matrix from which to extract the diagonal
386
+ out: if provided, the array into which to store the diagonal blocks
387
+ """
307
388
 
308
- def bsr_get_diag(A: BsrMatrix, out: wp.array = None):
309
- """Returns the block diagonal of a square sparse matrix"""
310
- if A.nrow != A.ncol:
311
- raise ValueError("bsr_get_diag is only available for square sparse matrices")
389
+ dim = min(A.nrow, A.ncol)
312
390
 
313
391
  if out is None:
314
- out = wp.zeros(shape=(A.nrow,), dtype=A.values.dtype, device=A.values.device)
392
+ out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
315
393
  else:
316
394
  if out.dtype != A.values.dtype:
317
395
  raise ValueError(f"Output array must have type {A.values.dtype}")
318
396
  if out.device != A.values.device:
319
397
  raise ValueError(f"Output array must reside on device {A.values.device}")
320
- if out.shape[0] < A.nrow:
321
- raise ValueError(f"Output array must be of length at least {A.nrow}")
398
+ if out.shape[0] < dim:
399
+ raise ValueError(f"Output array must be of length at least {dim}")
322
400
 
323
401
  wp.launch(
324
- kernel=_bsr_get_diag_kernel, dim=A.nrow, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
402
+ kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
325
403
  )
326
404
 
327
405
  return out
@@ -329,40 +407,205 @@ def bsr_get_diag(A: BsrMatrix, out: wp.array = None):
329
407
 
330
408
  @wp.kernel
331
409
  def _bsr_set_diag_kernel(
410
+ diag: wp.array(dtype=Any),
411
+ A_offsets: wp.array(dtype=int),
412
+ A_columns: wp.array(dtype=int),
413
+ A_values: wp.array(dtype=Any),
414
+ ):
415
+ row = wp.tid()
416
+ A_offsets[row + 1] = row + 1
417
+ A_columns[row] = row
418
+ A_values[row] = diag[row]
419
+
420
+ if row == 0:
421
+ A_offsets[0] = 0
422
+
423
+
424
+ @wp.kernel
425
+ def _bsr_set_diag_constant_kernel(
426
+ diag_value: Any,
332
427
  A_offsets: wp.array(dtype=int),
333
428
  A_columns: wp.array(dtype=int),
429
+ A_values: wp.array(dtype=Any),
334
430
  ):
335
431
  row = wp.tid()
336
432
  A_offsets[row + 1] = row + 1
337
433
  A_columns[row] = row
434
+ A_values[row] = diag_value
338
435
 
339
436
  if row == 0:
340
437
  A_offsets[0] = 0
341
438
 
342
439
 
343
- def bsr_set_diag(A: BsrMatrix, diag: wp.array):
344
- """Sets A as a block-diagonal square matrix"""
440
+ def bsr_set_diag(
441
+ A: BsrMatrix[BlockType],
442
+ diag: "Union[BlockType, Array[BlockType]]",
443
+ rows_of_blocks: Optional[int] = None,
444
+ cols_of_blocks: Optional[int] = None,
445
+ ):
446
+ """Sets `A` as a block-diagonal matrix
447
+
448
+ Args:
449
+ A: the sparse matrix to modify
450
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
451
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
452
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
453
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
454
+
455
+ The shape of the matrix will be defined one of the following, in that order:
456
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
457
+ - the first dimension of `diag`, if `diag` is an array
458
+ - the current dimensions of `A` otherwise
459
+ """
460
+
461
+ if rows_of_blocks is None and cols_of_blocks is not None:
462
+ rows_of_blocks = cols_of_blocks
463
+ if cols_of_blocks is None and rows_of_blocks is not None:
464
+ cols_of_blocks = rows_of_blocks
465
+
466
+ if warp.types.is_array(diag):
467
+ if rows_of_blocks is None:
468
+ rows_of_blocks = diag.shape[0]
469
+ cols_of_blocks = diag.shape[0]
470
+
471
+ if rows_of_blocks is not None:
472
+ A.nrow = rows_of_blocks
473
+ A.ncol = cols_of_blocks
474
+
475
+ A.nnz = min(A.nrow, A.ncol)
476
+ _bsr_ensure_fits(A)
477
+
478
+ if warp.types.is_array(diag):
479
+ wp.launch(
480
+ kernel=_bsr_set_diag_kernel,
481
+ dim=A.nnz,
482
+ device=A.values.device,
483
+ inputs=[diag, A.offsets, A.columns, A.values],
484
+ )
485
+ else:
486
+ if not warp.types.type_is_value(type(diag)):
487
+ # Cast to launchable type
488
+ diag = A.values.dtype(diag)
489
+ wp.launch(
490
+ kernel=_bsr_set_diag_constant_kernel,
491
+ dim=A.nnz,
492
+ device=A.values.device,
493
+ inputs=[diag, A.offsets, A.columns, A.values],
494
+ )
345
495
 
346
- A.nrow = diag.shape[0]
347
- A.ncol = diag.shape[0]
348
- A.nnz = diag.shape[0]
349
496
 
350
- A.values = diag
351
- if A.columns.size < A.nrow:
352
- A.columns = wp.empty(shape=(A.nrow,), dtype=int, device=diag.device)
353
- if A.offsets.size < A.nrow + 1:
354
- A.offsets = wp.empty(shape=(A.nrow + 1,), dtype=int, device=diag.device)
497
+ def bsr_diag(
498
+ diag: "Union[BlockType, Array[BlockType]]",
499
+ rows_of_blocks: Optional[int] = None,
500
+ cols_of_blocks: Optional[int] = None,
501
+ ) -> BsrMatrix["BlockType"]:
502
+ """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
355
503
 
356
- wp.launch(kernel=_bsr_set_diag_kernel, dim=A.nrow, device=A.values.device, inputs=[A.offsets, A.columns])
504
+ Args:
505
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
506
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
507
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
508
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
357
509
 
510
+ The shape of the matrix will be defined one of the following, in that order:
511
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
512
+ - the first dimension of `diag`, if `diag` is an array
513
+ """
514
+
515
+ if rows_of_blocks is None and cols_of_blocks is not None:
516
+ rows_of_blocks = cols_of_blocks
517
+ if cols_of_blocks is None and rows_of_blocks is not None:
518
+ cols_of_blocks = rows_of_blocks
519
+
520
+ if warp.types.is_array(diag):
521
+ if rows_of_blocks is None:
522
+ rows_of_blocks = diag.shape[0]
523
+ cols_of_blocks = diag.shape[0]
524
+
525
+ A = bsr_zeros(
526
+ rows_of_blocks,
527
+ cols_of_blocks,
528
+ block_type=diag.dtype,
529
+ device=diag.device,
530
+ )
531
+ else:
532
+ if rows_of_blocks is None:
533
+ raise ValueError(
534
+ "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
535
+ )
536
+
537
+ block_type = type(diag)
538
+ if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
539
+ block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
540
+
541
+ A = bsr_zeros(
542
+ rows_of_blocks,
543
+ cols_of_blocks,
544
+ block_type=block_type,
545
+ )
358
546
 
359
- def bsr_diag(diag: wp.array):
360
- """Creates a square block-diagonal BSR matrix from the values array `diag`"""
361
- A = bsr_zeros(rows_of_blocks=diag.shape[0], cols_of_blocks=diag.shape[0], block_type=diag.dtype, device=diag.device)
362
547
  bsr_set_diag(A, diag)
363
548
  return A
364
549
 
365
550
 
551
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
552
+ """Sets `A` as the identity matrix
553
+
554
+ Args:
555
+ A: the sparse matrix to modify
556
+ rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
557
+ """
558
+
559
+ if A.block_shape == (1, 1):
560
+ identity = A.scalar_type(1.0)
561
+ else:
562
+ from numpy import eye
563
+
564
+ identity = eye(A.block_shape[0])
565
+
566
+ bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
567
+
568
+
569
+ def bsr_identity(
570
+ rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
571
+ ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
572
+ """Creates and returns a square identity matrix.
573
+
574
+ Args:
575
+ rows_of_blocks: Number of rows and columns of blocks in the created matrix.
576
+ block_type: Block type for the newly created matrix -- must be square
577
+ device: Device onto which to allocate the data arrays
578
+ """
579
+ A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
580
+ bsr_set_identity(A)
581
+ return A
582
+
583
+
584
+ @wp.kernel
585
+ def _bsr_scale_kernel(
586
+ alpha: Any,
587
+ values: wp.array(dtype=Any),
588
+ ):
589
+ values[wp.tid()] = alpha * values[wp.tid()]
590
+
591
+
592
+ def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
593
+ """
594
+ Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
595
+ """
596
+
597
+ if alpha != 1.0 and x.nnz > 0:
598
+ if alpha == 0.0:
599
+ bsr_set_zero(x)
600
+ else:
601
+ if not isinstance(alpha, x.scalar_type):
602
+ alpha = x.scalar_type(alpha)
603
+
604
+ wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
605
+
606
+ return x
607
+
608
+
366
609
  @wp.kernel
367
610
  def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
368
611
  i = wp.tid()
@@ -393,16 +636,75 @@ def _bsr_axpy_add_block(
393
636
  dst_values[block] = dst_values[block] + scale * src_values[i]
394
637
 
395
638
 
396
- def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
639
+ class bsr_axpy_work_arrays:
640
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
641
+
642
+ def __init__(self):
643
+ self._reset(None)
644
+
645
+ def _reset(self, device):
646
+ self.device = device
647
+ self._sum_rows = None
648
+ self._sum_cols = None
649
+ self._old_y_values = None
650
+ self._old_x_values = None
651
+
652
+ def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
653
+ if self.device != device:
654
+ self._reset(device)
655
+
656
+ if self._sum_rows is None or self._sum_rows.size < sum_nnz:
657
+ self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
658
+ if self._sum_cols is None or self._sum_cols.size < sum_nnz:
659
+ self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
660
+
661
+ if self._old_y_values is None or self._old_y_values.size < y.nnz:
662
+ self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
663
+
664
+
665
+ def bsr_axpy(
666
+ x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
667
+ y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
668
+ alpha: Scalar = 1.0,
669
+ beta: Scalar = 1.0,
670
+ work_arrays: Optional[bsr_axpy_work_arrays] = None,
671
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
397
672
  """
398
- Performs the operation `y := alpha * X + beta * y` on BSR matrices `x` and `y`
673
+ Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
674
+
675
+ The `x` and `y` matrices are allowed to alias.
676
+
677
+ Args:
678
+ x: Read-only right-hand-side.
679
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
680
+ alpha: Uniform scaling factor for `x`
681
+ beta: Uniform scaling factor for `y`
682
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
399
683
  """
400
684
 
401
685
  if y is None:
402
- y = bsr_zeros(x.nrow, x.ncol, block_type=x.block_type, device=x.values.device)
686
+ # If not output matrix is provided, allocate it for convenience
687
+ y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
403
688
  beta = 0.0
404
689
 
405
- device = y.values.device
690
+ # Handle easy cases first
691
+ if beta == 0.0 or y.nnz == 0:
692
+ bsr_assign(src=x, dest=y)
693
+ return bsr_scale(y, alpha=alpha)
694
+
695
+ if alpha == 0.0 or x.nnz == 0:
696
+ return bsr_scale(y, alpha=beta)
697
+
698
+ if not isinstance(alpha, y.scalar_type):
699
+ alpha = y.scalar_type(alpha)
700
+ if not isinstance(beta, y.scalar_type):
701
+ beta = y.scalar_type(beta)
702
+
703
+ if x == y:
704
+ # Aliasing case
705
+ return bsr_scale(y, alpha=alpha.value + beta.value)
706
+
707
+ # General case
406
708
 
407
709
  if x.values.device != y.values.device:
408
710
  raise ValueError("All arguments must reside on the same device")
@@ -413,20 +715,21 @@ def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
413
715
  if x.nrow != y.nrow or x.ncol != y.ncol:
414
716
  raise ValueError("Matrices must have the same number of rows and columns")
415
717
 
416
- alpha = y.scalar_type(alpha)
417
- beta = y.scalar_type(beta)
718
+ if work_arrays is None:
719
+ work_arrays = bsr_axpy_work_arrays()
418
720
 
419
721
  sum_nnz = x.nnz + y.nnz
420
- sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=device)
421
- sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=device)
722
+ device = y.values.device
723
+ work_arrays._allocate(device, y, sum_nnz)
724
+
725
+ wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
726
+ wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
422
727
 
423
- if y.nnz > 0:
424
- wp.copy(sum_cols, y.columns, 0, 0, y.nnz)
425
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, sum_rows])
728
+ wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
729
+ wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
426
730
 
427
- if x.nnz > 0:
428
- wp.copy(sum_cols, x.columns, y.nnz, 0, x.nnz)
429
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, sum_rows])
731
+ # Save old y values before overwriting matrix
732
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
430
733
 
431
734
  # Increase dest array sizes if needed
432
735
  if y.columns.shape[0] < sum_nnz:
@@ -439,37 +742,55 @@ def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
439
742
  else:
440
743
  native_func = runtime.core.bsr_matrix_from_triplets_float_device
441
744
 
442
- sum_nnz = native_func(
745
+ old_y_nnz = y.nnz
746
+ y.nnz = native_func(
443
747
  y.block_shape[0],
444
748
  y.block_shape[1],
445
749
  y.nrow,
446
750
  sum_nnz,
447
- sum_rows.ptr,
448
- sum_cols.ptr,
751
+ work_arrays._sum_rows.ptr,
752
+ work_arrays._sum_cols.ptr,
449
753
  0,
450
754
  y.offsets.ptr,
451
755
  y.columns.ptr,
452
756
  0,
453
757
  )
454
758
 
455
- sum_values = wp.zeros(shape=(sum_nnz,), dtype=y.values.dtype, device=device)
759
+ _bsr_ensure_fits(y)
760
+ y.values.zero_()
456
761
 
457
762
  wp.launch(
458
763
  kernel=_bsr_axpy_add_block,
459
764
  device=device,
460
- dim=y.nnz,
461
- inputs=[0, beta, sum_rows, sum_cols, y.offsets, y.columns, y.values, sum_values],
765
+ dim=old_y_nnz,
766
+ inputs=[
767
+ 0,
768
+ beta,
769
+ work_arrays._sum_rows,
770
+ work_arrays._sum_cols,
771
+ y.offsets,
772
+ y.columns,
773
+ work_arrays._old_y_values,
774
+ y.values,
775
+ ],
462
776
  )
777
+
463
778
  wp.launch(
464
779
  kernel=_bsr_axpy_add_block,
465
780
  device=device,
466
781
  dim=x.nnz,
467
- inputs=[y.nnz, alpha, sum_rows, sum_cols, y.offsets, y.columns, x.values, sum_values],
782
+ inputs=[
783
+ old_y_nnz,
784
+ alpha,
785
+ work_arrays._sum_rows,
786
+ work_arrays._sum_cols,
787
+ y.offsets,
788
+ y.columns,
789
+ x.values,
790
+ y.values,
791
+ ],
468
792
  )
469
793
 
470
- y.values = sum_values
471
- y.nnz = sum_nnz
472
-
473
794
  return y
474
795
 
475
796
 
@@ -555,23 +876,77 @@ def _bsr_mm_compute_values(
555
876
  mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
556
877
 
557
878
 
558
- _pinned_temp_count_buffer = {}
559
-
560
-
561
- def _get_pinned_temp_count_buffer(device):
562
- device = str(device)
563
- if device not in _pinned_temp_count_buffer:
564
- _pinned_temp_count_buffer[device] = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
565
-
566
- return _pinned_temp_count_buffer[device]
567
-
568
-
569
- def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0, beta: float = 0.0):
879
+ class bsr_mm_work_arrays:
880
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
881
+
882
+ def __init__(self):
883
+ self._reset(None)
884
+
885
+ def _reset(self, device):
886
+ self.device = device
887
+ self._pinned_count_buffer = None
888
+ self._mm_row_counts = None
889
+ self._mm_rows = None
890
+ self._mm_cols = None
891
+ self._old_z_values = None
892
+ self._old_z_offsets = None
893
+ self._old_z_columns = None
894
+
895
+ def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
896
+ if self.device != device:
897
+ self._reset(device)
898
+
899
+ # Allocations that do not depend on any computation
900
+ if self.device.is_cuda:
901
+ if self._pinned_count_buffer is None:
902
+ self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
903
+
904
+ if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
905
+ self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
906
+
907
+ if copied_z_nnz > 0:
908
+ if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
909
+ self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
910
+
911
+ if z_aliasing:
912
+ if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
913
+ self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
914
+ if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
915
+ self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
916
+
917
+ def _allocate_stage_2(self, mm_nnz: int):
918
+ # Allocations that depend on unmerged nnz estimate
919
+ if self._mm_rows is None or self._mm_rows.size < mm_nnz:
920
+ self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
921
+ if self._mm_cols is None or self._mm_cols.size < mm_nnz:
922
+ self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
923
+
924
+
925
+ def bsr_mm(
926
+ x: BsrMatrix[BlockType[Rows, Any, Scalar]],
927
+ y: BsrMatrix[BlockType[Any, Cols, Scalar]],
928
+ z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
929
+ alpha: Scalar = 1.0,
930
+ beta: Scalar = 0.0,
931
+ work_arrays: Optional[bsr_mm_work_arrays] = None,
932
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
570
933
  """
571
- Performs the operation `z := alpha * X * Y + beta * z` on BSR matrices `x`, `y` and `z`
934
+ Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
935
+
936
+ The `x`, `y` and `z` matrices are allowed to alias.
937
+ If the matrix `z` is not provided as input, it will be allocated and treated as zero.
938
+
939
+ Args:
940
+ x: Read-only left factor of the matrix-matrix product.
941
+ y: Read-only right factor of the matrix-matrix product.
942
+ z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
943
+ alpha: Uniform scaling factor for the ``x * y`` product
944
+ beta: Uniform scaling factor for `z`
945
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
572
946
  """
573
947
 
574
948
  if z is None:
949
+ # If not output matrix is provided, allocate it for convenience
575
950
  z_block_shape = (x.block_shape[0], y.block_shape[1])
576
951
  if z_block_shape == (1, 1):
577
952
  z_block_type = x.scalar_type
@@ -586,52 +961,85 @@ def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0,
586
961
  if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
587
962
  raise ValueError("Matrices must have the same scalar type")
588
963
 
589
- if x.block_shape[0] != z.block_shape[0] or y.block_shape[1] != z.block_shape[1]:
590
- raise ValueError("Incompatible blocks sizes for matrix multiplication")
964
+ if (
965
+ x.block_shape[0] != z.block_shape[0]
966
+ or y.block_shape[1] != z.block_shape[1]
967
+ or x.block_shape[1] != y.block_shape[0]
968
+ ):
969
+ raise ValueError("Incompatible block sizes for matrix multiplication")
591
970
 
592
- if x.nrow != z.nrow or z.ncol != y.ncol:
971
+ if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
593
972
  raise ValueError("Incompatible number of rows/columns for matrix multiplication")
594
973
 
595
974
  device = z.values.device
596
975
 
597
- alpha = z.scalar_type(alpha)
598
- beta = z.scalar_type(beta)
976
+ if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
977
+ # Easy case
978
+ return bsr_scale(z, beta)
979
+
980
+ if not isinstance(alpha, z.scalar_type):
981
+ alpha = z.scalar_type(alpha)
982
+ if not isinstance(beta, z.scalar_type):
983
+ beta = z.scalar_type(beta)
984
+
985
+ if work_arrays is None:
986
+ work_arrays = bsr_mm_work_arrays()
987
+
988
+ z_aliasing = z == x or z == y
989
+ copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
990
+
991
+ work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
599
992
 
600
993
  # Prefix sum of number of (unmerged) mm blocks per row
601
- mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=device)
602
994
  wp.launch(
603
995
  kernel=_bsr_mm_count_coeffs,
604
996
  device=device,
605
997
  dim=z.nrow,
606
- inputs=[z.nnz, x.offsets, x.columns, y.offsets, mm_row_counts],
998
+ inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
607
999
  )
608
- warp.utils.array_scan(mm_row_counts, mm_row_counts)
1000
+ warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
609
1001
 
610
1002
  # Get back total counts on host
611
1003
  if device.is_cuda:
612
- mm_tot_count = _get_pinned_temp_count_buffer(device)
613
- wp.copy(dest=mm_tot_count, src=mm_row_counts, src_offset=z.nrow, count=1)
614
- wp.synchronize_stream(wp.get_stream())
615
- mm_nnz = int(mm_tot_count.numpy()[0])
1004
+ wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
1005
+ wp.synchronize_stream(wp.get_stream(device))
1006
+ mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
616
1007
  else:
617
- mm_nnz = int(mm_row_counts.numpy()[z.nrow])
1008
+ mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
618
1009
 
619
- mm_rows = wp.empty(shape=(mm_nnz), dtype=int, device=device)
620
- mm_cols = wp.empty(shape=(mm_nnz), dtype=int, device=device)
1010
+ work_arrays._allocate_stage_2(mm_nnz)
621
1011
 
622
- # Copy z rows columns
623
- wp.copy(mm_cols, z.columns, 0, 0, z.nnz)
624
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=z.nnz, inputs=[0, z.offsets, mm_rows])
1012
+ # If z has a non-zero scale, save current data before overwriting it
1013
+ if copied_z_nnz > 0:
1014
+ # Copy z row and column indices
1015
+ wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1016
+ wp.launch(
1017
+ kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
1018
+ )
1019
+ # Save current z values in temporary buffer
1020
+ wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1021
+ if z_aliasing:
1022
+ # If z is aliasing with x or y, need to save topology as well
1023
+ wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1024
+ wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
625
1025
 
626
1026
  # Fill unmerged mm blocks rows and columns
627
1027
  wp.launch(
628
1028
  kernel=_bsr_mm_list_coeffs,
629
1029
  device=device,
630
1030
  dim=z.nrow,
631
- inputs=[x.offsets, x.columns, y.offsets, y.columns, mm_row_counts, mm_rows, mm_cols],
1031
+ inputs=[
1032
+ x.offsets,
1033
+ x.columns,
1034
+ y.offsets,
1035
+ y.columns,
1036
+ work_arrays._mm_row_counts,
1037
+ work_arrays._mm_rows,
1038
+ work_arrays._mm_cols,
1039
+ ],
632
1040
  )
633
1041
 
634
- # Increase dest array sizes if needed
1042
+ # Increase dest array size if needed
635
1043
  if z.columns.shape[0] < mm_nnz:
636
1044
  z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
637
1045
 
@@ -642,40 +1050,68 @@ def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0,
642
1050
  else:
643
1051
  native_func = runtime.core.bsr_matrix_from_triplets_float_device
644
1052
 
645
- mm_nnz = native_func(
1053
+ z.nnz = native_func(
646
1054
  z.block_shape[0],
647
1055
  z.block_shape[1],
648
1056
  z.nrow,
649
1057
  mm_nnz,
650
- mm_rows.ptr,
651
- mm_cols.ptr,
1058
+ work_arrays._mm_rows.ptr,
1059
+ work_arrays._mm_cols.ptr,
652
1060
  0,
653
1061
  z.offsets.ptr,
654
1062
  z.columns.ptr,
655
1063
  0,
656
1064
  )
657
1065
 
658
- mm_values = wp.zeros(shape=(mm_nnz,), dtype=z.values.dtype, device=device)
1066
+ _bsr_ensure_fits(z)
1067
+ z.values.zero_()
1068
+
1069
+ if copied_z_nnz > 0:
1070
+ # Add back original z values
1071
+ wp.launch(
1072
+ kernel=_bsr_axpy_add_block,
1073
+ device=device,
1074
+ dim=copied_z_nnz,
1075
+ inputs=[
1076
+ 0,
1077
+ beta,
1078
+ work_arrays._mm_rows,
1079
+ work_arrays._mm_cols,
1080
+ z.offsets,
1081
+ z.columns,
1082
+ work_arrays._old_z_values,
1083
+ z.values,
1084
+ ],
1085
+ )
659
1086
 
660
- # Copy blocks from z
661
- wp.launch(
662
- kernel=_bsr_axpy_add_block,
663
- device=device,
664
- dim=z.nnz,
665
- inputs=[0, beta, mm_rows, mm_cols, z.offsets, z.columns, z.values, mm_values],
666
- )
1087
+ # Add mm blocks to z values
1088
+ if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
1089
+ warp.types.type_is_matrix(z.values.dtype)
1090
+ ):
1091
+ # Result block type is scalar, but operands are matrices
1092
+ # Cast result to (1x1) matrix to perform multiplication
1093
+ mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1094
+ else:
1095
+ mm_values = z.values
667
1096
 
668
- # Add mm blocks
669
1097
  wp.launch(
670
1098
  kernel=_bsr_mm_compute_values,
671
1099
  device=device,
672
1100
  dim=z.nrow,
673
- inputs=[alpha, x.offsets, x.columns, x.values, y.offsets, y.columns, y.values, z.offsets, z.columns, mm_values],
1101
+ inputs=[
1102
+ alpha,
1103
+ work_arrays._old_z_offsets if x == z else x.offsets,
1104
+ work_arrays._old_z_columns if x == z else x.columns,
1105
+ work_arrays._old_z_values if x == z else x.values,
1106
+ work_arrays._old_z_offsets if y == z else y.offsets,
1107
+ work_arrays._old_z_columns if y == z else y.columns,
1108
+ work_arrays._old_z_values if y == z else y.values,
1109
+ z.offsets,
1110
+ z.columns,
1111
+ mm_values,
1112
+ ],
674
1113
  )
675
1114
 
676
- z.values = mm_values
677
- z.nnz = mm_nnz
678
-
679
1115
  return z
680
1116
 
681
1117
 
@@ -690,44 +1126,96 @@ def _bsr_mv_kernel(
690
1126
  y: wp.array(dtype=Any),
691
1127
  ):
692
1128
  row = wp.tid()
693
- beg = A_offsets[row]
694
- end = A_offsets[row + 1]
695
1129
 
696
- yr = y[row]
697
- v = yr - yr # WAR to get zero with correct type
698
- for block in range(beg, end):
699
- v = v + A_values[block] * x[A_columns[block]]
1130
+ # zero-initialize with type of y elements
1131
+ scalar_zero = type(alpha)(0)
1132
+ v = y.dtype(scalar_zero)
700
1133
 
701
- y[row] = beta * yr + alpha * v
1134
+ if alpha != scalar_zero:
1135
+ beg = A_offsets[row]
1136
+ end = A_offsets[row + 1]
1137
+ for block in range(beg, end):
1138
+ v += A_values[block] * x[A_columns[block]]
1139
+ v *= alpha
702
1140
 
1141
+ if beta != scalar_zero:
1142
+ v += beta * y[row]
703
1143
 
704
- def bsr_mv(A: BsrMatrix, x: wp.array, y: wp.array, alpha: float = 1.0, beta: float = 0.0):
1144
+ y[row] = v
1145
+
1146
+
1147
+ def bsr_mv(
1148
+ A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
1149
+ x: "Array[Vector[Cols, Scalar] | Scalar]",
1150
+ y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1151
+ alpha: Scalar = 1.0,
1152
+ beta: Scalar = 0.0,
1153
+ work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1154
+ ) -> "Array[Vector[Rows, Scalar] | Scalar]":
705
1155
  """
706
- Naive implementation of sparse matrix-vector product, `y := alpha * A * x + beta * y`.
1156
+ Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
1157
+
1158
+ The `x` and `y` vectors are allowed to alias.
1159
+
1160
+ Args:
1161
+ A: Read-only, left matrix factor of the matrix-vector product.
1162
+ x: Read-only, right vector factor of the matrix-vector product.
1163
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1164
+ alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1165
+ beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1166
+ work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1167
+ will be used for this purpose, otherwise a temporary allocation will be performed.
707
1168
  """
708
- alpha = A.scalar_type(alpha)
709
- beta = A.scalar_type(beta)
710
1169
 
711
- # if A.scalar_type != x.dtype or A.scalar_type != y.dtype:
712
- # raise ValueError("A, x and y must have the same data types")
1170
+ if y is None:
1171
+ # If no output array is provided, allocate one for convenience
1172
+ y_vec_len = A.block_shape[0]
1173
+ y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1174
+ y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
1175
+ y.zero_()
1176
+ beta = 0.0
1177
+
1178
+ if not isinstance(alpha, A.scalar_type):
1179
+ alpha = A.scalar_type(alpha)
1180
+ if not isinstance(beta, A.scalar_type):
1181
+ beta = A.scalar_type(beta)
713
1182
 
714
1183
  if A.values.device != x.device or A.values.device != y.device:
715
- raise ValueError("A, x and y must reide on the same device")
1184
+ raise ValueError("A, x and y must reside on the same device")
716
1185
 
717
1186
  if x.shape[0] != A.ncol:
718
1187
  raise ValueError("Number of columns of A must match number of rows of x")
719
1188
  if y.shape[0] != A.nrow:
720
1189
  raise ValueError("Number of rows of A must match number of rows of y")
721
1190
 
722
- # Promote scalar vectors to length-1 vecs
723
- block_shape = A.block_shape
724
- if block_shape != (1, 1):
725
- if block_shape[0] == 1:
1191
+ if x == y:
1192
+ # Aliasing case, need temporary storage
1193
+ if work_buffer is None:
1194
+ work_buffer = wp.empty_like(y)
1195
+ elif work_buffer.size < y.size:
1196
+ raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
1197
+ elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
1198
+ raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
1199
+
1200
+ # Save old y values before overwriting vector
1201
+ wp.copy(dest=work_buffer, src=y, count=y.size)
1202
+ x = work_buffer
1203
+
1204
+ # Promote scalar vectors to length-1 vecs and conversely
1205
+ if warp.types.type_is_matrix(A.values.dtype):
1206
+ if A.block_shape[0] == 1:
726
1207
  if y.dtype == A.scalar_type:
727
1208
  y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
728
- if block_shape[1] == 1:
1209
+ if A.block_shape[1] == 1:
729
1210
  if x.dtype == A.scalar_type:
730
1211
  x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1212
+ else:
1213
+ if A.block_shape[0] == 1:
1214
+ if y.dtype != A.scalar_type:
1215
+ y = y.view(dtype=A.scalar_type)
1216
+ if A.block_shape[1] == 1:
1217
+ if x.dtype != A.scalar_type:
1218
+ x = x.view(dtype=A.scalar_type)
731
1219
 
732
1220
  wp.launch(
733
1221
  kernel=_bsr_mv_kernel,
@@ -735,3 +1223,5 @@ def bsr_mv(A: BsrMatrix, x: wp.array, y: wp.array, alpha: float = 1.0, beta: flo
735
1223
  dim=A.nrow,
736
1224
  inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
737
1225
  )
1226
+
1227
+ return y