warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -1,10 +1,39 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import ctypes
2
17
  from typing import Any, Generic, Optional, Tuple, TypeVar, Union
3
18
 
4
19
  import warp as wp
5
20
  import warp.types
6
21
  import warp.utils
7
- from warp.types import Array, Cols, Rows, Scalar, Vector
22
+ from warp.types import (
23
+ Array,
24
+ Cols,
25
+ Rows,
26
+ Scalar,
27
+ Vector,
28
+ is_array,
29
+ scalar_types,
30
+ type_is_matrix,
31
+ type_length,
32
+ type_repr,
33
+ type_scalar_type,
34
+ type_to_warp,
35
+ types_equal,
36
+ )
8
37
 
9
38
  # typing hints
10
39
 
@@ -30,50 +59,89 @@ class BsrMatrix(Generic[_BlockType]):
30
59
  Should not be constructed directly but through functions such as :func:`bsr_zeros`.
31
60
 
32
61
  Attributes:
33
- nrow (int): Number of rows of blocks
34
- ncol (int): Number of columns of blocks
35
- nnz (int): Upper bound for the number of non-zero blocks, used for dimensioning launches; the exact number is at ``offsets[nrow-1]``. See also :meth:`nnz_sync`.
36
- offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
37
- columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
38
- values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
62
+ nrow (int): Number of rows of blocks.
63
+ ncol (int): Number of columns of blocks.
64
+ nnz (int): Upper bound for the number of non-zero blocks, used for
65
+ dimensioning launches. The exact number is at ``offsets[nrow-1]``.
66
+ See also :meth:`nnz_sync`.
67
+ offsets (Array[int]): Array of size at least ``1 + nrow`` such that the
68
+ start and end indices of the blocks of row ``r`` are ``offsets[r]``
69
+ and ``offsets[r+1]``, respectively.
70
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing
71
+ block column indices.
72
+ values (Array[BlockType]): Array of size at least equal to ``nnz``
73
+ containing block values.
39
74
  """
40
75
 
41
76
  @property
42
77
  def scalar_type(self) -> Scalar:
43
- """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
44
- return warp.types.type_scalar_type(self.values.dtype)
78
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type."""
79
+ return type_scalar_type(self.values.dtype)
45
80
 
46
81
  @property
47
82
  def block_shape(self) -> Tuple[int, int]:
48
- """Shape of the individual blocks"""
83
+ """Shape of the individual blocks."""
49
84
  return getattr(self.values.dtype, "_shape_", (1, 1))
50
85
 
51
86
  @property
52
87
  def block_size(self) -> int:
53
- """Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
54
- return warp.types.type_length(self.values.dtype)
88
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
89
+ return type_length(self.values.dtype)
55
90
 
56
91
  @property
57
92
  def shape(self) -> Tuple[int, int]:
58
- """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
93
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block."""
59
94
  block_shape = self.block_shape
60
95
  return (self.nrow * block_shape[0], self.ncol * block_shape[1])
61
96
 
62
97
  @property
63
98
  def dtype(self) -> type:
64
- """Data type for individual block values"""
99
+ """Data type for individual block values."""
65
100
  return self.values.dtype
66
101
 
67
102
  @property
68
103
  def device(self) -> wp.context.Device:
69
- """Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays"""
104
+ """Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
70
105
  return self.values.device
71
106
 
107
+ @property
108
+ def scalar_values(self) -> wp.array:
109
+ """Accesses the ``values`` array as a 3d scalar array."""
110
+ if self.block_shape == (1, 1):
111
+ return self.values.reshape((self.nnz, 1, 1))
112
+
113
+ def _as_3d_array(arr):
114
+ return wp.array(
115
+ ptr=arr.ptr,
116
+ capacity=arr.capacity,
117
+ device=arr.device,
118
+ dtype=self.scalar_type,
119
+ shape=(self.nnz, *self.block_shape),
120
+ grad=None if arr.grad is None else _as_3d_array(arr.grad),
121
+ )
122
+
123
+ values_view = _as_3d_array(self.values)
124
+ values_view._ref = self.values # keep ref in case we're garbage collected
125
+ return values_view
126
+
127
+ def uncompress_rows(self, out: wp.array = None) -> wp.array:
128
+ """Compute the row index for each non-zero block from the compressed row offsets."""
129
+ if out is None:
130
+ out = wp.empty(self.nnz, dtype=int, device=self.device)
131
+
132
+ wp.launch(
133
+ kernel=_bsr_get_block_row,
134
+ device=self.device,
135
+ dim=self.nnz,
136
+ inputs=[self.nrow, self.offsets, out],
137
+ )
138
+ return out
139
+
72
140
  def nnz_sync(self):
73
- """Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
74
- and updates the nnz upper bound.
141
+ """Ensure that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed
142
+ and update the nnz upper bound.
75
143
 
76
- See also :meth:`copy_nnz_async`
144
+ See also :meth:`copy_nnz_async`.
77
145
  """
78
146
 
79
147
  if self._is_nnz_transfer_setup():
@@ -84,10 +152,11 @@ class BsrMatrix(Generic[_BlockType]):
84
152
 
85
153
  def copy_nnz_async(self, known_nnz: int = None):
86
154
  """
87
- Starts the asynchronous transfer of the exact nnz from the device offsets array to host, and records an event for completion.
155
+ Start the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
156
+
88
157
  Needs to be called whenever the offsets array has been modified from outside ``warp.sparse``.
89
158
 
90
- See also :meth:`nnz_sync`
159
+ See also :meth:`nnz_sync`.
91
160
  """
92
161
  if known_nnz is not None:
93
162
  self.nnz = int(known_nnz)
@@ -171,35 +240,33 @@ class BsrMatrix(Generic[_BlockType]):
171
240
  return _BsrScalingExpression(self, -1.0)
172
241
 
173
242
  def transpose(self):
174
- """Returns a transposed copy of this matrix"""
243
+ """Return a transposed copy of this matrix."""
175
244
  return bsr_transposed(self)
176
245
 
177
246
 
178
247
  def bsr_matrix_t(dtype: BlockType):
179
- dtype = wp.types.type_to_warp(dtype)
248
+ dtype = type_to_warp(dtype)
180
249
 
181
- if not warp.types.type_is_matrix(dtype) and dtype not in warp.types.scalar_types:
182
- raise ValueError(
183
- f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
184
- )
250
+ if not type_is_matrix(dtype) and dtype not in scalar_types:
251
+ raise ValueError(f"BsrMatrix block type must be either warp matrix or scalar; got {type_repr(dtype)}")
185
252
 
186
253
  class BsrMatrixTyped(BsrMatrix):
187
254
  nrow: int
188
- """Number of rows of blocks"""
255
+ """Number of rows of blocks."""
189
256
  ncol: int
190
- """Number of columns of blocks"""
257
+ """Number of columns of blocks."""
191
258
  nnz: int
192
- """Upper bound for the number of non-zeros"""
259
+ """Upper bound for the number of non-zeros."""
193
260
  offsets: wp.array(dtype=int)
194
- """Array of size at least 1 + nrows"""
261
+ """Array of size at least ``1 + nrow``."""
195
262
  columns: wp.array(dtype=int)
196
- """Array of size at least equal to nnz"""
263
+ """Array of size at least equal to ``nnz``."""
197
264
  values: wp.array(dtype=dtype)
198
265
 
199
266
  module = wp.get_module(BsrMatrix.__module__)
200
267
 
201
268
  if hasattr(dtype, "_shape_"):
202
- type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
269
+ type_str = f"{type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
203
270
  else:
204
271
  type_str = dtype.__name__
205
272
  key = f"{BsrMatrix.__qualname__}_{type_str}"
@@ -220,16 +287,16 @@ def bsr_zeros(
220
287
  block_type: BlockType,
221
288
  device: wp.context.Devicelike = None,
222
289
  ) -> BsrMatrix:
223
- """
224
- Constructs and returns an empty BSR or CSR matrix with the given shape
290
+ """Construct and return an empty BSR or CSR matrix with the given shape.
225
291
 
226
292
  Args:
