warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/sparse.py ADDED
@@ -0,0 +1,1227 @@
1
+ from typing import Any, Generic, Optional, Tuple, TypeVar, Union
2
+
3
+ import warp as wp
4
+ import warp.types
5
+ import warp.utils
6
+ from warp.types import Array, Cols, Matrix, Rows, Scalar, Vector
7
+
8
+ # typing hints
9
+
10
+ _BlockType = TypeVar("BlockType")
11
+
12
+
13
+ class _MatrixBlockType(Matrix):
14
+ pass
15
+
16
+
17
+ class _ScalarBlockType(Generic[Scalar]):
18
+ pass
19
+
20
+
21
+ BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
22
+
23
+ _struct_cache = dict()
24
+
25
+
26
+ class BsrMatrix(Generic[_BlockType]):
27
+ """Untyped base class for BSR and CSR matrices.
28
+
29
+ Should not be constructed directly but through functions such as :func:`bsr_zeros`.
30
+
31
+ Attributes:
32
+ nrow (int): Number of rows of blocks
33
+ ncol (int): Number of columns of blocks
34
+ nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
35
+ offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
36
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
37
+ values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
38
+ """
39
+
40
+ @property
41
+ def scalar_type(self) -> Scalar:
42
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
43
+ return warp.types.type_scalar_type(self.values.dtype)
44
+
45
+ @property
46
+ def block_shape(self) -> Tuple[int, int]:
47
+ """Shape of the individual blocks"""
48
+ return getattr(self.values.dtype, "_shape_", (1, 1))
49
+
50
+ @property
51
+ def block_size(self) -> int:
52
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
53
+ return warp.types.type_length(self.values.dtype)
54
+
55
+ @property
56
+ def shape(self) -> Tuple[int, int]:
57
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
58
+ block_shape = self.block_shape
59
+ return (self.nrow * block_shape[0], self.ncol * block_shape[1])
60
+
61
+ @property
62
+ def dtype(self) -> type:
63
+ """Data type for individual block values"""
64
+ return self.values.dtype
65
+
66
+ @property
67
+ def device(self) -> wp.context.Device:
68
+ """Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays """
69
+ return self.values.device
70
+
71
+
72
+ def bsr_matrix_t(dtype: BlockType):
73
+ dtype = wp.types.type_to_warp(dtype)
74
+
75
+ if not warp.types.type_is_matrix(dtype) and not dtype in warp.types.scalar_types:
76
+ raise ValueError(
77
+ f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
78
+ )
79
+
80
+ class BsrMatrixTyped(BsrMatrix):
81
+ nrow: int
82
+ """Number of rows of blocks"""
83
+ ncol: int
84
+ """Number of columns of blocks"""
85
+ nnz: int
86
+ """Number of non-zero blocks: equal to offsets[-1], cached on host for convenience"""
87
+ offsets: wp.array(dtype=int)
88
+ """Array of size at least 1 + nrows"""
89
+ columns: wp.array(dtype=int)
90
+ """Array of size at least equal to nnz"""
91
+ values: wp.array(dtype=dtype)
92
+
93
+ module = wp.get_module(BsrMatrix.__module__)
94
+
95
+ if hasattr(dtype, "_shape_"):
96
+ type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
97
+ else:
98
+ type_str = dtype.__name__
99
+ key = f"{BsrMatrix.__qualname__}_{type_str}"
100
+
101
+ if key not in _struct_cache:
102
+ _struct_cache[key] = wp.codegen.Struct(
103
+ cls=BsrMatrixTyped,
104
+ key=key,
105
+ module=module,
106
+ )
107
+
108
+ return _struct_cache[key]
109
+
110
+
111
+ def bsr_zeros(
112
+ rows_of_blocks: int,
113
+ cols_of_blocks: int,
114
+ block_type: BlockType,
115
+ device: wp.context.Devicelike = None,
116
+ ) -> BsrMatrix:
117
+ """
118
+ Constructs and returns an empty BSR or CSR matrix with the given shape
119
+
120
+ Args:
121
+ bsr: The BSR or CSR matrix to set to zero
122
+ rows_of_blocks: Number of rows of blocks
123
+ cols_of_blocks: Number of columns of blocks
124
+ block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
125
+ for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
126
+ device: Device on which to allocate the matrix arrays
127
+ """
128
+
129
+ bsr = bsr_matrix_t(block_type)()
130
+
131
+ bsr.nrow = rows_of_blocks
132
+ bsr.ncol = cols_of_blocks
133
+ bsr.nnz = 0
134
+ bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
135
+ bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
136
+ bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
137
+
138
+ return bsr
139
+
140
+
141
+ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
142
+ if nrow is None:
143
+ nrow = bsr.nrow
144
+ if nnz is None:
145
+ nnz = bsr.nnz
146
+
147
+ if bsr.offsets.size < nrow + 1:
148
+ bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
149
+ if bsr.columns.size < nnz:
150
+ bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
151
+ if bsr.values.size < nnz:
152
+ bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
153
+
154
+
155
+ def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
156
+ """
157
+ Sets a BSR matrix to zero, possibly changing its size
158
+
159
+ Args:
160
+ bsr: The BSR or CSR matrix to set to zero
161
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
162
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
163
+ """
164
+
165
+ if rows_of_blocks is not None:
166
+ bsr.nrow = rows_of_blocks
167
+ if cols_of_blocks is not None:
168
+ bsr.ncol = cols_of_blocks
169
+ bsr.nnz = 0
170
+ _bsr_ensure_fits(bsr)
171
+ bsr.offsets.zero_()
172
+
173
+
174
+ def bsr_set_from_triplets(
175
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
176
+ rows: "Array[int]",
177
+ columns: "Array[int]",
178
+ values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
179
+ ):
180
+ """
181
+ Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
182
+
183
+ The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
184
+
185
+ Args:
186
+ dest: Sparse matrix to populate
187
+ rows: Row index for each non-zero
188
+ columns: Columns index for each non-zero
189
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
190
+ to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
191
+ """
192
+
193
+ if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
194
+ raise ValueError("All arguments must reside on the same device")
195
+
196
+ if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]:
197
+ raise ValueError("All triplet arrays must have the same length")
198
+
199
+ # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
200
+ if values.ndim == 1:
201
+ if values.dtype != dest.values.dtype:
202
+ raise ValueError("Values array type must correspond to that of dest matrix")
203
+ elif values.ndim == 3:
204
+ if values.shape[1:] != dest.block_shape:
205
+ raise ValueError(
206
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
207
+ )
208
+
209
+ if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
210
+ raise ValueError("Scalar type of values array should correspond to that of matrix")
211
+
212
+ if not values.is_contiguous:
213
+ raise ValueError("Multi-dimensional values array should be contiguous")
214
+ else:
215
+ raise ValueError("Number of dimension for values array should be 1 or 3")
216
+
217
+ nnz = rows.shape[0]
218
+ if nnz == 0:
219
+ bsr_set_zero(dest)
220
+ return
221
+
222
+ # Increase dest array sizes if needed
223
+ _bsr_ensure_fits(dest, nnz=nnz)
224
+
225
+ device = dest.values.device
226
+ scalar_type = dest.scalar_type
227
+ from warp.context import runtime
228
+
229
+ if device.is_cpu:
230
+ if scalar_type == wp.float32:
231
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
232
+ elif scalar_type == wp.float64:
233
+ native_func = runtime.core.bsr_matrix_from_triplets_double_host
234
+ else:
235
+ if scalar_type == wp.float32:
236
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
237
+ elif scalar_type == wp.float64:
238
+ native_func = runtime.core.bsr_matrix_from_triplets_double_device
239
+
240
+ if not native_func:
241
+ raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
242
+
243
+ dest.nnz = native_func(
244
+ dest.block_shape[0],
245
+ dest.block_shape[1],
246
+ dest.nrow,
247
+ nnz,
248
+ rows.ptr,
249
+ columns.ptr,
250
+ values.ptr,
251
+ dest.offsets.ptr,
252
+ dest.columns.ptr,
253
+ dest.values.ptr,
254
+ )
255
+
256
+
257
+ def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
258
+ """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
259
+
260
+ if dest.values.device != src.values.device:
261
+ raise ValueError("Source and destination matrices must reside on the same device")
262
+
263
+ if dest.block_shape != src.block_shape:
264
+ raise ValueError("Source and destination matrices must have the same block shape")
265
+
266
+ dest.nrow = src.nrow
267
+ dest.ncol = src.ncol
268
+ dest.nnz = src.nnz
269
+
270
+ _bsr_ensure_fits(dest)
271
+
272
+ wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
273
+ if src.nnz > 0:
274
+ wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
275
+ warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
276
+
277
+
278
+ def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
279
+ """Returns a copy of matrix ``A``, possibly changing its scalar type.
280
+
281
+ Args:
282
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
283
+ """
284
+ if scalar_type is None:
285
+ block_type = A.values.dtype
286
+ elif A.block_shape == (1, 1):
287
+ block_type = scalar_type
288
+ else:
289
+ block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type)
290
+
291
+ copy = bsr_zeros(rows_of_blocks=A.nrow, cols_of_blocks=A.ncol, block_type=block_type, device=A.values.device)
292
+ bsr_assign(dest=copy, src=A)
293
+ return copy
294
+
295
+
296
+ def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
297
+ """Assigns the transposed matrix `src` to matrix `dest`"""
298
+
299
+ if dest.values.device != src.values.device:
300
+ raise ValueError("All arguments must reside on the same device")
301
+
302
+ if dest.scalar_type != src.scalar_type:
303
+ raise ValueError("All arguments must have the same scalar type")
304
+
305
+ transpose_block_shape = src.block_shape[::-1]
306
+
307
+ if dest.block_shape != transpose_block_shape:
308
+ raise ValueError(f"Destination block shape must be {transpose_block_shape}")
309
+
310
+ dest.nrow = src.ncol
311
+ dest.ncol = src.nrow
312
+ dest.nnz = src.nnz
313
+
314
+ if src.nnz == 0:
315
+ return
316
+
317
+ # Increase dest array sizes if needed
318
+ _bsr_ensure_fits(dest)
319
+
320
+ from warp.context import runtime
321
+
322
+ if dest.values.device.is_cpu:
323
+ if dest.scalar_type == wp.float32:
324
+ native_func = runtime.core.bsr_transpose_float_host
325
+ elif dest.scalar_type == wp.float64:
326
+ native_func = runtime.core.bsr_transpose_double_host
327
+ else:
328
+ if dest.scalar_type == wp.float32:
329
+ native_func = runtime.core.bsr_transpose_float_device
330
+ elif dest.scalar_type == wp.float64:
331
+ native_func = runtime.core.bsr_transpose_double_device
332
+
333
+ if not native_func:
334
+ raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
335
+
336
+ native_func(
337
+ src.block_shape[0],
338
+ src.block_shape[1],
339
+ src.nrow,
340
+ src.ncol,
341
+ src.nnz,
342
+ src.offsets.ptr,
343
+ src.columns.ptr,
344
+ src.values.ptr,
345
+ dest.offsets.ptr,
346
+ dest.columns.ptr,
347
+ dest.values.ptr,
348
+ )
349
+
350
+
351
+ def bsr_transposed(A: BsrMatrix):
352
+ """Returns a copy of the transposed matrix `A`"""
353
+
354
+ if A.block_shape == (1, 1):
355
+ block_type = A.values.dtype
356
+ else:
357
+ block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type)
358
+
359
+ transposed = bsr_zeros(rows_of_blocks=A.ncol, cols_of_blocks=A.nrow, block_type=block_type, device=A.values.device)
360
+ bsr_set_transpose(dest=transposed, src=A)
361
+ return transposed
362
+
363
+
364
+ @wp.kernel
365
+ def _bsr_get_diag_kernel(
366
+ A_offsets: wp.array(dtype=int),
367
+ A_columns: wp.array(dtype=int),
368
+ A_values: wp.array(dtype=Any),
369
+ out: wp.array(dtype=Any),
370
+ ):
371
+ row = wp.tid()
372
+ beg = A_offsets[row]
373
+ end = A_offsets[row + 1]
374
+
375
+ diag = wp.lower_bound(A_columns, beg, end, row)
376
+ if diag < end:
377
+ if A_columns[diag] == row:
378
+ out[row] = A_values[diag]
379
+
380
+
381
+ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
382
+ """Returns the array of blocks that constitute the diagonal of a sparse matrix.
383
+
384
+ Args:
385
+ A: the sparse matrix from which to extract the diagonal
386
+ out: if provided, the array into which to store the diagonal blocks
387
+ """
388
+
389
+ dim = min(A.nrow, A.ncol)
390
+
391
+ if out is None:
392
+ out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
393
+ else:
394
+ if out.dtype != A.values.dtype:
395
+ raise ValueError(f"Output array must have type {A.values.dtype}")
396
+ if out.device != A.values.device:
397
+ raise ValueError(f"Output array must reside on device {A.values.device}")
398
+ if out.shape[0] < dim:
399
+ raise ValueError(f"Output array must be of length at least {dim}")
400
+
401
+ wp.launch(
402
+ kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
403
+ )
404
+
405
+ return out
406
+
407
+
408
+ @wp.kernel
409
+ def _bsr_set_diag_kernel(
410
+ diag: wp.array(dtype=Any),
411
+ A_offsets: wp.array(dtype=int),
412
+ A_columns: wp.array(dtype=int),
413
+ A_values: wp.array(dtype=Any),
414
+ ):
415
+ row = wp.tid()
416
+ A_offsets[row + 1] = row + 1
417
+ A_columns[row] = row
418
+ A_values[row] = diag[row]
419
+
420
+ if row == 0:
421
+ A_offsets[0] = 0
422
+
423
+
424
+ @wp.kernel
425
+ def _bsr_set_diag_constant_kernel(
426
+ diag_value: Any,
427
+ A_offsets: wp.array(dtype=int),
428
+ A_columns: wp.array(dtype=int),
429
+ A_values: wp.array(dtype=Any),
430
+ ):
431
+ row = wp.tid()
432
+ A_offsets[row + 1] = row + 1
433
+ A_columns[row] = row
434
+ A_values[row] = diag_value
435
+
436
+ if row == 0:
437
+ A_offsets[0] = 0
438
+
439
+
440
+ def bsr_set_diag(
441
+ A: BsrMatrix[BlockType],
442
+ diag: "Union[BlockType, Array[BlockType]]",
443
+ rows_of_blocks: Optional[int] = None,
444
+ cols_of_blocks: Optional[int] = None,
445
+ ):
446
+ """Sets `A` as a block-diagonal matrix
447
+
448
+ Args:
449
+ A: the sparse matrix to modify
450
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
451
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
452
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
453
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
454
+
455
+ The shape of the matrix will be defined one of the following, in that order:
456
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
457
+ - the first dimension of `diag`, if `diag` is an array
458
+ - the current dimensions of `A` otherwise
459
+ """
460
+
461
+ if rows_of_blocks is None and cols_of_blocks is not None:
462
+ rows_of_blocks = cols_of_blocks
463
+ if cols_of_blocks is None and rows_of_blocks is not None:
464
+ cols_of_blocks = rows_of_blocks
465
+
466
+ if warp.types.is_array(diag):
467
+ if rows_of_blocks is None:
468
+ rows_of_blocks = diag.shape[0]
469
+ cols_of_blocks = diag.shape[0]
470
+
471
+ if rows_of_blocks is not None:
472
+ A.nrow = rows_of_blocks
473
+ A.ncol = cols_of_blocks
474
+
475
+ A.nnz = min(A.nrow, A.ncol)
476
+ _bsr_ensure_fits(A)
477
+
478
+ if warp.types.is_array(diag):
479
+ wp.launch(
480
+ kernel=_bsr_set_diag_kernel,
481
+ dim=A.nnz,
482
+ device=A.values.device,
483
+ inputs=[diag, A.offsets, A.columns, A.values],
484
+ )
485
+ else:
486
+ if not warp.types.type_is_value(type(diag)):
487
+ # Cast to launchable type
488
+ diag = A.values.dtype(diag)
489
+ wp.launch(
490
+ kernel=_bsr_set_diag_constant_kernel,
491
+ dim=A.nnz,
492
+ device=A.values.device,
493
+ inputs=[diag, A.offsets, A.columns, A.values],
494
+ )
495
+
496
+
497
+ def bsr_diag(
498
+ diag: "Union[BlockType, Array[BlockType]]",
499
+ rows_of_blocks: Optional[int] = None,
500
+ cols_of_blocks: Optional[int] = None,
501
+ ) -> BsrMatrix["BlockType"]:
502
+ """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
503
+
504
+ Args:
505
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
506
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
507
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
508
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
509
+
510
+ The shape of the matrix will be defined one of the following, in that order:
511
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
512
+ - the first dimension of `diag`, if `diag` is an array
513
+ """
514
+
515
+ if rows_of_blocks is None and cols_of_blocks is not None:
516
+ rows_of_blocks = cols_of_blocks
517
+ if cols_of_blocks is None and rows_of_blocks is not None:
518
+ cols_of_blocks = rows_of_blocks
519
+
520
+ if warp.types.is_array(diag):
521
+ if rows_of_blocks is None:
522
+ rows_of_blocks = diag.shape[0]
523
+ cols_of_blocks = diag.shape[0]
524
+
525
+ A = bsr_zeros(
526
+ rows_of_blocks,
527
+ cols_of_blocks,
528
+ block_type=diag.dtype,
529
+ device=diag.device,
530
+ )
531
+ else:
532
+ if rows_of_blocks is None:
533
+ raise ValueError(
534
+ "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
535
+ )
536
+
537
+ block_type = type(diag)
538
+ if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
539
+ block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
540
+
541
+ A = bsr_zeros(
542
+ rows_of_blocks,
543
+ cols_of_blocks,
544
+ block_type=block_type,
545
+ )
546
+
547
+ bsr_set_diag(A, diag)
548
+ return A
549
+
550
+
551
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
552
+ """Sets `A` as the identity matrix
553
+
554
+ Args:
555
+ A: the sparse matrix to modify
556
+ rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
557
+ """
558
+
559
+ if A.block_shape == (1, 1):
560
+ identity = A.scalar_type(1.0)
561
+ else:
562
+ from numpy import eye
563
+
564
+ identity = eye(A.block_shape[0])
565
+
566
+ bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
567
+
568
+
569
+ def bsr_identity(
570
+ rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
571
+ ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
572
+ """Creates and returns a square identity matrix.
573
+
574
+ Args:
575
+ rows_of_blocks: Number of rows and columns of blocks in the created matrix.
576
+ block_type: Block type for the newly created matrix -- must be square
577
+ device: Device onto which to allocate the data arrays
578
+ """
579
+ A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
580
+ bsr_set_identity(A)
581
+ return A
582
+
583
+
584
+ @wp.kernel
585
+ def _bsr_scale_kernel(
586
+ alpha: Any,
587
+ values: wp.array(dtype=Any),
588
+ ):
589
+ values[wp.tid()] = alpha * values[wp.tid()]
590
+
591
+
592
+ def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
593
+ """
594
+ Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
595
+ """
596
+
597
+ if alpha != 1.0 and x.nnz > 0:
598
+ if alpha == 0.0:
599
+ bsr_set_zero(x)
600
+ else:
601
+ if not isinstance(alpha, x.scalar_type):
602
+ alpha = x.scalar_type(alpha)
603
+
604
+ wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
605
+
606
+ return x
607
+
608
+
609
+ @wp.kernel
610
+ def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
611
+ i = wp.tid()
612
+
613
+ row = wp.lower_bound(bsr_offsets, i + 1) - 1
614
+ rows[dest_offset + i] = row
615
+
616
+
617
+ @wp.kernel
618
+ def _bsr_axpy_add_block(
619
+ src_offset: int,
620
+ scale: Any,
621
+ rows: wp.array(dtype=int),
622
+ cols: wp.array(dtype=int),
623
+ dst_offsets: wp.array(dtype=int),
624
+ dst_columns: wp.array(dtype=int),
625
+ src_values: wp.array(dtype=Any),
626
+ dst_values: wp.array(dtype=Any),
627
+ ):
628
+ i = wp.tid()
629
+ row = rows[i + src_offset]
630
+ col = cols[i + src_offset]
631
+ beg = dst_offsets[row]
632
+ end = dst_offsets[row + 1]
633
+
634
+ block = wp.lower_bound(dst_columns, beg, end, col)
635
+
636
+ dst_values[block] = dst_values[block] + scale * src_values[i]
637
+
638
+
639
+ class bsr_axpy_work_arrays:
640
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
641
+
642
+ def __init__(self):
643
+ self._reset(None)
644
+
645
+ def _reset(self, device):
646
+ self.device = device
647
+ self._sum_rows = None
648
+ self._sum_cols = None
649
+ self._old_y_values = None
650
+ self._old_x_values = None
651
+
652
+ def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
653
+ if self.device != device:
654
+ self._reset(device)
655
+
656
+ if self._sum_rows is None or self._sum_rows.size < sum_nnz:
657
+ self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
658
+ if self._sum_cols is None or self._sum_cols.size < sum_nnz:
659
+ self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
660
+
661
+ if self._old_y_values is None or self._old_y_values.size < y.nnz:
662
+ self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
663
+
664
+
665
+ def bsr_axpy(
666
+ x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
667
+ y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
668
+ alpha: Scalar = 1.0,
669
+ beta: Scalar = 1.0,
670
+ work_arrays: Optional[bsr_axpy_work_arrays] = None,
671
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
672
+ """
673
+ Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
674
+
675
+ The `x` and `y` matrices are allowed to alias.
676
+
677
+ Args:
678
+ x: Read-only right-hand-side.
679
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
680
+ alpha: Uniform scaling factor for `x`
681
+ beta: Uniform scaling factor for `y`
682
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
683
+ """
684
+
685
+ if y is None:
686
+ # If not output matrix is provided, allocate it for convenience
687
+ y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
688
+ beta = 0.0
689
+
690
+ # Handle easy cases first
691
+ if beta == 0.0 or y.nnz == 0:
692
+ bsr_assign(src=x, dest=y)
693
+ return bsr_scale(y, alpha=alpha)
694
+
695
+ if alpha == 0.0 or x.nnz == 0:
696
+ return bsr_scale(y, alpha=beta)
697
+
698
+ if not isinstance(alpha, y.scalar_type):
699
+ alpha = y.scalar_type(alpha)
700
+ if not isinstance(beta, y.scalar_type):
701
+ beta = y.scalar_type(beta)
702
+
703
+ if x == y:
704
+ # Aliasing case
705
+ return bsr_scale(y, alpha=alpha.value + beta.value)
706
+
707
+ # General case
708
+
709
+ if x.values.device != y.values.device:
710
+ raise ValueError("All arguments must reside on the same device")
711
+
712
+ if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
713
+ raise ValueError("Matrices must have the same block type")
714
+
715
+ if x.nrow != y.nrow or x.ncol != y.ncol:
716
+ raise ValueError("Matrices must have the same number of rows and columns")
717
+
718
+ if work_arrays is None:
719
+ work_arrays = bsr_axpy_work_arrays()
720
+
721
+ sum_nnz = x.nnz + y.nnz
722
+ device = y.values.device
723
+ work_arrays._allocate(device, y, sum_nnz)
724
+
725
+ wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
726
+ wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
727
+
728
+ wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
729
+ wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
730
+
731
+ # Save old y values before overwriting matrix
732
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
733
+
734
+ # Increase dest array sizes if needed
735
+ if y.columns.shape[0] < sum_nnz:
736
+ y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device)
737
+
738
+ from warp.context import runtime
739
+
740
+ if device.is_cpu:
741
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
742
+ else:
743
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
744
+
745
+ old_y_nnz = y.nnz
746
+ y.nnz = native_func(
747
+ y.block_shape[0],
748
+ y.block_shape[1],
749
+ y.nrow,
750
+ sum_nnz,
751
+ work_arrays._sum_rows.ptr,
752
+ work_arrays._sum_cols.ptr,
753
+ 0,
754
+ y.offsets.ptr,
755
+ y.columns.ptr,
756
+ 0,
757
+ )
758
+
759
+ _bsr_ensure_fits(y)
760
+ y.values.zero_()
761
+
762
+ wp.launch(
763
+ kernel=_bsr_axpy_add_block,
764
+ device=device,
765
+ dim=old_y_nnz,
766
+ inputs=[
767
+ 0,
768
+ beta,
769
+ work_arrays._sum_rows,
770
+ work_arrays._sum_cols,
771
+ y.offsets,
772
+ y.columns,
773
+ work_arrays._old_y_values,
774
+ y.values,
775
+ ],
776
+ )
777
+
778
+ wp.launch(
779
+ kernel=_bsr_axpy_add_block,
780
+ device=device,
781
+ dim=x.nnz,
782
+ inputs=[
783
+ old_y_nnz,
784
+ alpha,
785
+ work_arrays._sum_rows,
786
+ work_arrays._sum_cols,
787
+ y.offsets,
788
+ y.columns,
789
+ x.values,
790
+ y.values,
791
+ ],
792
+ )
793
+
794
+ return y
795
+
796
+
797
+ @wp.kernel
798
+ def _bsr_mm_count_coeffs(
799
+ z_nnz: int,
800
+ x_offsets: wp.array(dtype=int),
801
+ x_columns: wp.array(dtype=int),
802
+ y_offsets: wp.array(dtype=int),
803
+ counts: wp.array(dtype=int),
804
+ ):
805
+ row = wp.tid()
806
+ count = int(0)
807
+
808
+ x_beg = x_offsets[row]
809
+ x_end = x_offsets[row + 1]
810
+
811
+ for x_block in range(x_beg, x_end):
812
+ x_col = x_columns[x_block]
813
+ count += y_offsets[x_col + 1] - y_offsets[x_col]
814
+
815
+ counts[row + 1] = count
816
+
817
+ if row == 0:
818
+ counts[0] = z_nnz
819
+
820
+
821
+ @wp.kernel
822
+ def _bsr_mm_list_coeffs(
823
+ x_offsets: wp.array(dtype=int),
824
+ x_columns: wp.array(dtype=int),
825
+ y_offsets: wp.array(dtype=int),
826
+ y_columns: wp.array(dtype=int),
827
+ mm_offsets: wp.array(dtype=int),
828
+ mm_rows: wp.array(dtype=int),
829
+ mm_cols: wp.array(dtype=int),
830
+ ):
831
+ row = wp.tid()
832
+ mm_block = mm_offsets[row]
833
+
834
+ x_beg = x_offsets[row]
835
+ x_end = x_offsets[row + 1]
836
+
837
+ for x_block in range(x_beg, x_end):
838
+ x_col = x_columns[x_block]
839
+
840
+ y_beg = y_offsets[x_col]
841
+ y_end = y_offsets[x_col + 1]
842
+ for y_block in range(y_beg, y_end):
843
+ mm_cols[mm_block] = y_columns[y_block]
844
+ mm_rows[mm_block] = row
845
+ mm_block += 1
846
+
847
+
848
+ @wp.kernel
849
+ def _bsr_mm_compute_values(
850
+ alpha: Any,
851
+ x_offsets: wp.array(dtype=int),
852
+ x_columns: wp.array(dtype=int),
853
+ x_values: wp.array(dtype=Any),
854
+ y_offsets: wp.array(dtype=int),
855
+ y_columns: wp.array(dtype=int),
856
+ y_values: wp.array(dtype=Any),
857
+ mm_offsets: wp.array(dtype=int),
858
+ mm_cols: wp.array(dtype=int),
859
+ mm_values: wp.array(dtype=Any),
860
+ ):
861
+ row = wp.tid()
862
+ mm_beg = mm_offsets[row]
863
+ mm_end = mm_offsets[row + 1]
864
+
865
+ x_beg = x_offsets[row]
866
+ x_end = x_offsets[row + 1]
867
+ for x_block in range(x_beg, x_end):
868
+ x_col = x_columns[x_block]
869
+ ax_val = alpha * x_values[x_block]
870
+
871
+ y_beg = y_offsets[x_col]
872
+ y_end = y_offsets[x_col + 1]
873
+
874
+ for y_block in range(y_beg, y_end):
875
+ mm_block = wp.lower_bound(mm_cols, mm_beg, mm_end, y_columns[y_block])
876
+ mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
877
+
878
+
879
+ class bsr_mm_work_arrays:
880
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
881
+
882
+ def __init__(self):
883
+ self._reset(None)
884
+
885
+ def _reset(self, device):
886
+ self.device = device
887
+ self._pinned_count_buffer = None
888
+ self._mm_row_counts = None
889
+ self._mm_rows = None
890
+ self._mm_cols = None
891
+ self._old_z_values = None
892
+ self._old_z_offsets = None
893
+ self._old_z_columns = None
894
+
895
+ def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
896
+ if self.device != device:
897
+ self._reset(device)
898
+
899
+ # Allocations that do not depend on any computation
900
+ if self.device.is_cuda:
901
+ if self._pinned_count_buffer is None:
902
+ self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
903
+
904
+ if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
905
+ self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
906
+
907
+ if copied_z_nnz > 0:
908
+ if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
909
+ self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
910
+
911
+ if z_aliasing:
912
+ if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
913
+ self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
914
+ if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
915
+ self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
916
+
917
+ def _allocate_stage_2(self, mm_nnz: int):
918
+ # Allocations that depend on unmerged nnz estimate
919
+ if self._mm_rows is None or self._mm_rows.size < mm_nnz:
920
+ self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
921
+ if self._mm_cols is None or self._mm_cols.size < mm_nnz:
922
+ self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
923
+
924
+
925
+ def bsr_mm(
926
+ x: BsrMatrix[BlockType[Rows, Any, Scalar]],
927
+ y: BsrMatrix[BlockType[Any, Cols, Scalar]],
928
+ z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
929
+ alpha: Scalar = 1.0,
930
+ beta: Scalar = 0.0,
931
+ work_arrays: Optional[bsr_mm_work_arrays] = None,
932
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
933
+ """
934
+ Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
935
+
936
+ The `x`, `y` and `z` matrices are allowed to alias.
937
+ If the matrix `z` is not provided as input, it will be allocated and treated as zero.
938
+
939
+ Args:
940
+ x: Read-only left factor of the matrix-matrix product.
941
+ y: Read-only right factor of the matrix-matrix product.
942
+ z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
943
+ alpha: Uniform scaling factor for the ``x * y`` product
944
+ beta: Uniform scaling factor for `z`
945
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
946
+ """
947
+
948
+ if z is None:
949
+ # If not output matrix is provided, allocate it for convenience
950
+ z_block_shape = (x.block_shape[0], y.block_shape[1])
951
+ if z_block_shape == (1, 1):
952
+ z_block_type = x.scalar_type
953
+ else:
954
+ z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type)
955
+ z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
956
+ beta = 0.0
957
+
958
+ if x.values.device != y.values.device or x.values.device != z.values.device:
959
+ raise ValueError("All arguments must reside on the same device")
960
+
961
+ if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
962
+ raise ValueError("Matrices must have the same scalar type")
963
+
964
+ if (
965
+ x.block_shape[0] != z.block_shape[0]
966
+ or y.block_shape[1] != z.block_shape[1]
967
+ or x.block_shape[1] != y.block_shape[0]
968
+ ):
969
+ raise ValueError("Incompatible block sizes for matrix multiplication")
970
+
971
+ if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
972
+ raise ValueError("Incompatible number of rows/columns for matrix multiplication")
973
+
974
+ device = z.values.device
975
+
976
+ if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
977
+ # Easy case
978
+ return bsr_scale(z, beta)
979
+
980
+ if not isinstance(alpha, z.scalar_type):
981
+ alpha = z.scalar_type(alpha)
982
+ if not isinstance(beta, z.scalar_type):
983
+ beta = z.scalar_type(beta)
984
+
985
+ if work_arrays is None:
986
+ work_arrays = bsr_mm_work_arrays()
987
+
988
+ z_aliasing = z == x or z == y
989
+ copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
990
+
991
+ work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
992
+
993
+ # Prefix sum of number of (unmerged) mm blocks per row
994
+ wp.launch(
995
+ kernel=_bsr_mm_count_coeffs,
996
+ device=device,
997
+ dim=z.nrow,
998
+ inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
999
+ )
1000
+ warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
1001
+
1002
+ # Get back total counts on host
1003
+ if device.is_cuda:
1004
+ wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
1005
+ wp.synchronize_stream(wp.get_stream(device))
1006
+ mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
1007
+ else:
1008
+ mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
1009
+
1010
+ work_arrays._allocate_stage_2(mm_nnz)
1011
+
1012
+ # If z has a non-zero scale, save current data before overwriting it
1013
+ if copied_z_nnz > 0:
1014
+ # Copy z row and column indices
1015
+ wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1016
+ wp.launch(
1017
+ kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
1018
+ )
1019
+ # Save current z values in temporary buffer
1020
+ wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1021
+ if z_aliasing:
1022
+ # If z is aliasing with x or y, need to save topology as well
1023
+ wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1024
+ wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1025
+
1026
+ # Fill unmerged mm blocks rows and columns
1027
+ wp.launch(
1028
+ kernel=_bsr_mm_list_coeffs,
1029
+ device=device,
1030
+ dim=z.nrow,
1031
+ inputs=[
1032
+ x.offsets,
1033
+ x.columns,
1034
+ y.offsets,
1035
+ y.columns,
1036
+ work_arrays._mm_row_counts,
1037
+ work_arrays._mm_rows,
1038
+ work_arrays._mm_cols,
1039
+ ],
1040
+ )
1041
+
1042
+ # Increase dest array size if needed
1043
+ if z.columns.shape[0] < mm_nnz:
1044
+ z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1045
+
1046
+ from warp.context import runtime
1047
+
1048
+ if device.is_cpu:
1049
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
1050
+ else:
1051
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
1052
+
1053
+ z.nnz = native_func(
1054
+ z.block_shape[0],
1055
+ z.block_shape[1],
1056
+ z.nrow,
1057
+ mm_nnz,
1058
+ work_arrays._mm_rows.ptr,
1059
+ work_arrays._mm_cols.ptr,
1060
+ 0,
1061
+ z.offsets.ptr,
1062
+ z.columns.ptr,
1063
+ 0,
1064
+ )
1065
+
1066
+ _bsr_ensure_fits(z)
1067
+ z.values.zero_()
1068
+
1069
+ if copied_z_nnz > 0:
1070
+ # Add back original z values
1071
+ wp.launch(
1072
+ kernel=_bsr_axpy_add_block,
1073
+ device=device,
1074
+ dim=copied_z_nnz,
1075
+ inputs=[
1076
+ 0,
1077
+ beta,
1078
+ work_arrays._mm_rows,
1079
+ work_arrays._mm_cols,
1080
+ z.offsets,
1081
+ z.columns,
1082
+ work_arrays._old_z_values,
1083
+ z.values,
1084
+ ],
1085
+ )
1086
+
1087
+ # Add mm blocks to z values
1088
+ if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
1089
+ warp.types.type_is_matrix(z.values.dtype)
1090
+ ):
1091
+ # Result block type is scalar, but operands are matrices
1092
+ # Cast result to (1x1) matrix to perform multiplication
1093
+ mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1094
+ else:
1095
+ mm_values = z.values
1096
+
1097
+ wp.launch(
1098
+ kernel=_bsr_mm_compute_values,
1099
+ device=device,
1100
+ dim=z.nrow,
1101
+ inputs=[
1102
+ alpha,
1103
+ work_arrays._old_z_offsets if x == z else x.offsets,
1104
+ work_arrays._old_z_columns if x == z else x.columns,
1105
+ work_arrays._old_z_values if x == z else x.values,
1106
+ work_arrays._old_z_offsets if y == z else y.offsets,
1107
+ work_arrays._old_z_columns if y == z else y.columns,
1108
+ work_arrays._old_z_values if y == z else y.values,
1109
+ z.offsets,
1110
+ z.columns,
1111
+ mm_values,
1112
+ ],
1113
+ )
1114
+
1115
+ return z
1116
+
1117
+
1118
+ @wp.kernel
1119
+ def _bsr_mv_kernel(
1120
+ alpha: Any,
1121
+ A_offsets: wp.array(dtype=int),
1122
+ A_columns: wp.array(dtype=int),
1123
+ A_values: wp.array(dtype=Any),
1124
+ x: wp.array(dtype=Any),
1125
+ beta: Any,
1126
+ y: wp.array(dtype=Any),
1127
+ ):
1128
+ row = wp.tid()
1129
+
1130
+ # zero-initialize with type of y elements
1131
+ scalar_zero = type(alpha)(0)
1132
+ v = y.dtype(scalar_zero)
1133
+
1134
+ if alpha != scalar_zero:
1135
+ beg = A_offsets[row]
1136
+ end = A_offsets[row + 1]
1137
+ for block in range(beg, end):
1138
+ v += A_values[block] * x[A_columns[block]]
1139
+ v *= alpha
1140
+
1141
+ if beta != scalar_zero:
1142
+ v += beta * y[row]
1143
+
1144
+ y[row] = v
1145
+
1146
+
1147
+ def bsr_mv(
1148
+ A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
1149
+ x: "Array[Vector[Cols, Scalar] | Scalar]",
1150
+ y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1151
+ alpha: Scalar = 1.0,
1152
+ beta: Scalar = 0.0,
1153
+ work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1154
+ ) -> "Array[Vector[Rows, Scalar] | Scalar]":
1155
+ """
1156
+ Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
1157
+
1158
+ The `x` and `y` vectors are allowed to alias.
1159
+
1160
+ Args:
1161
+ A: Read-only, left matrix factor of the matrix-vector product.
1162
+ x: Read-only, right vector factor of the matrix-vector product.
1163
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1164
+ alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1165
+ beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1166
+ work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1167
+ will be used for this purpose, otherwise a temporary allocation will be performed.
1168
+ """
1169
+
1170
+ if y is None:
1171
+ # If no output array is provided, allocate one for convenience
1172
+ y_vec_len = A.block_shape[0]
1173
+ y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1174
+ y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
1175
+ y.zero_()
1176
+ beta = 0.0
1177
+
1178
+ if not isinstance(alpha, A.scalar_type):
1179
+ alpha = A.scalar_type(alpha)
1180
+ if not isinstance(beta, A.scalar_type):
1181
+ beta = A.scalar_type(beta)
1182
+
1183
+ if A.values.device != x.device or A.values.device != y.device:
1184
+ raise ValueError("A, x and y must reside on the same device")
1185
+
1186
+ if x.shape[0] != A.ncol:
1187
+ raise ValueError("Number of columns of A must match number of rows of x")
1188
+ if y.shape[0] != A.nrow:
1189
+ raise ValueError("Number of rows of A must match number of rows of y")
1190
+
1191
+ if x == y:
1192
+ # Aliasing case, need temporary storage
1193
+ if work_buffer is None:
1194
+ work_buffer = wp.empty_like(y)
1195
+ elif work_buffer.size < y.size:
1196
+ raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
1197
+ elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
1198
+ raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
1199
+
1200
+ # Save old y values before overwriting vector
1201
+ wp.copy(dest=work_buffer, src=y, count=y.size)
1202
+ x = work_buffer
1203
+
1204
+ # Promote scalar vectors to length-1 vecs and conversely
1205
+ if warp.types.type_is_matrix(A.values.dtype):
1206
+ if A.block_shape[0] == 1:
1207
+ if y.dtype == A.scalar_type:
1208
+ y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1209
+ if A.block_shape[1] == 1:
1210
+ if x.dtype == A.scalar_type:
1211
+ x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1212
+ else:
1213
+ if A.block_shape[0] == 1:
1214
+ if y.dtype != A.scalar_type:
1215
+ y = y.view(dtype=A.scalar_type)
1216
+ if A.block_shape[1] == 1:
1217
+ if x.dtype != A.scalar_type:
1218
+ x = x.view(dtype=A.scalar_type)
1219
+
1220
+ wp.launch(
1221
+ kernel=_bsr_mv_kernel,
1222
+ device=A.values.device,
1223
+ dim=A.nrow,
1224
+ inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
1225
+ )
1226
+
1227
+ return y