227
- bsr: The BSR or CSR matrix to set to zero
228
- rows_of_blocks: Number of rows of blocks
229
- cols_of_blocks: Number of columns of blocks
230
- block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
231
- for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
232
- device: Device on which to allocate the matrix arrays
293
+ bsr: The BSR or CSR matrix to set to zero.
294
+ rows_of_blocks: Number of rows of blocks.
295
+ cols_of_blocks: Number of columns of blocks.
296
+ block_type: Type of individual blocks.
297
+ For CSR matrices, this should be a scalar type.
298
+ For BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`).
299
+ device: Device on which to allocate the matrix arrays.
233
300
  """
234
301
 
235
302
  bsr = bsr_matrix_t(block_type)()
@@ -266,13 +333,12 @@ def bsr_set_zero(
266
333
  rows_of_blocks: Optional[int] = None,
267
334
  cols_of_blocks: Optional[int] = None,
268
335
  ):
269
- """
270
- Sets a BSR matrix to zero, possibly changing its size
336
+ """Set a BSR matrix to zero, possibly changing its size.
271
337
 
272
338
  Args:
273
- bsr: The BSR or CSR matrix to set to zero
274
- rows_of_blocks: If not ``None``, the new number of rows of blocks
275
- cols_of_blocks: If not ``None``, the new number of columns of blocks
339
+ bsr: The BSR or CSR matrix to set to zero.
340
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
341
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
276
342
  """
277
343
 
278
344
  if rows_of_blocks is not None:
@@ -289,46 +355,55 @@ def bsr_set_from_triplets(
289
355
  dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
290
356
  rows: "Array[int]",
291
357
  columns: "Array[int]",
292
- values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
358
+ values: Optional["Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]"] = None,
293
359
  prune_numerical_zeros: bool = True,
360
+ masked: bool = False,
294
361
  ):
295
- """
296
- Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
362
+ """Fill a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
297
363
 
298
364
  The first dimension of the three input arrays must match and indicates the number of COO triplets.
299
365
 
300
366
  Args:
301
- dest: Sparse matrix to populate
302
- rows: Row index for each non-zero
303
- columns: Columns index for each non-zero
367
+ dest: Sparse matrix to populate.
368
+ rows: Row index for each non-zero.
369
+ columns: Columns index for each non-zero.
304
370
  values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
305
- to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
306
- prune_numerical_zeros: If True, will ignore the zero-valued blocks
371
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
372
+ If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
373
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
374
+ masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
307
375
  """
308
376
 
309
- if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
377
+ if rows.device != columns.device or rows.device != dest.device:
310
378
  raise ValueError("All arguments must reside on the same device")
311
379
 
312
- if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]:
380
+ if rows.shape[0] != columns.shape[0]:
313
381
  raise ValueError("All triplet arrays must have the same length")
314
382
 
315
383
  # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
316
- if values.ndim == 1:
317
- if values.dtype != dest.values.dtype:
318
- raise ValueError("Values array type must correspond to that of dest matrix")
319
- elif values.ndim == 3:
320
- if values.shape[1:] != dest.block_shape:
321
- raise ValueError(
322
- f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
323
- )
384
+ if values is not None:
385
+ if values.device != rows.device:
386
+ raise ValueError("All arguments must reside on the same device")
387
+
388
+ if values.shape[0] != rows.shape[0]:
389
+ raise ValueError("All triplet arrays must have the same length")
390
+
391
+ if values.ndim == 1:
392
+ if values.dtype != dest.values.dtype:
393
+ raise ValueError("Values array type must correspond to that of dest matrix")
394
+ elif values.ndim == 3:
395
+ if values.shape[1:] != dest.block_shape:
396
+ raise ValueError(
397
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
398
+ )
324
399
 
325
- if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
326
- raise ValueError("Scalar type of values array should correspond to that of matrix")
400
+ if type_scalar_type(values.dtype) != dest.scalar_type:
401
+ raise ValueError("Scalar type of values array should correspond to that of matrix")
327
402
 
328
- if not values.is_contiguous:
329
- raise ValueError("Multi-dimensional values array should be contiguous")
330
- else:
331
- raise ValueError("Number of dimension for values array should be 1 or 3")
403
+ if not values.is_contiguous:
404
+ raise ValueError("Multi-dimensional values array should be contiguous")
405
+ else:
406
+ raise ValueError("Number of dimension for values array should be 1 or 3")
332
407
 
333
408
  nnz = rows.shape[0]
334
409
  if nnz == 0:
@@ -336,7 +411,8 @@ def bsr_set_from_triplets(
336
411
  return
337
412
 
338
413
  # Increase dest array sizes if needed
339
- _bsr_ensure_fits(dest, nnz=nnz)
414
+ if not masked:
415
+ _bsr_ensure_fits(dest, nnz=nnz)
340
416
 
341
417
  device = dest.values.device
342
418
  scalar_type = dest.scalar_type
@@ -366,16 +442,51 @@ def bsr_set_from_triplets(
366
442
  nnz,
367
443
  ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
368
444
  ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
369
- ctypes.cast(values.ptr, ctypes.c_void_p),
445
+ None if values is None else ctypes.cast(values.ptr, ctypes.c_void_p),
370
446
  prune_numerical_zeros,
447
+ masked,
371
448
  ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
372
449
  ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
373
- ctypes.cast(dest.values.ptr, ctypes.c_void_p),
450
+ None if values is None else ctypes.cast(dest.values.ptr, ctypes.c_void_p),
374
451
  ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
375
452
  nnz_event,
376
453
  )
377
454
 
378
455
 
456
+ def bsr_from_triplets(
457
+ rows_of_blocks: int,
458
+ cols_of_blocks: int,
459
+ rows: "Array[int]",
460
+ columns: "Array[int]",
461
+ values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
462
+ prune_numerical_zeros: bool = True,
463
+ ):
464
+ """Constructs a BSR matrix with values defined by coordinate-oriented (COO) triplets.
465
+
466
+ The first dimension of the three input arrays must match and indicates the number of COO triplets.
467
+
468
+ Args:
469
+ rows_of_blocks: Number of rows of blocks.
470
+ cols_of_blocks: Number of columns of blocks.
471
+ rows: Row index for each non-zero.
472
+ columns: Columns index for each non-zero.
473
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
474
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
475
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
476
+ """
477
+
478
+ if values.ndim == 3:
479
+ block_type = wp.mat(shape=values.shape[1:], dtype=values.dtype)
480
+ else:
481
+ block_type = values.dtype
482
+
483
+ A = bsr_zeros(
484
+ rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
485
+ )
486
+ bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
487
+ return A
488
+
489
+
379
490
  class _BsrExpression(Generic[_BlockType]):
380
491
  pass
381
492
 
@@ -486,96 +597,73 @@ def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
486
597
  raise ValueError("Argument cannot be interpreted as a BsrMatrix")
487
598
 
488
599
 
489
- @wp.kernel
490
- def _bsr_assign_split_offsets(
491
- row_factor: int,
492
- col_factor: int,
493
- src_offsets: wp.array(dtype=int),
494
- dest_offsets: wp.array(dtype=int),
600
+ @wp.func
601
+ def _bsr_row_index(
602
+ offsets: wp.array(dtype=int),
603
+ row_count: int,
604
+ block: int,
495
605
  ):
496
- row = wp.tid()
606
+ """Index of the row containing a block, or -1 if non-existing."""
607
+ return wp.where(block < offsets[row_count], wp.lower_bound(offsets, 0, row_count + 1, block + 1), 0) - 1
497
608
 
498
- base_offset = src_offsets[row] * row_factor * col_factor
499
- row_count = src_offsets[1 + row] - src_offsets[row]
500
609
 
501
- for k in range(row_factor):
502
- dest_offsets[1 + k + row_factor * row] = base_offset + row_count * col_factor * (k + 1)
503
-
504
- if row == 0:
505
- dest_offsets[0] = 0
506
-
507
-
508
- @wp.kernel
509
- def _bsr_assign_split_blocks(
510
- structure_only: wp.bool,
511
- scale: Any,
512
- row_factor: int,
513
- col_factor: int,
514
- dest_row_count: int,
515
- src_offsets: wp.array(dtype=int),
516
- src_columns: wp.array(dtype=int),
517
- src_values: wp.array3d(dtype=Any),
518
- dest_offsets: wp.array(dtype=int),
519
- dest_columns: wp.array(dtype=int),
520
- dest_values: wp.array3d(dtype=Any),
610
+ @wp.func
611
+ def _bsr_block_index(
612
+ row: int,
613
+ col: int,
614
+ bsr_offsets: wp.array(dtype=int),
615
+ bsr_columns: wp.array(dtype=int),
521
616
  ):
522
- dest_block = wp.tid()
523
-
524
- if dest_block >= dest_offsets[dest_row_count]:
525
- return
526
-
527
- dest_row = wp.lower_bound(dest_offsets, 0, dest_row_count + 1, dest_block + 1) - 1
528
- src_row = dest_row // row_factor
529
-
530
- dest_col_in_row = dest_block - dest_offsets[dest_row]
531
- src_col_in_row = dest_col_in_row // col_factor
532
-
533
- src_block = src_offsets[src_row] + src_col_in_row
617
+ """Index of the block at block-coordinates (row, col), or -1 if non-existing.
618
+ Assumes bsr_columns is sorted.
619
+ """
534
620
 
535
- dest_rows_per_block = dest_values.shape[1]
536
- dest_cols_per_block = dest_values.shape[2]
621
+ if row < 0:
622
+ return -1
537
623
 
538
- split_row = dest_row - row_factor * src_row
539
- split_col = dest_col_in_row - col_factor * src_col_in_row
624
+ mask_row_beg = bsr_offsets[row]
625
+ mask_row_end = bsr_offsets[row + 1]
540
626
 
541
- dest_columns[dest_block] = src_columns[src_block] * col_factor + split_col
627
+ if mask_row_beg == mask_row_end:
628
+ return -1
542
629
 
543
- if not structure_only:
544
- src_base_i = split_row * dest_rows_per_block
545
- src_base_j = split_col * dest_cols_per_block
546
- for i in range(dest_rows_per_block):
547
- for j in range(dest_cols_per_block):
548
- dest_values[dest_block, i, j] = dest_values.dtype(
549
- scale * src_values[src_block, i + src_base_i, j + src_base_j]
550
- )
630
+ block_index = wp.lower_bound(bsr_columns, mask_row_beg, mask_row_end, col)
631
+ return wp.where(bsr_columns[block_index] == col, block_index, -1)
551
632
 
552
633
 
553
- @wp.kernel
554
- def _bsr_assign_merge_row_col(
555
- row_factor: int,
556
- col_factor: int,
634
+ @wp.kernel(enable_backward=False)
635
+ def _bsr_assign_list_blocks(
636
+ src_subrows: int,
637
+ src_subcols: int,
638
+ dest_subrows: int,
639
+ dest_subcols: int,
557
640
  src_row_count: int,
558
641
  src_offsets: wp.array(dtype=int),
559
642
  src_columns: wp.array(dtype=int),
560
643
  dest_rows: wp.array(dtype=int),
561
644
  dest_cols: wp.array(dtype=int),
562
645
  ):
563
- block = wp.tid()
646
+ block, subrow, subcol = wp.tid()
647
+ dest_block = (block * src_subcols + subcol) * src_subrows + subrow
564
648
 
565
- if block >= src_offsets[src_row_count]:
566
- dest_rows[block] = -1 # invalid
567
- dest_cols[block] = -1
649
+ row = _bsr_row_index(src_offsets, src_row_count, block)
650
+ if row == -1:
651
+ dest_rows[dest_block] = row # invalid
652
+ dest_cols[dest_block] = row
568
653
  else:
569
- row = wp.lower_bound(src_offsets, 0, src_row_count + 1, block + 1) - 1
570
- dest_rows[block] = row // row_factor
571
- dest_cols[block] = src_columns[block] // col_factor
654
+ dest_subrow = row * src_subrows + subrow
655
+ dest_subcol = src_columns[block] * src_subcols + subcol
656
+ dest_rows[dest_block] = dest_subrow // dest_subrows
657
+ dest_cols[dest_block] = dest_subcol // dest_subcols
572
658
 
573
659
 
574
660
  @wp.kernel
575
- def _bsr_assign_merge_blocks(
661
+ def _bsr_assign_copy_blocks(
576
662
  scale: Any,
577
- row_factor: int,
578
- col_factor: int,
663
+ src_subrows: int,
664
+ src_subcols: int,
665
+ dest_subrows: int,
666
+ dest_subcols: int,
579
667
  src_row_count: int,
580
668
  src_offsets: wp.array(dtype=int),
581
669
  src_columns: wp.array(dtype=int),
@@ -585,61 +673,58 @@ def _bsr_assign_merge_blocks(
585
673
  dest_values: wp.array3d(dtype=Any),
586
674
  ):
587
675
  src_block = wp.tid()
676
+ src_block, subrow, subcol = wp.tid()
588
677
 
589
- if src_block >= src_offsets[src_row_count]:
678
+ src_row = _bsr_row_index(src_offsets, src_row_count, src_block)
679
+ if src_row == -1:
590
680
  return
591
681
 
592
- src_row = wp.lower_bound(src_offsets, 0, src_row_count + 1, src_block + 1) - 1
593
682
  src_col = src_columns[src_block]
594
683
 
595
- dest_row = src_row // row_factor
596
- dest_col = src_col // col_factor
684
+ dest_subrow = src_row * src_subrows + subrow
685
+ dest_subcol = src_col * src_subcols + subcol
686
+ dest_row = dest_subrow // dest_subrows
687
+ dest_col = dest_subcol // dest_subcols
597
688
 
598
- dest_block = wp.lower_bound(dest_columns, dest_offsets[dest_row], dest_offsets[dest_row + 1], dest_col)
689
+ dest_block = _bsr_block_index(dest_row, dest_col, dest_offsets, dest_columns)
690
+ if dest_block == -1:
691
+ return
692
+
693
+ split_row = dest_subrow - dest_subrows * dest_row
694
+ split_col = dest_subcol - dest_subcols * dest_col
599
695
 
600
- src_rows_per_block = src_values.shape[1]
601
- src_cols_per_block = src_values.shape[2]
696
+ rows_per_subblock = src_values.shape[1] // src_subrows
697
+ cols_per_subblock = src_values.shape[2] // src_subcols
602
698
 
603
- split_row = src_row - row_factor * dest_row
604
- split_col = src_col - col_factor * dest_col
699
+ dest_base_i = split_row * rows_per_subblock
700
+ dest_base_j = split_col * cols_per_subblock
605
701
 
606
- dest_base_i = split_row * src_rows_per_block
607
- dest_base_j = split_col * src_cols_per_block
702
+ src_base_i = subrow * rows_per_subblock
703
+ src_base_j = subcol * cols_per_subblock
608
704
 
609
- for i in range(src_rows_per_block):
610
- for j in range(src_cols_per_block):
705
+ for i in range(rows_per_subblock):
706
+ for j in range(cols_per_subblock):
611
707
  dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
612
- scale * src_values[src_block, i, j]
708
+ scale * src_values[src_block, i + src_base_i, j + src_base_j]
613
709
  )
614
710
 
615
711
 
616
- def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array:
617
- if A.block_shape == (1, 1):
618
- return A.values.reshape((A.values.shape[0], 1, 1))
619
-
620
- return wp.array(
621
- data=None,
622
- ptr=A.values.ptr,
623
- capacity=A.values.capacity,
624
- device=A.device,
625
- dtype=A.scalar_type,
626
- shape=(A.values.shape[0], A.block_shape[0], A.block_shape[1]),
627
- )
628
-
629
-
630
712
  def bsr_assign(
631
713
  dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
632
714
  src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
633
715
  structure_only: bool = False,
716
+ masked: bool = False,
634
717
  ):
635
- """Copies the content of the `src` BSR matrix to `dest`.
718
+ """Copy the content of the ``src`` BSR matrix to ``dest``.
636
719
 
637
720
  Args:
638
- src: Matrix to be copied
639
- dest: Destination matrix. May have a different block shape of scalar type than `src`, in which case the required casting will be performed.
721
+ src: Matrix to be copied.
722
+ dest: Destination matrix. May have a different block shape or scalar type
723
+ than ``src``, in which case the required casting will be performed.
640
724
  structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
641
- to accommodate at least `src.nnz` blocks. If `structure_only` is ``False``, values are also copied with implicit
725
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
642
726
  casting if the two matrices use distinct scalar types.
727
+ masked: If ``True``, prevent the assignment operation from adding new non-zeros blocks to ``dest``.
643
728
  """
644
729
 
645
730
  src, src_scale = _extract_matrix_and_scale(src)
@@ -647,13 +732,50 @@ def bsr_assign(
647
732
  if dest.values.device != src.values.device:
648
733
  raise ValueError("Source and destination matrices must reside on the same device")
649
734
 
650
- if dest.block_shape == src.block_shape:
651
- dest.nrow = src.nrow
652
- dest.ncol = src.ncol
735
+ if src.block_shape[0] >= dest.block_shape[0]:
736
+ src_subrows = src.block_shape[0] // dest.block_shape[0]
737
+ dest_subrows = 1
738
+ else:
739
+ dest_subrows = dest.block_shape[0] // src.block_shape[0]
740
+ src_subrows = 1
741
+
742
+ if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
743
+ raise ValueError(
744
+ f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {src.block_shape[0]}, {dest.block_shape[0]})"
745
+ )
746
+
747
+ if src.block_shape[1] >= dest.block_shape[1]:
748
+ src_subcols = src.block_shape[1] // dest.block_shape[1]
749
+ dest_subcols = 1
750
+ else:
751
+ dest_subcols = dest.block_shape[1] // src.block_shape[1]
752
+ src_subcols = 1
753
+
754
+ if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
755
+ raise ValueError(
756
+ f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {src.block_shape[1]}, {dest.block_shape[1]})"
757
+ )
653
758
 
654
- nnz_alloc = src.nnz
759
+ dest_nrow = (src.nrow * src_subrows) // dest_subrows
760
+ dest_ncol = (src.ncol * src_subcols) // dest_subcols
761
+
762
+ if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
763
+ raise ValueError("The requested block shape does not evenly divide the source matrix")
764
+
765
+ nnz_alloc = src.nnz * src_subrows * src_subcols
766
+ if masked:
767
+ if dest_nrow != dest.nrow or dest_ncol != dest.ncol:
768
+ raise ValueError(
769
+ f"Incompatible destination matrix size, expected ({dest_nrow}, {dest_ncol}), got ({dest.nrow}, {dest.ncol})"
770
+ )
771
+ else:
772
+ dest.nrow = dest_nrow
773
+ dest.ncol = dest_ncol
655
774
  _bsr_ensure_fits(dest, nnz=nnz_alloc)
656
775
 
776
+ if dest.block_shape == src.block_shape and not masked:
777
+ # Direct copy
778
+
657
779
  wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
658
780
  dest.copy_nnz_async()
659
781
 
@@ -664,86 +786,29 @@ def bsr_assign(
664
786
  warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
665
787
  bsr_scale(dest, src_scale)
666
788
 
667
- elif src.block_shape[0] >= dest.block_shape[0] and src.block_shape[1] >= dest.block_shape[1]:
668
- # Split blocks
669
-
670
- row_factor = src.block_shape[0] // dest.block_shape[0]
671
- col_factor = src.block_shape[1] // dest.block_shape[1]
672
-
673
- if (
674
- row_factor * dest.block_shape[0] != src.block_shape[0]
675
- or col_factor * dest.block_shape[1] != src.block_shape[1]
676
- ):
677
- raise ValueError(
678
- f"Dest block shape {dest.block_shape} is not an exact divider of src block shape {src.block_shape}"
679
- )
680
-
681
- dest.nrow = src.nrow * row_factor
682
- dest.ncol = src.ncol * col_factor
683
-
684
- nnz_alloc = src.nnz * row_factor * col_factor
685
- _bsr_ensure_fits(dest, nnz=nnz_alloc)
789
+ else:
790
+ # Masked and/or multiple src blocks per dest block, go through COO format
686
791
 
792
+ # Compute destination rows and columns
793
+ dest_rows = wp.empty(nnz_alloc, dtype=int, device=dest.device)
794
+ dest_cols = wp.empty(nnz_alloc, dtype=int, device=dest.device)
687
795
  wp.launch(
688
- _bsr_assign_split_offsets,
689
- dim=src.nrow,
690
- device=dest.device,
691
- inputs=[row_factor, col_factor, src.offsets, dest.offsets],
692
- )
693
- wp.launch(
694
- _bsr_assign_split_blocks,
695
- dim=dest.nnz,
796
+ _bsr_assign_list_blocks,
797
+ dim=(src.nnz, src_subrows, src_subcols),
696
798
  device=dest.device,
697
799
  inputs=[
698
- wp.bool(structure_only),
699
- src.scalar_type(src_scale),
700
- row_factor,
701
- col_factor,
702
- dest.nrow,
800
+ src_subrows,
801
+ src_subcols,
802
+ dest_subrows,
803
+ dest_subcols,
804
+ src.nrow,
703
805
  src.offsets,
704
806
  src.columns,
705
- _bsr_values_as_3d_array(src),
706
- dest.offsets,
707
- dest.columns,
708
- _bsr_values_as_3d_array(dest),
807
+ dest_rows,
808
+ dest_cols,
709
809
  ],
710
810
  )
711
811
 
712
- elif src.block_shape[0] <= dest.block_shape[0] and src.block_shape[1] <= dest.block_shape[1]:
713
- # Merge blocks
714
-
715
- row_factor = dest.block_shape[0] // src.block_shape[0]
716
- col_factor = dest.block_shape[1] // src.block_shape[1]
717
-
718
- if (
719
- row_factor * src.block_shape[0] != dest.block_shape[0]
720
- or col_factor * src.block_shape[1] != dest.block_shape[1]
721
- ):
722
- raise ValueError(
723
- f"Dest block shape {dest.block_shape} is not an exact multiple of src block shape {src.block_shape}"
724
- )
725
-
726
- if src.nrow % row_factor != 0 or src.ncol % col_factor != 0:
727
- raise ValueError(
728
- "The total rows and columns of the src matrix cannot be evenly divided using the requested block shape"
729
- )
730
-
731
- dest.nrow = src.nrow // row_factor
732
- dest.ncol = src.ncol // col_factor
733
-
734
- nnz_alloc = src.nnz # Conservative, in case all nnz in src belong to distinct merged blocks
735
- _bsr_ensure_fits(dest, nnz=nnz_alloc)
736
-
737
- # Compute destination rows and columns
738
- dest_rows = wp.empty_like(src.columns)
739
- dest_cols = wp.empty_like(src.columns)
740
- wp.launch(
741
- _bsr_assign_merge_row_col,
742
- dim=src.nnz,
743
- device=dest.device,
744
- inputs=[row_factor, col_factor, src.nrow, src.offsets, src.columns, dest_rows, dest_cols],
745
- )
746
-
747
812
  # Compute destination offsets from triplets
748
813
  from warp.context import runtime
749
814
 
@@ -758,11 +823,12 @@ def bsr_assign(
758
823
  dest.block_shape[0],
759
824
  dest.block_shape[1],
760
825
  dest.nrow,
761
- dest.nnz,
826
+ nnz_alloc,
762
827
  ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
763
828
  ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
764
829
  0,
765
830
  False,
831
+ masked,
766
832
  ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
767
833
  ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
768
834
  0,
@@ -774,26 +840,25 @@ def bsr_assign(
774
840
  if not structure_only:
775
841
  dest.values.zero_()
776
842
  wp.launch(
777
- _bsr_assign_merge_blocks,
778
- dim=src.nnz,
843
+ _bsr_assign_copy_blocks,
844
+ dim=(src.nnz, src_subrows, src_subcols),
779
845
  device=dest.device,
780
846
  inputs=[
781
847
  src.scalar_type(src_scale),
782
- row_factor,
783
- col_factor,
848
+ src_subrows,
849
+ src_subcols,
850
+ dest_subrows,
851
+ dest_subcols,
784
852
  src.nrow,
785
853
  src.offsets,
786
854
  src.columns,
787
- _bsr_values_as_3d_array(src),
855
+ src.scalar_values,
788
856
  dest.offsets,
789
857
  dest.columns,
790
- _bsr_values_as_3d_array(dest),
858
+ dest.scalar_values,
791
859
  ],
792
860
  )
793
861
 
794
- else:
795
- raise ValueError("Incompatible dest and src block shapes")
796
-
797
862
 
798
863
  def bsr_copy(
799
864
  A: BsrMatrixOrExpression,
@@ -801,15 +866,15 @@ def bsr_copy(
801
866
  block_shape: Optional[Tuple[int, int]] = None,
802
867
  structure_only: bool = False,
803
868
  ):
804
- """Returns a copy of matrix ``A``, possibly changing its scalar type.
869
+ """Return a copy of matrix ``A``, possibly changing its scalar type.
805
870
 
806
871
  Args:
807
- A: Matrix to be copied
808
- scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
809
- block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from `A`.
810
- Both dimensions of `block_shape` must be either a multiple or an exact divider of the ones from `A`.
872
+ A: Matrix to be copied.
873
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from ``A``.
874
+ block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from ``A``.
875
+ Both dimensions of ``block_shape`` must be either a multiple or an exact divider of the ones from ``A``.
811
876
  structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
812
- to accommodate at least `src.nnz` blocks. If `structure_only` is ``False``, values are also copied with implicit
877
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
813
878
  casting if the two matrices use distinct scalar types.
814
879
  """
815
880
  if scalar_type is None:
@@ -820,7 +885,7 @@ def bsr_copy(
820
885
  if block_shape == (1, 1):
821
886
  block_type = scalar_type
822
887
  else:
823
- block_type = wp.types.matrix(shape=block_shape, dtype=scalar_type)
888
+ block_type = wp.mat(shape=block_shape, dtype=scalar_type)
824
889
 
825
890
  copy = bsr_zeros(
826
891
  rows_of_blocks=A.nrow,
@@ -836,7 +901,7 @@ def bsr_set_transpose(
836
901
  dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
837
902
  src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
838
903
  ):
839
- """Assigns the transposed matrix `src` to matrix `dest`"""
904
+ """Assign the transposed matrix ``src`` to matrix ``dest``."""
840
905
 
841
906
  src, src_scale = _extract_matrix_and_scale(src)
842
907
 
@@ -897,13 +962,13 @@ def bsr_set_transpose(
897
962
  bsr_scale(dest, src_scale)
898
963
 
899
964
 
900
- def bsr_transposed(A: BsrMatrixOrExpression):
901
- """Returns a copy of the transposed matrix `A`"""
965
+ def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
966
+ """Return a copy of the transposed matrix ``A``."""
902
967
 
903
968
  if A.block_shape == (1, 1):
904
969
  block_type = A.values.dtype
905
970
  else:
906
- block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type)
971
+ block_type = wp.mat(shape=A.block_shape[::-1], dtype=A.scalar_type)
907
972
 
908
973
  transposed = bsr_zeros(
909
974
  rows_of_blocks=A.ncol,
@@ -924,21 +989,18 @@ def _bsr_get_diag_kernel(
924
989
  out: wp.array(dtype=Any),
925
990
  ):
926
991
  row = wp.tid()
927
- beg = A_offsets[row]
928
- end = A_offsets[row + 1]
929
992
 
930
- diag = wp.lower_bound(A_columns, beg, end, row)
931
- if diag < end:
932
- if A_columns[diag] == row:
933
- out[row] = scale * A_values[diag]
993
+ diag = _bsr_block_index(row, row, A_offsets, A_columns)
994
+ if diag != -1:
995
+ out[row] = scale * A_values[diag]
934
996
 
935
997
 
936
998
  def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
937
- """Returns the array of blocks that constitute the diagonal of a sparse matrix.
999
+ """Return the array of blocks that constitute the diagonal of a sparse matrix.
938
1000
 
939
1001
  Args:
940
- A: the sparse matrix from which to extract the diagonal
941
- out: if provided, the array into which to store the diagonal blocks
1002
+ A: The sparse matrix from which to extract the diagonal.
1003
+ out: If provided, the array into which to store the diagonal blocks.
942
1004
  """
943
1005
 
944
1006
  A, scale = _extract_matrix_and_scale(A)
@@ -965,36 +1027,16 @@ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[Block
965
1027
  return out
966
1028
 
967
1029
 
968
- @wp.kernel
1030
+ @wp.kernel(enable_backward=False)
969
1031
  def _bsr_set_diag_kernel(
970
- diag: wp.array(dtype=Any),
1032
+ nnz: int,
971
1033
  A_offsets: wp.array(dtype=int),
972
1034
  A_columns: wp.array(dtype=int),
973
- A_values: wp.array(dtype=Any),
974
1035
  ):
975
1036
  row = wp.tid()
976
- A_offsets[row + 1] = row + 1
977
- A_columns[row] = row
978
- A_values[row] = diag[row]
979
-
980
- if row == 0:
981
- A_offsets[0] = 0
982
-
983
-
984
- @wp.kernel
985
- def _bsr_set_diag_constant_kernel(
986
- diag_value: Any,
987
- A_offsets: wp.array(dtype=int),
988
- A_columns: wp.array(dtype=int),
989
- A_values: wp.array(dtype=Any),
990
- ):
991
- row = wp.tid()
992
- A_offsets[row + 1] = row + 1
993
- A_columns[row] = row
994
- A_values[row] = diag_value
995
-
996
- if row == 0:
997
- A_offsets[0] = 0
1037
+ A_offsets[row] = wp.min(row, nnz)
1038
+ if row < nnz:
1039
+ A_columns[row] = row
998
1040
 
999
1041
 
1000
1042
  def bsr_set_diag(
@@ -1002,20 +1044,26 @@ def bsr_set_diag(
1002
1044
  diag: "Union[BlockType, Array[BlockType]]",
1003
1045
  rows_of_blocks: Optional[int] = None,
1004
1046
  cols_of_blocks: Optional[int] = None,
1005
- ):
1006
- """Sets `A` as a block-diagonal matrix
1047
+ ) -> None:
1048
+ """Set ``A`` as a block-diagonal matrix.
1007
1049
 
1008
1050
  Args:
1009
- A: the sparse matrix to modify
1010
- diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
1011
- or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
1012
- rows_of_blocks: If not ``None``, the new number of rows of blocks
1013
- cols_of_blocks: If not ``None``, the new number of columns of blocks
1051
+ A: The sparse matrix to modify.
1052
+ diag: Specifies the values for diagonal blocks. Can be one of:
1053
+
1054
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1055
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1056
+ - ``None``: Diagonal block values are left uninitialized
1057
+
1058
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
1059
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
1060
+
1061
+ The shape of the matrix will be defined one of the following, in this order:
1014
1062
 
1015
- The shape of the matrix will be defined one of the following, in that order:
1016
- - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
1017
- - the first dimension of `diag`, if `diag` is an array
1018
- - the current dimensions of `A` otherwise
1063
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1064
+ If only one is given, the second is assumed equal.
1065
+ - The first dimension of ``diag``, if ``diag`` is an array
1066
+ - The current dimensions of ``A`` otherwise
1019
1067
  """
1020
1068
 
1021
1069
  if rows_of_blocks is None and cols_of_blocks is not None:
@@ -1023,7 +1071,7 @@ def bsr_set_diag(
1023
1071
  if cols_of_blocks is None and rows_of_blocks is not None:
1024
1072
  cols_of_blocks = rows_of_blocks
1025
1073
 
1026
- if warp.types.is_array(diag):
1074
+ if is_array(diag):
1027
1075
  if rows_of_blocks is None:
1028
1076
  rows_of_blocks = diag.shape[0]
1029
1077
  cols_of_blocks = diag.shape[0]
@@ -1035,43 +1083,45 @@ def bsr_set_diag(
1035
1083
  nnz = min(A.nrow, A.ncol)
1036
1084
  _bsr_ensure_fits(A, nnz=nnz)
1037
1085
 
1038
- if warp.types.is_array(diag):
1039
- wp.launch(
1040
- kernel=_bsr_set_diag_kernel,
1041
- dim=nnz,
1042
- device=A.values.device,
1043
- inputs=[diag, A.offsets, A.columns, A.values],
1044
- )
1045
- else:
1046
- if not warp.types.type_is_value(type(diag)):
1047
- # Cast to launchable type
1048
- diag = A.values.dtype(diag)
1049
- wp.launch(
1050
- kernel=_bsr_set_diag_constant_kernel,
1051
- dim=nnz,
1052
- device=A.values.device,
1053
- inputs=[diag, A.offsets, A.columns, A.values],
1054
- )
1086
+ wp.launch(
1087
+ kernel=_bsr_set_diag_kernel,
1088
+ dim=nnz + 1,
1089
+ device=A.offsets.device,
1090
+ inputs=[nnz, A.offsets, A.columns],
1091
+ )
1092
+
1093
+ if is_array(diag):
1094
+ wp.copy(src=diag, dest=A.values, count=nnz)
1095
+ elif diag is not None:
1096
+ A.values.fill_(diag)
1055
1097
 
1056
1098
  A.copy_nnz_async(known_nnz=nnz)
1057
1099
 
1058
1100
 
1059
1101
  def bsr_diag(
1060
- diag: "Union[BlockType, Array[BlockType]]",
1102
+ diag: Optional[Union[BlockType, Array[BlockType]]] = None,
1061
1103
  rows_of_blocks: Optional[int] = None,
1062
1104
  cols_of_blocks: Optional[int] = None,
1105
+ block_type: Optional[BlockType] = None,
1106
+ device=None,
1063
1107
  ) -> BsrMatrix["BlockType"]:
1064
- """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
1108
+ """Create and return a block-diagonal BSR matrix from an given block value or array of block values.
1065
1109
 
1066
1110
  Args:
1067
- diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
1068
- or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
1111
+ diag: Specifies the values for diagonal blocks. Can be one of:
1112
+
1113
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1114
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1069
1115
  rows_of_blocks: If not ``None``, the new number of rows of blocks
1070
1116
  cols_of_blocks: If not ``None``, the new number of columns of blocks
1117
+ block_type: If ``diag`` is ``None``, block type of the matrix. Otherwise deduced from ``diag``
1118
+ device: If ``diag`` is not a Warp array, device on which to allocate the matrix. Otherwise deduced from ``diag``
1119
+
1120
+ The shape of the matrix will be defined one of the following, in this order:
1071
1121
 
1072
- The shape of the matrix will be defined one of the following, in that order:
1073
- - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
1074
- - the first dimension of `diag`, if `diag` is an array
1122
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1123
+ If only one is given, the second is assumed equal.
1124
+ - The first dimension of ``diag`` if ``diag`` is an array.
1075
1125
  """
1076
1126
 
1077
1127
  if rows_of_blocks is None and cols_of_blocks is not None:
@@ -1079,43 +1129,39 @@ def bsr_diag(
1079
1129
  if cols_of_blocks is None and rows_of_blocks is not None:
1080
1130
  cols_of_blocks = rows_of_blocks
1081
1131
 
1082
- if warp.types.is_array(diag):
1132
+ if is_array(diag):
1083
1133
  if rows_of_blocks is None:
1084
1134
  rows_of_blocks = diag.shape[0]
1085
1135
  cols_of_blocks = diag.shape[0]
1086
1136
 
1087
- A = bsr_zeros(
1088
- rows_of_blocks,
1089
- cols_of_blocks,
1090
- block_type=diag.dtype,
1091
- device=diag.device,
1092
- )
1137
+ block_type = diag.dtype
1138
+ device = diag.device
1093
1139
  else:
1094
1140
  if rows_of_blocks is None:
1095
1141
  raise ValueError(
1096
1142
  "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
1097
1143
  )
1098
1144
 
1145
+ if block_type is None:
1146
+ if diag is None:
1147
+ raise ValueError("Either `diag` or `block_type` needs to be provided")
1148
+
1099
1149
  block_type = type(diag)
1100
- if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
1150
+ if not type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
1101
1151
  block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
1102
1152
 
1103
- A = bsr_zeros(
1104
- rows_of_blocks,
1105
- cols_of_blocks,
1106
- block_type=block_type,
1107
- )
1108
-
1153
+ A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
1109
1154
  bsr_set_diag(A, diag)
1110
1155
  return A
1111
1156
 
1112
1157
 
1113
- def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
1114
- """Sets `A` as the identity matrix
1158
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None) -> None:
1159
+ """Set ``A`` as the identity matrix.
1115
1160
 
1116
1161
  Args:
1117
- A: the sparse matrix to modify
1118
- rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
1162
+ A: The sparse matrix to modify.
1163
+ rows_of_blocks: If provided, the matrix will be resized as a square
1164
+ matrix with ``rows_of_blocks`` rows and columns.
1119
1165
  """
1120
1166
 
1121
1167
  if A.block_shape == (1, 1):
@@ -1133,11 +1179,11 @@ def bsr_identity(
1133
1179
  block_type: BlockType[Rows, Rows, Scalar],
1134
1180
  device: wp.context.Devicelike = None,
1135
1181
  ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
1136
- """Creates and returns a square identity matrix.
1182
+ """Create and return a square identity matrix.
1137
1183
 
1138
1184
  Args:
1139
1185
  rows_of_blocks: Number of rows and columns of blocks in the created matrix.
1140
- block_type: Block type for the newly created matrix -- must be square
1186
+ block_type: Block type for the newly created matrix. Must be square
1141
1187
  device: Device onto which to allocate the data arrays
1142
1188
  """
1143
1189
  A = bsr_zeros(
@@ -1159,9 +1205,7 @@ def _bsr_scale_kernel(
1159
1205
 
1160
1206
 
1161
1207
  def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1162
- """
1163
- Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
1164
- """
1208
+ """Perform the operation ``x := alpha * x`` on BSR matrix ``x`` and return ``x``."""
1165
1209
 
1166
1210
  x, scale = _extract_matrix_and_scale(x)
1167
1211
  alpha *= scale
@@ -1170,8 +1214,7 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1170
1214
  if alpha == 0.0:
1171
1215
  bsr_set_zero(x)
1172
1216
  else:
1173
- if not isinstance(alpha, x.scalar_type):
1174
- alpha = x.scalar_type(alpha)
1217
+ alpha = x.scalar_type(alpha)
1175
1218
 
1176
1219
  wp.launch(
1177
1220
  kernel=_bsr_scale_kernel,
@@ -1183,15 +1226,10 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1183
1226
  return x
1184
1227
 
1185
1228
 
1186
- @wp.kernel
1187
- def _bsr_get_block_row(dest_offset: int, row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
1188
- i = wp.tid()
1189
-
1190
- if i >= bsr_offsets[row_count]:
1191
- rows[dest_offset + i] = -1 # invalid
1192
- else:
1193
- row = wp.lower_bound(bsr_offsets, 0, row_count + 1, i + 1) - 1
1194
- rows[dest_offset + i] = row
1229
+ @wp.kernel(enable_backward=False)
1230
+ def _bsr_get_block_row(row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
1231
+ block = wp.tid()
1232
+ rows[block] = _bsr_row_index(bsr_offsets, row_count, block)
1195
1233
 
1196
1234
 
1197
1235
  @wp.kernel
@@ -1207,21 +1245,15 @@ def _bsr_axpy_add_block(
1207
1245
  ):
1208
1246
  i = wp.tid()
1209
1247
  row = rows[i + src_offset]
1210
-
1211
- if row < 0:
1212
- return
1213
-
1214
1248
  col = cols[i + src_offset]
1215
- beg = dst_offsets[row]
1216
- end = dst_offsets[row + 1]
1217
1249
 
1218
- block = wp.lower_bound(dst_columns, beg, end, col)
1219
-
1220
- dst_values[block] = dst_values[block] + scale * src_values[i]
1250
+ block = _bsr_block_index(row, col, dst_offsets, dst_columns)
1251
+ if block != -1:
1252
+ dst_values[block] += scale * src_values[i]
1221
1253
 
1222
1254
 
1223
1255
  class bsr_axpy_work_arrays:
1224
- """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
1256
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls."""
1225
1257
 
1226
1258
  def __init__(self):
1227
1259
  self._reset(None)
@@ -1251,25 +1283,33 @@ def bsr_axpy(
1251
1283
  y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
1252
1284
  alpha: Scalar = 1.0,
1253
1285
  beta: Scalar = 1.0,
1286
+ masked: bool = False,
1254
1287
  work_arrays: Optional[bsr_axpy_work_arrays] = None,
1255
1288
  ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1256
1289
  """
1257
- Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
1290
+ Perform the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices ``x`` and ``y`` and return ``y``.
1258
1291
 
1259
- The `x` and `y` matrices are allowed to alias.
1292
+ The ``x`` and ``y`` matrices are allowed to alias.
1260
1293
 
1261
1294
  Args:
1262
1295
  x: Read-only right-hand-side.
1263
- y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1264
- alpha: Uniform scaling factor for `x`
1265
- beta: Uniform scaling factor for `y`
1266
- work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
1296
+ y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
1297
+ alpha: Uniform scaling factor for ``x``.
1298
+ beta: Uniform scaling factor for ``y``.
1299
+ masked: If ``True``, discard all blocks from ``x`` which are not
1300
+ existing non-zeros of ``y``.
1301
+ work_arrays: In most cases, this function will require the use of temporary storage.
1302
+ This storage can be reused across calls by passing an instance of
1303
+ :class:`bsr_axpy_work_arrays` in ``work_arrays``.
1267
1304
  """
1268
1305
 
1269
1306
  x, x_scale = _extract_matrix_and_scale(x)
1270
1307
  alpha *= x_scale
1271
1308
 
1272
1309
  if y is None:
1310
+ if masked:
1311
+ raise ValueError("Left-hand-side 'y' matrix must be provided for masked addition")
1312
+
1273
1313
  # If not output matrix is provided, allocate it for convenience
1274
1314
  y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
1275
1315
  beta = 0.0
@@ -1313,27 +1353,17 @@ def bsr_axpy(
1313
1353
  work_arrays._allocate(device, y, sum_nnz)
1314
1354
 
1315
1355
  wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
1316
- wp.launch(
1317
- kernel=_bsr_get_block_row,
1318
- device=device,
1319
- dim=y_nnz,
1320
- inputs=[0, y.nrow, y.offsets, work_arrays._sum_rows],
1321
- )
1356
+ y.uncompress_rows(out=work_arrays._sum_rows)
1322
1357
 
1323
1358
  wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
1324
- wp.launch(
1325
- kernel=_bsr_get_block_row,
1326
- device=device,
1327
- dim=x_nnz,
1328
- inputs=[y_nnz, x.nrow, x.offsets, work_arrays._sum_rows],
1329
- )
1359
+ x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
1330
1360
 
1331
1361
  # Save old y values before overwriting matrix
1332
1362
  wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y_nnz)
1333
1363
 
1334
1364
  # Increase dest array sizes if needed
1335
- if y.columns.shape[0] < sum_nnz:
1336
- y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device)
1365
+ if not masked:
1366
+ _bsr_ensure_fits(y, nnz=sum_nnz)
1337
1367
 
1338
1368
  from warp.context import runtime
1339
1369
 
@@ -1355,6 +1385,7 @@ def bsr_axpy(
1355
1385
  ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1356
1386
  0,
1357
1387
  False,
1388
+ masked,
1358
1389
  ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1359
1390
  ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1360
1391
  0,
@@ -1362,8 +1393,6 @@ def bsr_axpy(
1362
1393
  nnz_event,
1363
1394
  )
1364
1395
 
1365
- _bsr_ensure_fits(y, nnz=sum_nnz)
1366
-
1367
1396
  y.values.zero_()
1368
1397
 
1369
1398
  wp.launch(
@@ -1401,55 +1430,90 @@ def bsr_axpy(
1401
1430
  return y
1402
1431
 
1403
1432
 
1404
- @wp.kernel
1433
+ @wp.kernel(enable_backward=False)
1405
1434
  def _bsr_mm_count_coeffs(
1435
+ y_ncol: int,
1406
1436
  z_nnz: int,
1407
1437
  x_offsets: wp.array(dtype=int),
1408
1438
  x_columns: wp.array(dtype=int),
1409
1439
  y_offsets: wp.array(dtype=int),
1410
- counts: wp.array(dtype=int),
1440
+ y_columns: wp.array(dtype=int),
1441
+ row_min: wp.array(dtype=int),
1442
+ block_counts: wp.array(dtype=int),
1411
1443
  ):
1412
1444
  row = wp.tid()
1413
- count = int(0)
1445
+ row_count = int(0)
1414
1446
 
1415
1447
  x_beg = x_offsets[row]
1416
1448
  x_end = x_offsets[row + 1]
1417
1449
 
1450
+ min_col = y_ncol
1451
+ max_col = int(0)
1452
+
1418
1453
  for x_block in range(x_beg, x_end):
1419
1454
  x_col = x_columns[x_block]
1420
- count += y_offsets[x_col + 1] - y_offsets[x_col]
1421
-
1422
- counts[row + 1] = count
1455
+ y_row_end = y_offsets[x_col + 1]
1456
+ y_row_beg = y_offsets[x_col]
1457
+ block_count = y_row_end - y_row_beg
1458
+ if block_count != 0:
1459
+ min_col = wp.min(y_columns[y_row_beg], min_col)
1460
+ max_col = wp.max(y_columns[y_row_end - 1], max_col)
1461
+
1462
+ block_counts[x_block + 1] = block_count
1463
+ row_count += block_count
1464
+
1465
+ if row_count > wp.max(0, max_col - min_col):
1466
+ row_min[row] = min_col
1467
+ block_counts[x_end] = max_col + 1 - min_col
1468
+ for x_block in range(x_beg, x_end - 1):
1469
+ block_counts[x_block + 1] = 0
1470
+ else:
1471
+ row_min[row] = -1
1423
1472
 
1424
1473
  if row == 0:
1425
- counts[0] = z_nnz
1474
+ block_counts[0] = z_nnz
1426
1475
 
1427
1476
 
1428
- @wp.kernel
1477
+ @wp.kernel(enable_backward=False)
1429
1478
  def _bsr_mm_list_coeffs(
1479
+ x_nrow: int,
1430
1480
  x_offsets: wp.array(dtype=int),
1431
1481
  x_columns: wp.array(dtype=int),
1432
1482
  y_offsets: wp.array(dtype=int),
1433
1483
  y_columns: wp.array(dtype=int),
1484
+ mm_row_min: wp.array(dtype=int),
1434
1485
  mm_offsets: wp.array(dtype=int),
1435
1486
  mm_rows: wp.array(dtype=int),
1436
1487
  mm_cols: wp.array(dtype=int),
1437
1488
  ):
1438
- row = wp.tid()
1439
- mm_block = mm_offsets[row]
1489
+ x_block = wp.tid()
1490
+ mm_block = mm_offsets[x_block]
1440
1491
 
1441
- x_beg = x_offsets[row]
1442
- x_end = x_offsets[row + 1]
1492
+ row = _bsr_row_index(x_offsets, x_nrow, x_block)
1493
+ if row == -1:
1494
+ return
1443
1495
 
1444
- for x_block in range(x_beg, x_end):
1496
+ row_min_col = mm_row_min[row]
1497
+ if row_min_col != -1:
1445
1498
  x_col = x_columns[x_block]
1446
1499
 
1447
1500
  y_beg = y_offsets[x_col]
1448
1501
  y_end = y_offsets[x_col + 1]
1502
+
1449
1503
  for y_block in range(y_beg, y_end):
1450
- mm_cols[mm_block] = y_columns[y_block]
1451
- mm_rows[mm_block] = row
1452
- mm_block += 1
1504
+ col = y_columns[y_block]
1505
+ mm_rows[mm_block + col - row_min_col] = row
1506
+ mm_cols[mm_block + col - row_min_col] = col
1507
+
1508
+ return
1509
+
1510
+ x_col = x_columns[x_block]
1511
+ y_beg = y_offsets[x_col]
1512
+ y_end = y_offsets[x_col + 1]
1513
+ for y_block in range(y_beg, y_end):
1514
+ mm_cols[mm_block] = y_columns[y_block]
1515
+ mm_rows[mm_block] = row
1516
+ mm_block += 1
1453
1517
 
1454
1518
 
1455
1519
  @wp.kernel
@@ -1468,7 +1532,10 @@ def _bsr_mm_compute_values(
1468
1532
  ):
1469
1533
  mm_block = wp.tid()
1470
1534
 
1471
- row = wp.lower_bound(mm_offsets, 0, mm_row_count + 1, mm_block + 1) - 1
1535
+ row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
1536
+ if row == -1:
1537
+ return
1538
+
1472
1539
  col = mm_cols[mm_block]
1473
1540
 
1474
1541
  mm_val = mm_values.dtype(type(alpha)(0.0))
@@ -1477,26 +1544,23 @@ def _bsr_mm_compute_values(
1477
1544
  x_end = x_offsets[row + 1]
1478
1545
  for x_block in range(x_beg, x_end):
1479
1546
  x_col = x_columns[x_block]
1480
- y_beg = y_offsets[x_col]
1481
- y_end = y_offsets[x_col + 1]
1482
-
1483
- y_block = wp.lower_bound(y_columns, y_beg, y_end, col)
1484
- if y_block < y_end:
1485
- if y_columns[y_block] == col:
1486
- mm_val += x_values[x_block] * y_values[y_block]
1547
+ y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
1548
+ if y_block != -1:
1549
+ mm_val += x_values[x_block] * y_values[y_block]
1487
1550
 
1488
1551
  mm_values[mm_block] += alpha * mm_val
1489
1552
 
1490
1553
 
1491
1554
  class bsr_mm_work_arrays:
1492
- """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
1555
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
1493
1556
 
1494
1557
  def __init__(self):
1495
1558
  self._reset(None)
1496
1559
 
1497
1560
  def _reset(self, device):
1498
1561
  self.device = device
1499
- self._mm_row_counts = None
1562
+ self._mm_row_min = None
1563
+ self._mm_block_counts = None
1500
1564
  self._mm_rows = None
1501
1565
  self._mm_cols = None
1502
1566
  self._old_z_values = None
@@ -1504,7 +1568,7 @@ class bsr_mm_work_arrays:
1504
1568
  self._old_z_columns = None
1505
1569
  self._mm_nnz = 0
1506
1570
 
1507
- def _allocate_stage_1(self, device, z: BsrMatrix, beta: float, z_aliasing: bool):
1571
+ def _allocate_stage_1(self, device, x_nnz: int, z: BsrMatrix, beta: float, z_aliasing: bool):
1508
1572
  if self.device != device:
1509
1573
  self._reset(device)
1510
1574
 
@@ -1512,8 +1576,10 @@ class bsr_mm_work_arrays:
1512
1576
  z_nnz = z.nnz_sync()
1513
1577
  self._copied_z_nnz = z_nnz if beta != 0.0 or z_aliasing else 0
1514
1578
 
1515
- if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
1516
- self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
1579
+ if self._mm_row_min is None or self._mm_block_counts.size < z.nrow + 1:
1580
+ self._mm_row_min = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
1581
+ if self._mm_block_counts is None or self._mm_block_counts.size < x_nnz + 1:
1582
+ self._mm_block_counts = wp.empty(shape=(x_nnz + 1,), dtype=int, device=self.device)
1517
1583
 
1518
1584
  if self._copied_z_nnz > 0:
1519
1585
  if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
@@ -1540,25 +1606,31 @@ def bsr_mm(
1540
1606
  z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
1541
1607
  alpha: Scalar = 1.0,
1542
1608
  beta: Scalar = 0.0,
1609
+ masked: bool = False,
1543
1610
  work_arrays: Optional[bsr_mm_work_arrays] = None,
1544
1611
  reuse_topology: bool = False,
1545
1612
  ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1546
1613
  """
1547
- Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
1614
+ Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
1548
1615
 
1549
- The `x`, `y` and `z` matrices are allowed to alias.
1550
- If the matrix `z` is not provided as input, it will be allocated and treated as zero.
1616
+ The ``x``, ``y`` and ``z`` matrices are allowed to alias.
1617
+ If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
1551
1618
 
1552
1619
  Args:
1553
1620
  x: Read-only left factor of the matrix-matrix product.
1554
1621
  y: Read-only right factor of the matrix-matrix product.
1555
- z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
1556
- alpha: Uniform scaling factor for the ``x * y`` product
1557
- beta: Uniform scaling factor for `z`
1558
- work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
1559
- reuse_topology: If True, reuse the product topology information stored in `work_arrays` rather than recompute it from scratch.
1560
- The matrices x, y and z must be structurally similar to the previous call in which `work_arrays` were populated.
1561
- This is necessary for `bsr_mm` to be captured in a CUDA graph.
1622
+ z: Mutable left-hand-side. If ``z`` is not provided, it will be allocated and treated as zero.
1623
+ alpha: Uniform scaling factor for the ``x @ y`` product
1624
+ beta: Uniform scaling factor for ``z``
1625
+ masked: If ``True``, ignore all blocks from ``x @ y`` which are not existing non-zeros of ``y``
1626
+ work_arrays: In most cases, this function will require the use of temporary storage.
1627
+ This storage can be reused across calls by passing an instance of
1628
+ :class:`bsr_mm_work_arrays` in ``work_arrays``.
1629
+ reuse_topology: If ``True``, reuse the product topology information
1630
+ stored in ``work_arrays`` rather than recompute it from scratch.
1631
+ The matrices ``x``, ``y`` and ``z`` must be structurally similar to
1632
+ the previous call in which ``work_arrays`` were populated.
1633
+ This is necessary for ``bsr_mm`` to be captured in a CUDA graph.
1562
1634
  """
1563
1635
 
1564
1636
  x, x_scale = _extract_matrix_and_scale(x)
@@ -1567,12 +1639,15 @@ def bsr_mm(
1567
1639
  alpha *= y_scale
1568
1640
 
1569
1641
  if z is None:
1642
+ if masked:
1643
+ raise ValueError("Left-hand-side 'z' matrix must be provided for masked multiplication")
1644
+
1570
1645
  # If not output matrix is provided, allocate it for convenience
1571
1646
  z_block_shape = (x.block_shape[0], y.block_shape[1])
1572
1647
  if z_block_shape == (1, 1):
1573
1648
  z_block_type = x.scalar_type
1574
1649
  else:
1575
- z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type)
1650
+ z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
1576
1651
  z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
1577
1652
  beta = 0.0
1578
1653
 
@@ -1598,14 +1673,22 @@ def bsr_mm(
1598
1673
  # Easy case
1599
1674
  return bsr_scale(z, beta)
1600
1675
 
1601
- if not isinstance(alpha, z.scalar_type):
1602
- alpha = z.scalar_type(alpha)
1603
- if not isinstance(beta, z.scalar_type):
1604
- beta = z.scalar_type(beta)
1605
-
1606
1676
  z_aliasing = z == x or z == y
1607
1677
 
1608
- if reuse_topology:
1678
+ if masked:
1679
+ # no need to copy z, scale in-place
1680
+ copied_z_nnz = 0
1681
+ mm_nnz = z.nnz
1682
+
1683
+ if z_aliasing:
1684
+ raise ValueError("`masked=True` is not supported for aliased inputs")
1685
+
1686
+ if beta == 0.0:
1687
+ # do not bsr_scale(0), this would not preserve topology
1688
+ z.values.zero_()
1689
+ else:
1690
+ bsr_scale(z, beta)
1691
+ elif reuse_topology:
1609
1692
  if work_arrays is None:
1610
1693
  raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
1611
1694
 
@@ -1618,133 +1701,142 @@ def bsr_mm(
1618
1701
  if work_arrays is None:
1619
1702
  work_arrays = bsr_mm_work_arrays()
1620
1703
 
1621
- work_arrays._allocate_stage_1(device, z, beta, z_aliasing)
1704
+ work_arrays._allocate_stage_1(device, x.nnz, z, beta, z_aliasing)
1622
1705
  copied_z_nnz = work_arrays._copied_z_nnz
1623
1706
 
1624
1707
  # Prefix sum of number of (unmerged) mm blocks per row
1708
+ work_arrays._mm_block_counts.zero_()
1625
1709
  wp.launch(
1626
1710
  kernel=_bsr_mm_count_coeffs,
1627
1711
  device=device,
1628
1712
  dim=z.nrow,
1629
1713
  inputs=[
1714
+ y.ncol,
1630
1715
  copied_z_nnz,
1631
1716
  x.offsets,
1632
1717
  x.columns,
1633
1718
  y.offsets,
1634
- work_arrays._mm_row_counts,
1719
+ y.columns,
1720
+ work_arrays._mm_row_min,
1721
+ work_arrays._mm_block_counts,
1635
1722
  ],
1636
1723
  )
1637
- warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
1724
+ warp.utils.array_scan(work_arrays._mm_block_counts, work_arrays._mm_block_counts)
1638
1725
 
1639
1726
  # Get back total counts on host -- we need a synchronization here
1640
1727
  # Use pinned buffer from z, we are going to need it later anyway
1641
1728
  nnz_buf, _ = z._nnz_transfer_buf_and_event()
1642
1729
  stream = wp.get_stream(device) if device.is_cuda else None
1643
- wp.copy(dest=nnz_buf, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1, stream=stream)
1730
+ wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
1644
1731
  if device.is_cuda:
1645
1732
  wp.synchronize_stream(stream)
1646
1733
  mm_nnz = int(nnz_buf.numpy()[0])
1647
1734
 
1735
+ if mm_nnz == copied_z_nnz:
1736
+ # x@y = 0
1737
+ return bsr_scale(z, beta)
1738
+
1648
1739
  work_arrays._allocate_stage_2(mm_nnz)
1649
1740
 
1650
1741
  # If z has a non-zero scale, save current data before overwriting it
1651
1742
  if copied_z_nnz > 0:
1652
1743
  # Copy z row and column indices
1653
1744
  wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1654
- wp.launch(
1655
- kernel=_bsr_get_block_row,
1656
- device=device,
1657
- dim=copied_z_nnz,
1658
- inputs=[0, z.nrow, z.offsets, work_arrays._mm_rows],
1659
- )
1745
+ z.uncompress_rows(out=work_arrays._mm_rows)
1660
1746
  if z_aliasing:
1661
1747
  # If z is aliasing with x or y, need to save topology as well
1662
1748
  wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1663
1749
  wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1664
1750
 
1665
1751
  # Fill unmerged mm blocks rows and columns
1752
+ work_arrays._mm_rows[copied_z_nnz:].fill_(-1)
1666
1753
  wp.launch(
1667
1754
  kernel=_bsr_mm_list_coeffs,
1668
1755
  device=device,
1669
- dim=z.nrow,
1756
+ dim=x.nnz,
1670
1757
  inputs=[
1758
+ x.nrow,
1671
1759
  x.offsets,
1672
1760
  x.columns,
1673
1761
  y.offsets,
1674
1762
  y.columns,
1675
- work_arrays._mm_row_counts,
1763
+ work_arrays._mm_row_min,
1764
+ work_arrays._mm_block_counts,
1676
1765
  work_arrays._mm_rows,
1677
1766
  work_arrays._mm_cols,
1678
1767
  ],
1679
1768
  )
1680
1769
 
1770
+ alpha = z.scalar_type(alpha)
1771
+ beta = z.scalar_type(beta)
1772
+
1681
1773
  if copied_z_nnz > 0:
1682
1774
  # Save current z values in temporary buffer
1683
1775
  wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1684
1776
 
1685
- # Increase dest array size if needed
1686
- if z.columns.shape[0] < mm_nnz:
1687
- z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1777
+ if not masked:
1778
+ # Increase dest array size if needed
1779
+ if z.columns.shape[0] < mm_nnz:
1780
+ z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1688
1781
 
1689
- from warp.context import runtime
1782
+ from warp.context import runtime
1690
1783
 
1691
- if device.is_cpu:
1692
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
1693
- else:
1694
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
1784
+ if device.is_cpu:
1785
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
1786
+ else:
1787
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
1695
1788
 
1696
- nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
1789
+ nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
1697
1790
 
1698
- with wp.ScopedDevice(z.device):
1699
- native_func(
1700
- z.block_shape[0],
1701
- z.block_shape[1],
1702
- z.nrow,
1703
- mm_nnz,
1704
- ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1705
- ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1706
- 0,
1707
- False,
1708
- ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1709
- ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1710
- 0,
1711
- ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1712
- nnz_event,
1713
- )
1791
+ with wp.ScopedDevice(z.device):
1792
+ native_func(
1793
+ z.block_shape[0],
1794
+ z.block_shape[1],
1795
+ z.nrow,
1796
+ mm_nnz,
1797
+ ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1798
+ ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1799
+ 0,
1800
+ False,
1801
+ masked,
1802
+ ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1803
+ ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1804
+ 0,
1805
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1806
+ nnz_event,
1807
+ )
1714
1808
 
1715
- # Resize z to fit mm result if necessary
1716
- # If we are not reusing the product topology, this needs another synchronization
1717
- if not reuse_topology:
1718
- work_arrays.result_nnz = z.nnz_sync()
1719
- _bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
1809
+ # Resize z to fit mm result if necessary
1810
+ # If we are not reusing the product topology, this needs another synchronization
1811
+ if not reuse_topology:
1812
+ work_arrays.result_nnz = z.nnz_sync()
1720
1813
 
1721
- z.values.zero_()
1814
+ _bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
1815
+ z.values.zero_()
1722
1816
 
1723
- if copied_z_nnz > 0:
1724
- # Add back original z values
1725
- wp.launch(
1726
- kernel=_bsr_axpy_add_block,
1727
- device=device,
1728
- dim=copied_z_nnz,
1729
- inputs=[
1730
- 0,
1731
- beta,
1732
- work_arrays._mm_rows,
1733
- work_arrays._mm_cols,
1734
- z.offsets,
1735
- z.columns,
1736
- work_arrays._old_z_values,
1737
- z.values,
1738
- ],
1739
- )
1817
+ if copied_z_nnz > 0:
1818
+ # Add back original z values
1819
+ wp.launch(
1820
+ kernel=_bsr_axpy_add_block,
1821
+ device=device,
1822
+ dim=copied_z_nnz,
1823
+ inputs=[
1824
+ 0,
1825
+ beta,
1826
+ work_arrays._mm_rows,
1827
+ work_arrays._mm_cols,
1828
+ z.offsets,
1829
+ z.columns,
1830
+ work_arrays._old_z_values,
1831
+ z.values,
1832
+ ],
1833
+ )
1740
1834
 
1741
1835
  # Add mm blocks to z values
1742
- if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
1743
- warp.types.type_is_matrix(z.values.dtype)
1744
- ):
1836
+ if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
1745
1837
  # Result block type is scalar, but operands are matrices
1746
1838
  # Cast result to (1x1) matrix to perform multiplication
1747
- mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1839
+ mm_values = z.values.view(wp.mat(shape=(1, 1), dtype=z.scalar_type))
1748
1840
  else:
1749
1841
  mm_values = z.values
1750
1842
 
@@ -1817,15 +1909,31 @@ def _bsr_mv_transpose_kernel(
1817
1909
  wp.atomic_add(y, A_columns[block], v)
1818
1910
 
1819
1911
 
1820
- def _bsr_mv_as_vec_array(array: wp.array) -> wp.array:
1821
- if array.ndim == 1:
1912
+ def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
1913
+ # cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
1914
+
1915
+ scalar_count = array.size * type_length(array.dtype)
1916
+ if scalar_count != expected_scalar_count:
1917
+ raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
1918
+
1919
+ if array.ndim == 1 and types_equal(array.dtype, dtype):
1822
1920
  return array
1823
1921
 
1922
+ if type_scalar_type(array.dtype) != type_scalar_type(dtype):
1923
+ raise ValueError(f"Incompatible scalar types, {type_repr(array.dtype)} vs {type_repr(dtype)}")
1924
+
1824
1925
  if array.ndim > 2:
1825
1926
  raise ValueError(f"Incompatible array number of dimensions {array.ndim}")
1826
1927
 
1827
1928
  if not array.is_contiguous:
1828
- raise ValueError("2d array must be contiguous")
1929
+ raise ValueError("Array must be contiguous")
1930
+
1931
+ vec_length = type_length(dtype)
1932
+ vec_count = scalar_count // vec_length
1933
+ if vec_count * vec_length != scalar_count:
1934
+ raise ValueError(
1935
+ f"Array of shape {array.shape} and type {type_repr(array.dtype)} cannot be reshaped to an array of type {type_repr(dtype)}"
1936
+ )
1829
1937
 
1830
1938
  def vec_view(array):
1831
1939
  return wp.array(
@@ -1833,8 +1941,8 @@ def _bsr_mv_as_vec_array(array: wp.array) -> wp.array:
1833
1941
  ptr=array.ptr,
1834
1942
  capacity=array.capacity,
1835
1943
  device=array.device,
1836
- dtype=wp.vec(length=array.shape[1], dtype=array.dtype),
1837
- shape=array.shape[0],
1944
+ dtype=dtype,
1945
+ shape=vec_count,
1838
1946
  grad=None if array.grad is None else vec_view(array.grad),
1839
1947
  )
1840
1948
 
@@ -1852,20 +1960,20 @@ def bsr_mv(
1852
1960
  transpose: bool = False,
1853
1961
  work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1854
1962
  ) -> "Array[Vector[Rows, Scalar] | Scalar]":
1855
- """
1856
- Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
1963
+ """Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
1857
1964
 
1858
- The `x` and `y` vectors are allowed to alias.
1965
+ The ``x`` and ``y`` vectors are allowed to alias.
1859
1966
 
1860
1967
  Args:
1861
1968
  A: Read-only, left matrix factor of the matrix-vector product.
1862
1969
  x: Read-only, right vector factor of the matrix-vector product.
1863
- y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1864
- alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1865
- beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1866
- transpose: If ``True``, use the transpose of the matrix `A`. In this case the result is **non-deterministic**.
1867
- work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1868
- will be used for this purpose, otherwise a temporary allocation will be performed.
1970
+ y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
1971
+ alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
1972
+ beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
1973
+ transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
1974
+ work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
1975
+ If provided, the ``work_buffer`` array will be used for this purpose,
1976
+ otherwise a temporary allocation will be performed.
1869
1977
  """
1870
1978
 
1871
1979
  A, A_scale = _extract_matrix_and_scale(A)
@@ -1885,22 +1993,11 @@ def bsr_mv(
1885
1993
  y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype)
1886
1994
  beta = 0.0
1887
1995
 
1888
- if not isinstance(alpha, A.scalar_type):
1889
- alpha = A.scalar_type(alpha)
1890
- if not isinstance(beta, A.scalar_type):
1891
- beta = A.scalar_type(beta)
1996
+ alpha = A.scalar_type(alpha)
1997
+ beta = A.scalar_type(beta)
1892
1998
 
1893
1999
  if A.values.device != x.device or A.values.device != y.device:
1894
- raise ValueError("A, x and y must reside on the same device")
1895
-
1896
- if x.shape[0] != ncol:
1897
- raise ValueError("Number of columns of A must match number of rows of x")
1898
- if y.shape[0] != nrow:
1899
- raise ValueError("Number of rows of A must match number of rows of y")
1900
-
1901
- # View 2d arrays as arrays of vecs
1902
- x = _bsr_mv_as_vec_array(x)
1903
- y = _bsr_mv_as_vec_array(y)
2000
+ raise ValueError("A, x, and y must reside on the same device")
1904
2001
 
1905
2002
  if x.ptr == y.ptr:
1906
2003
  # Aliasing case, need temporary storage
@@ -1908,24 +2005,29 @@ def bsr_mv(
1908
2005
  work_buffer = wp.empty_like(y)
1909
2006
  elif work_buffer.size < y.size:
1910
2007
  raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
1911
- elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
1912
- raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
2008
+ elif not types_equal(work_buffer.dtype, y.dtype):
2009
+ raise ValueError(f"Work buffer must have same data type as y, {type_repr(y.dtype)}")
1913
2010
 
1914
2011
  # Save old y values before overwriting vector
1915
2012
  wp.copy(dest=work_buffer, src=y, count=y.size)
1916
2013
  x = work_buffer
1917
2014
 
1918
2015
  # Promote scalar vectors to length-1 vecs and conversely
1919
- if warp.types.type_is_matrix(A.values.dtype):
1920
- if block_shape[0] == 1 and y.dtype == A.scalar_type:
1921
- y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1922
- if block_shape[1] == 1 and x.dtype == A.scalar_type:
1923
- x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
2016
+ if type_is_matrix(A.values.dtype):
2017
+ x_dtype = wp.vec(length=block_shape[1], dtype=A.scalar_type)
2018
+ y_dtype = wp.vec(length=block_shape[0], dtype=A.scalar_type)
1924
2019
  else:
1925
- if block_shape[0] == 1 and y.dtype != A.scalar_type:
1926
- y = y.view(dtype=A.scalar_type)
1927
- if block_shape[1] == 1 and x.dtype != A.scalar_type:
1928
- x = x.view(dtype=A.scalar_type)
2020
+ x_dtype = A.scalar_type
2021
+ y_dtype = A.scalar_type
2022
+
2023
+ try:
2024
+ x_view = _vec_array_view(x, x_dtype, expected_scalar_count=ncol * block_shape[1])
2025
+ except ValueError as err:
2026
+ raise ValueError("Incompatible 'x' vector for bsr_mv") from err
2027
+ try:
2028
+ y_view = _vec_array_view(y, y_dtype, expected_scalar_count=nrow * block_shape[0])
2029
+ except ValueError as err:
2030
+ raise ValueError("Incompatible 'y' vector for bsr_mv") from err
1929
2031
 
1930
2032
  if transpose:
1931
2033
  if beta.value == 0.0:
@@ -1942,14 +2044,14 @@ def bsr_mv(
1942
2044
  kernel=_bsr_mv_transpose_kernel,
1943
2045
  device=A.values.device,
1944
2046
  dim=ncol,
1945
- inputs=[alpha, A.offsets, A.columns, A.values, x, y],
2047
+ inputs=[alpha, A.offsets, A.columns, A.values, x_view, y_view],
1946
2048
  )
1947
2049
  else:
1948
2050
  wp.launch(
1949
2051
  kernel=_bsr_mv_kernel,
1950
2052
  device=A.values.device,
1951
2053
  dim=nrow,
1952
- inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
2054
+ inputs=[alpha, A.offsets, A.columns, A.values, x_view, beta, y_view],
1953
2055
  )
1954
2056
 
1955
2057
  return y