warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/sparse.py ADDED
@@ -0,0 +1,2057 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import ctypes
17
+ from typing import Any, Generic, Optional, Tuple, TypeVar, Union
18
+
19
+ import warp as wp
20
+ import warp.types
21
+ import warp.utils
22
+ from warp.types import (
23
+ Array,
24
+ Cols,
25
+ Rows,
26
+ Scalar,
27
+ Vector,
28
+ is_array,
29
+ scalar_types,
30
+ type_is_matrix,
31
+ type_length,
32
+ type_repr,
33
+ type_scalar_type,
34
+ type_to_warp,
35
+ types_equal,
36
+ )
37
+
38
+ # typing hints
39
+
40
+ _BlockType = TypeVar("BlockType") # noqa: PLC0132
41
+
42
+
43
+ class _MatrixBlockType(Generic[Rows, Cols, Scalar]):
44
+ pass
45
+
46
+
47
+ class _ScalarBlockType(Generic[Scalar]):
48
+ pass
49
+
50
+
51
+ BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
52
+
53
+ _struct_cache = {}
54
+
55
+
56
+ class BsrMatrix(Generic[_BlockType]):
57
+ """Untyped base class for BSR and CSR matrices.
58
+
59
+ Should not be constructed directly but through functions such as :func:`bsr_zeros`.
60
+
61
+ Attributes:
62
+ nrow (int): Number of rows of blocks.
63
+ ncol (int): Number of columns of blocks.
64
+ nnz (int): Upper bound for the number of non-zero blocks, used for
65
+ dimensioning launches. The exact number is at ``offsets[nrow-1]``.
66
+ See also :meth:`nnz_sync`.
67
+ offsets (Array[int]): Array of size at least ``1 + nrow`` such that the
68
+ start and end indices of the blocks of row ``r`` are ``offsets[r]``
69
+ and ``offsets[r+1]``, respectively.
70
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing
71
+ block column indices.
72
+ values (Array[BlockType]): Array of size at least equal to ``nnz``
73
+ containing block values.
74
+ """
75
+
76
+ @property
77
+ def scalar_type(self) -> Scalar:
78
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type."""
79
+ return type_scalar_type(self.values.dtype)
80
+
81
+ @property
82
+ def block_shape(self) -> Tuple[int, int]:
83
+ """Shape of the individual blocks."""
84
+ return getattr(self.values.dtype, "_shape_", (1, 1))
85
+
86
+ @property
87
+ def block_size(self) -> int:
88
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
89
+ return type_length(self.values.dtype)
90
+
91
+ @property
92
+ def shape(self) -> Tuple[int, int]:
93
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block."""
94
+ block_shape = self.block_shape
95
+ return (self.nrow * block_shape[0], self.ncol * block_shape[1])
96
+
97
+ @property
98
+ def dtype(self) -> type:
99
+ """Data type for individual block values."""
100
+ return self.values.dtype
101
+
102
+ @property
103
+ def device(self) -> wp.context.Device:
104
+ """Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
105
+ return self.values.device
106
+
107
+ @property
108
+ def scalar_values(self) -> wp.array:
109
+ """Accesses the ``values`` array as a 3d scalar array."""
110
+ if self.block_shape == (1, 1):
111
+ return self.values.reshape((self.nnz, 1, 1))
112
+
113
+ def _as_3d_array(arr):
114
+ return wp.array(
115
+ ptr=arr.ptr,
116
+ capacity=arr.capacity,
117
+ device=arr.device,
118
+ dtype=self.scalar_type,
119
+ shape=(self.nnz, *self.block_shape),
120
+ grad=None if arr.grad is None else _as_3d_array(arr.grad),
121
+ )
122
+
123
+ values_view = _as_3d_array(self.values)
124
+ values_view._ref = self.values # keep ref in case we're garbage collected
125
+ return values_view
126
+
127
+ def uncompress_rows(self, out: wp.array = None) -> wp.array:
128
+ """Compute the row index for each non-zero block from the compressed row offsets."""
129
+ if out is None:
130
+ out = wp.empty(self.nnz, dtype=int, device=self.device)
131
+
132
+ wp.launch(
133
+ kernel=_bsr_get_block_row,
134
+ device=self.device,
135
+ dim=self.nnz,
136
+ inputs=[self.nrow, self.offsets, out],
137
+ )
138
+ return out
139
+
140
+ def nnz_sync(self):
141
+ """Ensure that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed
142
+ and update the nnz upper bound.
143
+
144
+ See also :meth:`copy_nnz_async`.
145
+ """
146
+
147
+ if self._is_nnz_transfer_setup():
148
+ if self.device.is_cuda:
149
+ wp.synchronize_event(self._nnz_event)
150
+ self.nnz = int(self._nnz_buf.numpy()[0])
151
+ return self.nnz
152
+
153
+ def copy_nnz_async(self, known_nnz: int = None):
154
+ """
155
+ Start the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
156
+
157
+ Needs to be called whenever the offsets array has been modified from outside ``warp.sparse``.
158
+
159
+ See also :meth:`nnz_sync`.
160
+ """
161
+ if known_nnz is not None:
162
+ self.nnz = int(known_nnz)
163
+ else:
164
+ self._setup_nnz_transfer()
165
+
166
+ # If a transfer is already ongoing, or if the actual nnz is unknown, schedule a new transfer
167
+ if self._is_nnz_transfer_setup():
168
+ stream = wp.get_stream(self.device) if self.device.is_cuda else None
169
+ wp.copy(src=self.offsets, dest=self._nnz_buf, src_offset=self.nrow, count=1, stream=stream)
170
+ if self.device.is_cuda:
171
+ stream.record_event(self._nnz_event)
172
+
173
+ def _setup_nnz_transfer(self):
174
+ if self._is_nnz_transfer_setup():
175
+ return
176
+
177
+ BsrMatrix.__setattr__(
178
+ self, "_nnz_buf", wp.empty(dtype=int, shape=(1,), device="cpu", pinned=self.device.is_cuda)
179
+ )
180
+ if self.device.is_cuda:
181
+ BsrMatrix.__setattr__(self, "_nnz_event", wp.Event(self.device))
182
+
183
+ def _is_nnz_transfer_setup(self):
184
+ return hasattr(self, "_nnz_buf")
185
+
186
+ def _nnz_transfer_buf_and_event(self):
187
+ self._setup_nnz_transfer()
188
+
189
+ if not self.device.is_cuda:
190
+ return self._nnz_buf, ctypes.c_void_p(None)
191
+ return self._nnz_buf, self._nnz_event.cuda_event
192
+
193
+ # Overloaded math operators
194
+ def __add__(self, y):
195
+ return bsr_axpy(y, bsr_copy(self))
196
+
197
+ def __iadd__(self, y):
198
+ return bsr_axpy(y, self)
199
+
200
+ def __radd__(self, x):
201
+ return bsr_axpy(x, bsr_copy(self))
202
+
203
+ def __sub__(self, y):
204
+ return bsr_axpy(y, bsr_copy(self), alpha=-1.0)
205
+
206
+ def __rsub__(self, x):
207
+ return bsr_axpy(x, bsr_copy(self), beta=-1.0)
208
+
209
+ def __isub__(self, y):
210
+ return bsr_axpy(y, self, alpha=-1.0)
211
+
212
+ def __mul__(self, y):
213
+ return _BsrScalingExpression(self, y)
214
+
215
+ def __rmul__(self, x):
216
+ return _BsrScalingExpression(self, x)
217
+
218
+ def __imul__(self, y):
219
+ return bsr_scale(self, y)
220
+
221
+ def __matmul__(self, y):
222
+ if isinstance(y, wp.array):
223
+ return bsr_mv(self, y)
224
+
225
+ return bsr_mm(self, y)
226
+
227
+ def __rmatmul__(self, x):
228
+ if isinstance(x, wp.array):
229
+ return bsr_mv(self, x, transpose=True)
230
+
231
+ return bsr_mm(x, self)
232
+
233
+ def __imatmul__(self, y):
234
+ return bsr_mm(self, y, self)
235
+
236
+ def __truediv__(self, y):
237
+ return _BsrScalingExpression(self, 1.0 / y)
238
+
239
+ def __neg__(self):
240
+ return _BsrScalingExpression(self, -1.0)
241
+
242
+ def transpose(self):
243
+ """Return a transposed copy of this matrix."""
244
+ return bsr_transposed(self)
245
+
246
+
247
+ def bsr_matrix_t(dtype: BlockType):
248
+ dtype = type_to_warp(dtype)
249
+
250
+ if not type_is_matrix(dtype) and dtype not in scalar_types:
251
+ raise ValueError(f"BsrMatrix block type must be either warp matrix or scalar; got {type_repr(dtype)}")
252
+
253
+ class BsrMatrixTyped(BsrMatrix):
254
+ nrow: int
255
+ """Number of rows of blocks."""
256
+ ncol: int
257
+ """Number of columns of blocks."""
258
+ nnz: int
259
+ """Upper bound for the number of non-zeros."""
260
+ offsets: wp.array(dtype=int)
261
+ """Array of size at least ``1 + nrow``."""
262
+ columns: wp.array(dtype=int)
263
+ """Array of size at least equal to ``nnz``."""
264
+ values: wp.array(dtype=dtype)
265
+
266
+ module = wp.get_module(BsrMatrix.__module__)
267
+
268
+ if hasattr(dtype, "_shape_"):
269
+ type_str = f"{type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
270
+ else:
271
+ type_str = dtype.__name__
272
+ key = f"{BsrMatrix.__qualname__}_{type_str}"
273
+
274
+ if key not in _struct_cache:
275
+ _struct_cache[key] = wp.codegen.Struct(
276
+ cls=BsrMatrixTyped,
277
+ key=key,
278
+ module=module,
279
+ )
280
+
281
+ return _struct_cache[key]
282
+
283
+
284
+ def bsr_zeros(
285
+ rows_of_blocks: int,
286
+ cols_of_blocks: int,
287
+ block_type: BlockType,
288
+ device: wp.context.Devicelike = None,
289
+ ) -> BsrMatrix:
290
+ """Construct and return an empty BSR or CSR matrix with the given shape.
291
+
292
+ Args:
293
+ bsr: The BSR or CSR matrix to set to zero.
294
+ rows_of_blocks: Number of rows of blocks.
295
+ cols_of_blocks: Number of columns of blocks.
296
+ block_type: Type of individual blocks.
297
+ For CSR matrices, this should be a scalar type.
298
+ For BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`).
299
+ device: Device on which to allocate the matrix arrays.
300
+ """
301
+
302
+ bsr = bsr_matrix_t(block_type)()
303
+
304
+ bsr.nrow = int(rows_of_blocks)
305
+ bsr.ncol = int(cols_of_blocks)
306
+ bsr.nnz = int(0)
307
+ bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
308
+ bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
309
+ bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
310
+
311
+ return bsr
312
+
313
+
314
+ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
315
+ if nrow is None:
316
+ nrow = bsr.nrow
317
+ if nnz is None:
318
+ nnz = bsr.nnz
319
+ else:
320
+ # update nnz upper bound
321
+ bsr.nnz = int(nnz)
322
+
323
+ if bsr.offsets.size < nrow + 1:
324
+ bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
325
+ if bsr.columns.size < nnz:
326
+ bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
327
+ if bsr.values.size < nnz:
328
+ bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
329
+
330
+
331
+ def bsr_set_zero(
332
+ bsr: BsrMatrix,
333
+ rows_of_blocks: Optional[int] = None,
334
+ cols_of_blocks: Optional[int] = None,
335
+ ):
336
+ """Set a BSR matrix to zero, possibly changing its size.
337
+
338
+ Args:
339
+ bsr: The BSR or CSR matrix to set to zero.
340
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
341
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
342
+ """
343
+
344
+ if rows_of_blocks is not None:
345
+ bsr.nrow = int(rows_of_blocks)
346
+ if cols_of_blocks is not None:
347
+ bsr.ncol = int(cols_of_blocks)
348
+
349
+ _bsr_ensure_fits(bsr, nnz=0)
350
+ bsr.offsets.zero_()
351
+ bsr.copy_nnz_async(known_nnz=0)
352
+
353
+
354
+ def bsr_set_from_triplets(
355
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
356
+ rows: "Array[int]",
357
+ columns: "Array[int]",
358
+ values: Optional["Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]"] = None,
359
+ prune_numerical_zeros: bool = True,
360
+ masked: bool = False,
361
+ ):
362
+ """Fill a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
363
+
364
+ The first dimension of the three input arrays must match and indicates the number of COO triplets.
365
+
366
+ Args:
367
+ dest: Sparse matrix to populate.
368
+ rows: Row index for each non-zero.
369
+ columns: Columns index for each non-zero.
370
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
371
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
372
+ If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
373
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
374
+ masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
375
+ """
376
+
377
+ if rows.device != columns.device or rows.device != dest.device:
378
+ raise ValueError("All arguments must reside on the same device")
379
+
380
+ if rows.shape[0] != columns.shape[0]:
381
+ raise ValueError("All triplet arrays must have the same length")
382
+
383
+ # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
384
+ if values is not None:
385
+ if values.device != rows.device:
386
+ raise ValueError("All arguments must reside on the same device")
387
+
388
+ if values.shape[0] != rows.shape[0]:
389
+ raise ValueError("All triplet arrays must have the same length")
390
+
391
+ if values.ndim == 1:
392
+ if values.dtype != dest.values.dtype:
393
+ raise ValueError("Values array type must correspond to that of dest matrix")
394
+ elif values.ndim == 3:
395
+ if values.shape[1:] != dest.block_shape:
396
+ raise ValueError(
397
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
398
+ )
399
+
400
+ if type_scalar_type(values.dtype) != dest.scalar_type:
401
+ raise ValueError("Scalar type of values array should correspond to that of matrix")
402
+
403
+ if not values.is_contiguous:
404
+ raise ValueError("Multi-dimensional values array should be contiguous")
405
+ else:
406
+ raise ValueError("Number of dimension for values array should be 1 or 3")
407
+
408
+ nnz = rows.shape[0]
409
+ if nnz == 0:
410
+ bsr_set_zero(dest)
411
+ return
412
+
413
+ # Increase dest array sizes if needed
414
+ if not masked:
415
+ _bsr_ensure_fits(dest, nnz=nnz)
416
+
417
+ device = dest.values.device
418
+ scalar_type = dest.scalar_type
419
+ from warp.context import runtime
420
+
421
+ if device.is_cpu:
422
+ if scalar_type == wp.float32:
423
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
424
+ elif scalar_type == wp.float64:
425
+ native_func = runtime.core.bsr_matrix_from_triplets_double_host
426
+ else:
427
+ if scalar_type == wp.float32:
428
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
429
+ elif scalar_type == wp.float64:
430
+ native_func = runtime.core.bsr_matrix_from_triplets_double_device
431
+
432
+ if not native_func:
433
+ raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
434
+
435
+ nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
436
+
437
+ with wp.ScopedDevice(device):
438
+ native_func(
439
+ dest.block_shape[0],
440
+ dest.block_shape[1],
441
+ dest.nrow,
442
+ nnz,
443
+ ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
444
+ ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
445
+ None if values is None else ctypes.cast(values.ptr, ctypes.c_void_p),
446
+ prune_numerical_zeros,
447
+ masked,
448
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
449
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
450
+ None if values is None else ctypes.cast(dest.values.ptr, ctypes.c_void_p),
451
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
452
+ nnz_event,
453
+ )
454
+
455
+
456
+ def bsr_from_triplets(
457
+ rows_of_blocks: int,
458
+ cols_of_blocks: int,
459
+ rows: "Array[int]",
460
+ columns: "Array[int]",
461
+ values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
462
+ prune_numerical_zeros: bool = True,
463
+ ):
464
+ """Constructs a BSR matrix with values defined by coordinate-oriented (COO) triplets.
465
+
466
+ The first dimension of the three input arrays must match and indicates the number of COO triplets.
467
+
468
+ Args:
469
+ rows_of_blocks: Number of rows of blocks.
470
+ cols_of_blocks: Number of columns of blocks.
471
+ rows: Row index for each non-zero.
472
+ columns: Columns index for each non-zero.
473
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
474
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
475
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
476
+ """
477
+
478
+ if values.ndim == 3:
479
+ block_type = wp.mat(shape=values.shape[1:], dtype=values.dtype)
480
+ else:
481
+ block_type = values.dtype
482
+
483
+ A = bsr_zeros(
484
+ rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
485
+ )
486
+ bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
487
+ return A
488
+
489
+
490
+ class _BsrExpression(Generic[_BlockType]):
491
+ pass
492
+
493
+
494
+ class _BsrScalingExpression(_BsrExpression):
495
+ def __init__(self, mat, scale):
496
+ self.mat = mat
497
+ self.scale = scale
498
+
499
+ def eval(self):
500
+ return bsr_copy(self)
501
+
502
+ @property
503
+ def nrow(self) -> int:
504
+ return self.mat.nrow
505
+
506
+ @property
507
+ def ncol(self) -> int:
508
+ return self.mat.ncol
509
+
510
+ @property
511
+ def nnz(self) -> int:
512
+ return self.mat.nnz
513
+
514
+ @property
515
+ def offsets(self) -> wp.array:
516
+ return self.mat.offsets
517
+
518
+ @property
519
+ def columns(self) -> wp.array:
520
+ return self.mat.columns
521
+
522
+ @property
523
+ def scalar_type(self) -> Scalar:
524
+ return self.mat.scalar_type
525
+
526
+ @property
527
+ def block_shape(self) -> Tuple[int, int]:
528
+ return self.mat.block_shape
529
+
530
+ @property
531
+ def block_size(self) -> int:
532
+ return self.mat.block_size
533
+
534
+ @property
535
+ def shape(self) -> Tuple[int, int]:
536
+ return self.mat.shape
537
+
538
+ @property
539
+ def dtype(self) -> type:
540
+ return self.mat.dtype
541
+
542
+ @property
543
+ def device(self) -> wp.context.Device:
544
+ return self.mat.device
545
+
546
+ # Overloaded math operators
547
+ def __add__(self, y):
548
+ return bsr_axpy(y, bsr_copy(self.mat), alpha=self.scale)
549
+
550
+ def __radd__(self, x):
551
+ return bsr_axpy(x, bsr_copy(self.mat), beta=self.scale)
552
+
553
+ def __sub__(self, y):
554
+ return bsr_axpy(y, bsr_copy(self.mat), alpha=-self.scale)
555
+
556
+ def __rsub__(self, x):
557
+ return bsr_axpy(x, bsr_copy(self.mat), beta=-self.scale)
558
+
559
+ def __mul__(self, y):
560
+ return _BsrScalingExpression(self.mat, y * self.scale)
561
+
562
+ def __rmul__(self, x):
563
+ return _BsrScalingExpression(self.mat, x * self.scale)
564
+
565
+ def __matmul__(self, y):
566
+ if isinstance(y, wp.array):
567
+ return bsr_mv(self.mat, y, alpha=self.scale)
568
+
569
+ return bsr_mm(self.mat, y, alpha=self.scale)
570
+
571
+ def __rmatmul__(self, x):
572
+ if isinstance(x, wp.array):
573
+ return bsr_mv(self.mat, x, alpha=self.scale, transpose=True)
574
+
575
+ return bsr_mm(x, self.mat, alpha=self.scale)
576
+
577
+ def __truediv__(self, y):
578
+ return _BsrScalingExpression(self.mat, self.scale / y)
579
+
580
+ def __neg__(self):
581
+ return _BsrScalingExpression(self.mat, -self.scale)
582
+
583
+ def transpose(self):
584
+ """Returns a transposed copy of this matrix"""
585
+ return _BsrScalingExpression(self.mat.transpose(), self.scale)
586
+
587
+
588
+ BsrMatrixOrExpression = Union[BsrMatrix[_BlockType], _BsrExpression[_BlockType]]
589
+
590
+
591
+ def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
592
+ if isinstance(bsr, BsrMatrix):
593
+ return bsr, 1.0
594
+ if isinstance(bsr, _BsrScalingExpression):
595
+ return bsr.mat, bsr.scale
596
+
597
+ raise ValueError("Argument cannot be interpreted as a BsrMatrix")
598
+
599
+
600
+ @wp.func
601
+ def _bsr_row_index(
602
+ offsets: wp.array(dtype=int),
603
+ row_count: int,
604
+ block: int,
605
+ ):
606
+ """Index of the row containing a block, or -1 if non-existing."""
607
+ return wp.where(block < offsets[row_count], wp.lower_bound(offsets, 0, row_count + 1, block + 1), 0) - 1
608
+
609
+
610
+ @wp.func
611
+ def _bsr_block_index(
612
+ row: int,
613
+ col: int,
614
+ bsr_offsets: wp.array(dtype=int),
615
+ bsr_columns: wp.array(dtype=int),
616
+ ):
617
+ """Index of the block at block-coordinates (row, col), or -1 if non-existing.
618
+ Assumes bsr_columns is sorted.
619
+ """
620
+
621
+ if row < 0:
622
+ return -1
623
+
624
+ mask_row_beg = bsr_offsets[row]
625
+ mask_row_end = bsr_offsets[row + 1]
626
+
627
+ if mask_row_beg == mask_row_end:
628
+ return -1
629
+
630
+ block_index = wp.lower_bound(bsr_columns, mask_row_beg, mask_row_end, col)
631
+ return wp.where(bsr_columns[block_index] == col, block_index, -1)
632
+
633
+
634
+ @wp.kernel(enable_backward=False)
635
+ def _bsr_assign_list_blocks(
636
+ src_subrows: int,
637
+ src_subcols: int,
638
+ dest_subrows: int,
639
+ dest_subcols: int,
640
+ src_row_count: int,
641
+ src_offsets: wp.array(dtype=int),
642
+ src_columns: wp.array(dtype=int),
643
+ dest_rows: wp.array(dtype=int),
644
+ dest_cols: wp.array(dtype=int),
645
+ ):
646
+ block, subrow, subcol = wp.tid()
647
+ dest_block = (block * src_subcols + subcol) * src_subrows + subrow
648
+
649
+ row = _bsr_row_index(src_offsets, src_row_count, block)
650
+ if row == -1:
651
+ dest_rows[dest_block] = row # invalid
652
+ dest_cols[dest_block] = row
653
+ else:
654
+ dest_subrow = row * src_subrows + subrow
655
+ dest_subcol = src_columns[block] * src_subcols + subcol
656
+ dest_rows[dest_block] = dest_subrow // dest_subrows
657
+ dest_cols[dest_block] = dest_subcol // dest_subcols
658
+
659
+
660
+ @wp.kernel
661
+ def _bsr_assign_copy_blocks(
662
+ scale: Any,
663
+ src_subrows: int,
664
+ src_subcols: int,
665
+ dest_subrows: int,
666
+ dest_subcols: int,
667
+ src_row_count: int,
668
+ src_offsets: wp.array(dtype=int),
669
+ src_columns: wp.array(dtype=int),
670
+ src_values: wp.array3d(dtype=Any),
671
+ dest_offsets: wp.array(dtype=int),
672
+ dest_columns: wp.array(dtype=int),
673
+ dest_values: wp.array3d(dtype=Any),
674
+ ):
675
+ src_block = wp.tid()
676
+ src_block, subrow, subcol = wp.tid()
677
+
678
+ src_row = _bsr_row_index(src_offsets, src_row_count, src_block)
679
+ if src_row == -1:
680
+ return
681
+
682
+ src_col = src_columns[src_block]
683
+
684
+ dest_subrow = src_row * src_subrows + subrow
685
+ dest_subcol = src_col * src_subcols + subcol
686
+ dest_row = dest_subrow // dest_subrows
687
+ dest_col = dest_subcol // dest_subcols
688
+
689
+ dest_block = _bsr_block_index(dest_row, dest_col, dest_offsets, dest_columns)
690
+ if dest_block == -1:
691
+ return
692
+
693
+ split_row = dest_subrow - dest_subrows * dest_row
694
+ split_col = dest_subcol - dest_subcols * dest_col
695
+
696
+ rows_per_subblock = src_values.shape[1] // src_subrows
697
+ cols_per_subblock = src_values.shape[2] // src_subcols
698
+
699
+ dest_base_i = split_row * rows_per_subblock
700
+ dest_base_j = split_col * cols_per_subblock
701
+
702
+ src_base_i = subrow * rows_per_subblock
703
+ src_base_j = subcol * cols_per_subblock
704
+
705
+ for i in range(rows_per_subblock):
706
+ for j in range(cols_per_subblock):
707
+ dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
708
+ scale * src_values[src_block, i + src_base_i, j + src_base_j]
709
+ )
710
+
711
+
712
+ def bsr_assign(
713
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
714
+ src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
715
+ structure_only: bool = False,
716
+ masked: bool = False,
717
+ ):
718
+ """Copy the content of the ``src`` BSR matrix to ``dest``.
719
+
720
+ Args:
721
+ src: Matrix to be copied.
722
+ dest: Destination matrix. May have a different block shape or scalar type
723
+ than ``src``, in which case the required casting will be performed.
724
+ structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
725
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
726
+ casting if the two matrices use distinct scalar types.
727
+ masked: If ``True``, prevent the assignment operation from adding new non-zeros blocks to ``dest``.
728
+ """
729
+
730
+ src, src_scale = _extract_matrix_and_scale(src)
731
+
732
+ if dest.values.device != src.values.device:
733
+ raise ValueError("Source and destination matrices must reside on the same device")
734
+
735
+ if src.block_shape[0] >= dest.block_shape[0]:
736
+ src_subrows = src.block_shape[0] // dest.block_shape[0]
737
+ dest_subrows = 1
738
+ else:
739
+ dest_subrows = dest.block_shape[0] // src.block_shape[0]
740
+ src_subrows = 1
741
+
742
+ if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
743
+ raise ValueError(
744
+ f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {src.block_shape[0]}, {dest.block_shape[0]})"
745
+ )
746
+
747
+ if src.block_shape[1] >= dest.block_shape[1]:
748
+ src_subcols = src.block_shape[1] // dest.block_shape[1]
749
+ dest_subcols = 1
750
+ else:
751
+ dest_subcols = dest.block_shape[1] // src.block_shape[1]
752
+ src_subcols = 1
753
+
754
+ if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
755
+ raise ValueError(
756
+ f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {src.block_shape[1]}, {dest.block_shape[1]})"
757
+ )
758
+
759
+ dest_nrow = (src.nrow * src_subrows) // dest_subrows
760
+ dest_ncol = (src.ncol * src_subcols) // dest_subcols
761
+
762
+ if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
763
+ raise ValueError("The requested block shape does not evenly divide the source matrix")
764
+
765
+ nnz_alloc = src.nnz * src_subrows * src_subcols
766
+ if masked:
767
+ if dest_nrow != dest.nrow or dest_ncol != dest.ncol:
768
+ raise ValueError(
769
+ f"Incompatible destination matrix size, expected ({dest_nrow}, {dest_ncol}), got ({dest.nrow}, {dest.ncol})"
770
+ )
771
+ else:
772
+ dest.nrow = dest_nrow
773
+ dest.ncol = dest_ncol
774
+ _bsr_ensure_fits(dest, nnz=nnz_alloc)
775
+
776
+ if dest.block_shape == src.block_shape and not masked:
777
+ # Direct copy
778
+
779
+ wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
780
+ dest.copy_nnz_async()
781
+
782
+ if nnz_alloc > 0:
783
+ wp.copy(dest=dest.columns, src=src.columns, count=nnz_alloc)
784
+
785
+ if not structure_only:
786
+ warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
787
+ bsr_scale(dest, src_scale)
788
+
789
+ else:
790
+ # Masked and/or multiple src blocks per dest block, go through COO format
791
+
792
+ # Compute destination rows and columns
793
+ dest_rows = wp.empty(nnz_alloc, dtype=int, device=dest.device)
794
+ dest_cols = wp.empty(nnz_alloc, dtype=int, device=dest.device)
795
+ wp.launch(
796
+ _bsr_assign_list_blocks,
797
+ dim=(src.nnz, src_subrows, src_subcols),
798
+ device=dest.device,
799
+ inputs=[
800
+ src_subrows,
801
+ src_subcols,
802
+ dest_subrows,
803
+ dest_subcols,
804
+ src.nrow,
805
+ src.offsets,
806
+ src.columns,
807
+ dest_rows,
808
+ dest_cols,
809
+ ],
810
+ )
811
+
812
+ # Compute destination offsets from triplets
813
+ from warp.context import runtime
814
+
815
+ if dest.device.is_cpu:
816
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
817
+ else:
818
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
819
+
820
+ nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
821
+ with wp.ScopedDevice(dest.device):
822
+ native_func(
823
+ dest.block_shape[0],
824
+ dest.block_shape[1],
825
+ dest.nrow,
826
+ nnz_alloc,
827
+ ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
828
+ ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
829
+ 0,
830
+ False,
831
+ masked,
832
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
833
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
834
+ 0,
835
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
836
+ nnz_event,
837
+ )
838
+
839
+ # merge block values
840
+ if not structure_only:
841
+ dest.values.zero_()
842
+ wp.launch(
843
+ _bsr_assign_copy_blocks,
844
+ dim=(src.nnz, src_subrows, src_subcols),
845
+ device=dest.device,
846
+ inputs=[
847
+ src.scalar_type(src_scale),
848
+ src_subrows,
849
+ src_subcols,
850
+ dest_subrows,
851
+ dest_subcols,
852
+ src.nrow,
853
+ src.offsets,
854
+ src.columns,
855
+ src.scalar_values,
856
+ dest.offsets,
857
+ dest.columns,
858
+ dest.scalar_values,
859
+ ],
860
+ )
861
+
862
+
863
+ def bsr_copy(
864
+ A: BsrMatrixOrExpression,
865
+ scalar_type: Optional[Scalar] = None,
866
+ block_shape: Optional[Tuple[int, int]] = None,
867
+ structure_only: bool = False,
868
+ ):
869
+ """Return a copy of matrix ``A``, possibly changing its scalar type.
870
+
871
+ Args:
872
+ A: Matrix to be copied.
873
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from ``A``.
874
+ block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from ``A``.
875
+ Both dimensions of ``block_shape`` must be either a multiple or an exact divider of the ones from ``A``.
876
+ structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
877
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
878
+ casting if the two matrices use distinct scalar types.
879
+ """
880
+ if scalar_type is None:
881
+ scalar_type = A.scalar_type
882
+ if block_shape is None:
883
+ block_shape = A.block_shape
884
+
885
+ if block_shape == (1, 1):
886
+ block_type = scalar_type
887
+ else:
888
+ block_type = wp.mat(shape=block_shape, dtype=scalar_type)
889
+
890
+ copy = bsr_zeros(
891
+ rows_of_blocks=A.nrow,
892
+ cols_of_blocks=A.ncol,
893
+ block_type=block_type,
894
+ device=A.device,
895
+ )
896
+ bsr_assign(dest=copy, src=A, structure_only=structure_only)
897
+ return copy
898
+
899
+
900
+ def bsr_set_transpose(
901
+ dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
902
+ src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
903
+ ):
904
+ """Assign the transposed matrix ``src`` to matrix ``dest``."""
905
+
906
+ src, src_scale = _extract_matrix_and_scale(src)
907
+
908
+ if dest.values.device != src.values.device:
909
+ raise ValueError("All arguments must reside on the same device")
910
+
911
+ if dest.scalar_type != src.scalar_type:
912
+ raise ValueError("All arguments must have the same scalar type")
913
+
914
+ transpose_block_shape = src.block_shape[::-1]
915
+
916
+ if dest.block_shape != transpose_block_shape:
917
+ raise ValueError(f"Destination block shape must be {transpose_block_shape}")
918
+
919
+ nnz = src.nnz
920
+ dest.nrow = src.ncol
921
+ dest.ncol = src.nrow
922
+
923
+ if nnz == 0:
924
+ bsr_set_zero(dest)
925
+ return
926
+
927
+ # Increase dest array sizes if needed
928
+ _bsr_ensure_fits(dest, nnz=nnz)
929
+
930
+ from warp.context import runtime
931
+
932
+ if dest.values.device.is_cpu:
933
+ if dest.scalar_type == wp.float32:
934
+ native_func = runtime.core.bsr_transpose_float_host
935
+ elif dest.scalar_type == wp.float64:
936
+ native_func = runtime.core.bsr_transpose_double_host
937
+ else:
938
+ if dest.scalar_type == wp.float32:
939
+ native_func = runtime.core.bsr_transpose_float_device
940
+ elif dest.scalar_type == wp.float64:
941
+ native_func = runtime.core.bsr_transpose_double_device
942
+
943
+ if not native_func:
944
+ raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
945
+
946
+ with wp.ScopedDevice(dest.device):
947
+ native_func(
948
+ src.block_shape[0],
949
+ src.block_shape[1],
950
+ src.nrow,
951
+ src.ncol,
952
+ nnz,
953
+ ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
954
+ ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
955
+ ctypes.cast(src.values.ptr, ctypes.c_void_p),
956
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
957
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
958
+ ctypes.cast(dest.values.ptr, ctypes.c_void_p),
959
+ )
960
+
961
+ dest.copy_nnz_async()
962
+ bsr_scale(dest, src_scale)
963
+
964
+
965
+ def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
966
+ """Return a copy of the transposed matrix ``A``."""
967
+
968
+ if A.block_shape == (1, 1):
969
+ block_type = A.values.dtype
970
+ else:
971
+ block_type = wp.mat(shape=A.block_shape[::-1], dtype=A.scalar_type)
972
+
973
+ transposed = bsr_zeros(
974
+ rows_of_blocks=A.ncol,
975
+ cols_of_blocks=A.nrow,
976
+ block_type=block_type,
977
+ device=A.device,
978
+ )
979
+ bsr_set_transpose(dest=transposed, src=A)
980
+ return transposed
981
+
982
+
983
+ @wp.kernel
984
+ def _bsr_get_diag_kernel(
985
+ scale: Any,
986
+ A_offsets: wp.array(dtype=int),
987
+ A_columns: wp.array(dtype=int),
988
+ A_values: wp.array(dtype=Any),
989
+ out: wp.array(dtype=Any),
990
+ ):
991
+ row = wp.tid()
992
+
993
+ diag = _bsr_block_index(row, row, A_offsets, A_columns)
994
+ if diag != -1:
995
+ out[row] = scale * A_values[diag]
996
+
997
+
998
+ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
999
+ """Return the array of blocks that constitute the diagonal of a sparse matrix.
1000
+
1001
+ Args:
1002
+ A: The sparse matrix from which to extract the diagonal.
1003
+ out: If provided, the array into which to store the diagonal blocks.
1004
+ """
1005
+
1006
+ A, scale = _extract_matrix_and_scale(A)
1007
+
1008
+ dim = min(A.nrow, A.ncol)
1009
+
1010
+ if out is None:
1011
+ out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
1012
+ else:
1013
+ if out.dtype != A.values.dtype:
1014
+ raise ValueError(f"Output array must have type {A.values.dtype}")
1015
+ if out.device != A.values.device:
1016
+ raise ValueError(f"Output array must reside on device {A.values.device}")
1017
+ if out.shape[0] < dim:
1018
+ raise ValueError(f"Output array must be of length at least {dim}")
1019
+
1020
+ wp.launch(
1021
+ kernel=_bsr_get_diag_kernel,
1022
+ dim=dim,
1023
+ device=A.values.device,
1024
+ inputs=[A.scalar_type(scale), A.offsets, A.columns, A.values, out],
1025
+ )
1026
+
1027
+ return out
1028
+
1029
+
1030
+ @wp.kernel(enable_backward=False)
1031
+ def _bsr_set_diag_kernel(
1032
+ nnz: int,
1033
+ A_offsets: wp.array(dtype=int),
1034
+ A_columns: wp.array(dtype=int),
1035
+ ):
1036
+ row = wp.tid()
1037
+ A_offsets[row] = wp.min(row, nnz)
1038
+ if row < nnz:
1039
+ A_columns[row] = row
1040
+
1041
+
1042
+ def bsr_set_diag(
1043
+ A: BsrMatrix[BlockType],
1044
+ diag: "Union[BlockType, Array[BlockType]]",
1045
+ rows_of_blocks: Optional[int] = None,
1046
+ cols_of_blocks: Optional[int] = None,
1047
+ ) -> None:
1048
+ """Set ``A`` as a block-diagonal matrix.
1049
+
1050
+ Args:
1051
+ A: The sparse matrix to modify.
1052
+ diag: Specifies the values for diagonal blocks. Can be one of:
1053
+
1054
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1055
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1056
+ - ``None``: Diagonal block values are left uninitialized
1057
+
1058
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
1059
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
1060
+
1061
+ The shape of the matrix will be defined one of the following, in this order:
1062
+
1063
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1064
+ If only one is given, the second is assumed equal.
1065
+ - The first dimension of ``diag``, if ``diag`` is an array
1066
+ - The current dimensions of ``A`` otherwise
1067
+ """
1068
+
1069
+ if rows_of_blocks is None and cols_of_blocks is not None:
1070
+ rows_of_blocks = cols_of_blocks
1071
+ if cols_of_blocks is None and rows_of_blocks is not None:
1072
+ cols_of_blocks = rows_of_blocks
1073
+
1074
+ if is_array(diag):
1075
+ if rows_of_blocks is None:
1076
+ rows_of_blocks = diag.shape[0]
1077
+ cols_of_blocks = diag.shape[0]
1078
+
1079
+ if rows_of_blocks is not None:
1080
+ A.nrow = rows_of_blocks
1081
+ A.ncol = cols_of_blocks
1082
+
1083
+ nnz = min(A.nrow, A.ncol)
1084
+ _bsr_ensure_fits(A, nnz=nnz)
1085
+
1086
+ wp.launch(
1087
+ kernel=_bsr_set_diag_kernel,
1088
+ dim=nnz + 1,
1089
+ device=A.offsets.device,
1090
+ inputs=[nnz, A.offsets, A.columns],
1091
+ )
1092
+
1093
+ if is_array(diag):
1094
+ wp.copy(src=diag, dest=A.values, count=nnz)
1095
+ elif diag is not None:
1096
+ A.values.fill_(diag)
1097
+
1098
+ A.copy_nnz_async(known_nnz=nnz)
1099
+
1100
+
1101
+ def bsr_diag(
1102
+ diag: Optional[Union[BlockType, Array[BlockType]]] = None,
1103
+ rows_of_blocks: Optional[int] = None,
1104
+ cols_of_blocks: Optional[int] = None,
1105
+ block_type: Optional[BlockType] = None,
1106
+ device=None,
1107
+ ) -> BsrMatrix["BlockType"]:
1108
+ """Create and return a block-diagonal BSR matrix from an given block value or array of block values.
1109
+
1110
+ Args:
1111
+ diag: Specifies the values for diagonal blocks. Can be one of:
1112
+
1113
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1114
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1115
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
1116
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
1117
+ block_type: If ``diag`` is ``None``, block type of the matrix. Otherwise deduced from ``diag``
1118
+ device: If ``diag`` is not a Warp array, device on which to allocate the matrix. Otherwise deduced from ``diag``
1119
+
1120
+ The shape of the matrix will be defined one of the following, in this order:
1121
+
1122
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1123
+ If only one is given, the second is assumed equal.
1124
+ - The first dimension of ``diag`` if ``diag`` is an array.
1125
+ """
1126
+
1127
+ if rows_of_blocks is None and cols_of_blocks is not None:
1128
+ rows_of_blocks = cols_of_blocks
1129
+ if cols_of_blocks is None and rows_of_blocks is not None:
1130
+ cols_of_blocks = rows_of_blocks
1131
+
1132
+ if is_array(diag):
1133
+ if rows_of_blocks is None:
1134
+ rows_of_blocks = diag.shape[0]
1135
+ cols_of_blocks = diag.shape[0]
1136
+
1137
+ block_type = diag.dtype
1138
+ device = diag.device
1139
+ else:
1140
+ if rows_of_blocks is None:
1141
+ raise ValueError(
1142
+ "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
1143
+ )
1144
+
1145
+ if block_type is None:
1146
+ if diag is None:
1147
+ raise ValueError("Either `diag` or `block_type` needs to be provided")
1148
+
1149
+ block_type = type(diag)
1150
+ if not type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
1151
+ block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
1152
+
1153
+ A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
1154
+ bsr_set_diag(A, diag)
1155
+ return A
1156
+
1157
+
1158
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None) -> None:
1159
+ """Set ``A`` as the identity matrix.
1160
+
1161
+ Args:
1162
+ A: The sparse matrix to modify.
1163
+ rows_of_blocks: If provided, the matrix will be resized as a square
1164
+ matrix with ``rows_of_blocks`` rows and columns.
1165
+ """
1166
+
1167
+ if A.block_shape == (1, 1):
1168
+ identity = A.scalar_type(1.0)
1169
+ else:
1170
+ from numpy import eye
1171
+
1172
+ identity = eye(A.block_shape[0])
1173
+
1174
+ bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
1175
+
1176
+
1177
+ def bsr_identity(
1178
+ rows_of_blocks: int,
1179
+ block_type: BlockType[Rows, Rows, Scalar],
1180
+ device: wp.context.Devicelike = None,
1181
+ ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
1182
+ """Create and return a square identity matrix.
1183
+
1184
+ Args:
1185
+ rows_of_blocks: Number of rows and columns of blocks in the created matrix.
1186
+ block_type: Block type for the newly created matrix. Must be square
1187
+ device: Device onto which to allocate the data arrays
1188
+ """
1189
+ A = bsr_zeros(
1190
+ rows_of_blocks=rows_of_blocks,
1191
+ cols_of_blocks=rows_of_blocks,
1192
+ block_type=block_type,
1193
+ device=device,
1194
+ )
1195
+ bsr_set_identity(A)
1196
+ return A
1197
+
1198
+
1199
+ @wp.kernel
1200
+ def _bsr_scale_kernel(
1201
+ alpha: Any,
1202
+ values: wp.array(dtype=Any),
1203
+ ):
1204
+ values[wp.tid()] = alpha * values[wp.tid()]
1205
+
1206
+
1207
+ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1208
+ """Perform the operation ``x := alpha * x`` on BSR matrix ``x`` and return ``x``."""
1209
+
1210
+ x, scale = _extract_matrix_and_scale(x)
1211
+ alpha *= scale
1212
+
1213
+ if alpha != 1.0 and x.nnz > 0:
1214
+ if alpha == 0.0:
1215
+ bsr_set_zero(x)
1216
+ else:
1217
+ alpha = x.scalar_type(alpha)
1218
+
1219
+ wp.launch(
1220
+ kernel=_bsr_scale_kernel,
1221
+ dim=x.nnz,
1222
+ device=x.values.device,
1223
+ inputs=[alpha, x.values],
1224
+ )
1225
+
1226
+ return x
1227
+
1228
+
1229
+ @wp.kernel(enable_backward=False)
1230
+ def _bsr_get_block_row(row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
1231
+ block = wp.tid()
1232
+ rows[block] = _bsr_row_index(bsr_offsets, row_count, block)
1233
+
1234
+
1235
+ @wp.kernel
1236
+ def _bsr_axpy_add_block(
1237
+ src_offset: int,
1238
+ scale: Any,
1239
+ rows: wp.array(dtype=int),
1240
+ cols: wp.array(dtype=int),
1241
+ dst_offsets: wp.array(dtype=int),
1242
+ dst_columns: wp.array(dtype=int),
1243
+ src_values: wp.array(dtype=Any),
1244
+ dst_values: wp.array(dtype=Any),
1245
+ ):
1246
+ i = wp.tid()
1247
+ row = rows[i + src_offset]
1248
+ col = cols[i + src_offset]
1249
+
1250
+ block = _bsr_block_index(row, col, dst_offsets, dst_columns)
1251
+ if block != -1:
1252
+ dst_values[block] += scale * src_values[i]
1253
+
1254
+
1255
+ class bsr_axpy_work_arrays:
1256
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls."""
1257
+
1258
+ def __init__(self):
1259
+ self._reset(None)
1260
+
1261
+ def _reset(self, device):
1262
+ self.device = device
1263
+ self._sum_rows = None
1264
+ self._sum_cols = None
1265
+ self._old_y_values = None
1266
+ self._old_x_values = None
1267
+
1268
+ def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
1269
+ if self.device != device:
1270
+ self._reset(device)
1271
+
1272
+ if self._sum_rows is None or self._sum_rows.size < sum_nnz:
1273
+ self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
1274
+ if self._sum_cols is None or self._sum_cols.size < sum_nnz:
1275
+ self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
1276
+
1277
+ if self._old_y_values is None or self._old_y_values.size < y.nnz:
1278
+ self._old_y_values = wp.empty(shape=(y.nnz,), dtype=y.values.dtype, device=self.device)
1279
+
1280
+
1281
+ def bsr_axpy(
1282
+ x: BsrMatrixOrExpression,
1283
+ y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
1284
+ alpha: Scalar = 1.0,
1285
+ beta: Scalar = 1.0,
1286
+ masked: bool = False,
1287
+ work_arrays: Optional[bsr_axpy_work_arrays] = None,
1288
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1289
+ """
1290
+ Perform the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices ``x`` and ``y`` and return ``y``.
1291
+
1292
+ The ``x`` and ``y`` matrices are allowed to alias.
1293
+
1294
+ Args:
1295
+ x: Read-only right-hand-side.
1296
+ y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
1297
+ alpha: Uniform scaling factor for ``x``.
1298
+ beta: Uniform scaling factor for ``y``.
1299
+ masked: If ``True``, discard all blocks from ``x`` which are not
1300
+ existing non-zeros of ``y``.
1301
+ work_arrays: In most cases, this function will require the use of temporary storage.
1302
+ This storage can be reused across calls by passing an instance of
1303
+ :class:`bsr_axpy_work_arrays` in ``work_arrays``.
1304
+ """
1305
+
1306
+ x, x_scale = _extract_matrix_and_scale(x)
1307
+ alpha *= x_scale
1308
+
1309
+ if y is None:
1310
+ if masked:
1311
+ raise ValueError("Left-hand-side 'y' matrix must be provided for masked addition")
1312
+
1313
+ # If not output matrix is provided, allocate it for convenience
1314
+ y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
1315
+ beta = 0.0
1316
+
1317
+ x_nnz = x.nnz
1318
+ y_nnz = y.nnz
1319
+
1320
+ # Handle easy cases first
1321
+ if beta == 0.0 or y_nnz == 0:
1322
+ bsr_assign(src=x, dest=y)
1323
+ return bsr_scale(y, alpha=alpha)
1324
+
1325
+ if alpha == 0.0 or x_nnz == 0:
1326
+ return bsr_scale(y, alpha=beta)
1327
+
1328
+ if not isinstance(alpha, y.scalar_type):
1329
+ alpha = y.scalar_type(alpha)
1330
+ if not isinstance(beta, y.scalar_type):
1331
+ beta = y.scalar_type(beta)
1332
+
1333
+ if x == y:
1334
+ # Aliasing case
1335
+ return bsr_scale(y, alpha=alpha.value + beta.value)
1336
+
1337
+ # General case
1338
+
1339
+ if x.values.device != y.values.device:
1340
+ raise ValueError("All arguments must reside on the same device")
1341
+
1342
+ if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
1343
+ raise ValueError("Matrices must have the same block type")
1344
+
1345
+ if x.nrow != y.nrow or x.ncol != y.ncol:
1346
+ raise ValueError("Matrices must have the same number of rows and columns")
1347
+
1348
+ if work_arrays is None:
1349
+ work_arrays = bsr_axpy_work_arrays()
1350
+
1351
+ sum_nnz = x_nnz + y_nnz
1352
+ device = y.values.device
1353
+ work_arrays._allocate(device, y, sum_nnz)
1354
+
1355
+ wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
1356
+ y.uncompress_rows(out=work_arrays._sum_rows)
1357
+
1358
+ wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
1359
+ x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
1360
+
1361
+ # Save old y values before overwriting matrix
1362
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y_nnz)
1363
+
1364
+ # Increase dest array sizes if needed
1365
+ if not masked:
1366
+ _bsr_ensure_fits(y, nnz=sum_nnz)
1367
+
1368
+ from warp.context import runtime
1369
+
1370
+ if device.is_cpu:
1371
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
1372
+ else:
1373
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
1374
+
1375
+ old_y_nnz = y_nnz
1376
+ nnz_buf, nnz_event = y._nnz_transfer_buf_and_event()
1377
+
1378
+ with wp.ScopedDevice(y.device):
1379
+ native_func(
1380
+ y.block_shape[0],
1381
+ y.block_shape[1],
1382
+ y.nrow,
1383
+ sum_nnz,
1384
+ ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1385
+ ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1386
+ 0,
1387
+ False,
1388
+ masked,
1389
+ ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1390
+ ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1391
+ 0,
1392
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1393
+ nnz_event,
1394
+ )
1395
+
1396
+ y.values.zero_()
1397
+
1398
+ wp.launch(
1399
+ kernel=_bsr_axpy_add_block,
1400
+ device=device,
1401
+ dim=old_y_nnz,
1402
+ inputs=[
1403
+ 0,
1404
+ beta,
1405
+ work_arrays._sum_rows,
1406
+ work_arrays._sum_cols,
1407
+ y.offsets,
1408
+ y.columns,
1409
+ work_arrays._old_y_values,
1410
+ y.values,
1411
+ ],
1412
+ )
1413
+
1414
+ wp.launch(
1415
+ kernel=_bsr_axpy_add_block,
1416
+ device=device,
1417
+ dim=x_nnz,
1418
+ inputs=[
1419
+ old_y_nnz,
1420
+ alpha,
1421
+ work_arrays._sum_rows,
1422
+ work_arrays._sum_cols,
1423
+ y.offsets,
1424
+ y.columns,
1425
+ x.values,
1426
+ y.values,
1427
+ ],
1428
+ )
1429
+
1430
+ return y
1431
+
1432
+
1433
+ @wp.kernel(enable_backward=False)
1434
+ def _bsr_mm_count_coeffs(
1435
+ y_ncol: int,
1436
+ z_nnz: int,
1437
+ x_offsets: wp.array(dtype=int),
1438
+ x_columns: wp.array(dtype=int),
1439
+ y_offsets: wp.array(dtype=int),
1440
+ y_columns: wp.array(dtype=int),
1441
+ row_min: wp.array(dtype=int),
1442
+ block_counts: wp.array(dtype=int),
1443
+ ):
1444
+ row = wp.tid()
1445
+ row_count = int(0)
1446
+
1447
+ x_beg = x_offsets[row]
1448
+ x_end = x_offsets[row + 1]
1449
+
1450
+ min_col = y_ncol
1451
+ max_col = int(0)
1452
+
1453
+ for x_block in range(x_beg, x_end):
1454
+ x_col = x_columns[x_block]
1455
+ y_row_end = y_offsets[x_col + 1]
1456
+ y_row_beg = y_offsets[x_col]
1457
+ block_count = y_row_end - y_row_beg
1458
+ if block_count != 0:
1459
+ min_col = wp.min(y_columns[y_row_beg], min_col)
1460
+ max_col = wp.max(y_columns[y_row_end - 1], max_col)
1461
+
1462
+ block_counts[x_block + 1] = block_count
1463
+ row_count += block_count
1464
+
1465
+ if row_count > wp.max(0, max_col - min_col):
1466
+ row_min[row] = min_col
1467
+ block_counts[x_end] = max_col + 1 - min_col
1468
+ for x_block in range(x_beg, x_end - 1):
1469
+ block_counts[x_block + 1] = 0
1470
+ else:
1471
+ row_min[row] = -1
1472
+
1473
+ if row == 0:
1474
+ block_counts[0] = z_nnz
1475
+
1476
+
1477
+ @wp.kernel(enable_backward=False)
1478
+ def _bsr_mm_list_coeffs(
1479
+ x_nrow: int,
1480
+ x_offsets: wp.array(dtype=int),
1481
+ x_columns: wp.array(dtype=int),
1482
+ y_offsets: wp.array(dtype=int),
1483
+ y_columns: wp.array(dtype=int),
1484
+ mm_row_min: wp.array(dtype=int),
1485
+ mm_offsets: wp.array(dtype=int),
1486
+ mm_rows: wp.array(dtype=int),
1487
+ mm_cols: wp.array(dtype=int),
1488
+ ):
1489
+ x_block = wp.tid()
1490
+ mm_block = mm_offsets[x_block]
1491
+
1492
+ row = _bsr_row_index(x_offsets, x_nrow, x_block)
1493
+ if row == -1:
1494
+ return
1495
+
1496
+ row_min_col = mm_row_min[row]
1497
+ if row_min_col != -1:
1498
+ x_col = x_columns[x_block]
1499
+
1500
+ y_beg = y_offsets[x_col]
1501
+ y_end = y_offsets[x_col + 1]
1502
+
1503
+ for y_block in range(y_beg, y_end):
1504
+ col = y_columns[y_block]
1505
+ mm_rows[mm_block + col - row_min_col] = row
1506
+ mm_cols[mm_block + col - row_min_col] = col
1507
+
1508
+ return
1509
+
1510
+ x_col = x_columns[x_block]
1511
+ y_beg = y_offsets[x_col]
1512
+ y_end = y_offsets[x_col + 1]
1513
+ for y_block in range(y_beg, y_end):
1514
+ mm_cols[mm_block] = y_columns[y_block]
1515
+ mm_rows[mm_block] = row
1516
+ mm_block += 1
1517
+
1518
+
1519
+ @wp.kernel
1520
+ def _bsr_mm_compute_values(
1521
+ alpha: Any,
1522
+ x_offsets: wp.array(dtype=int),
1523
+ x_columns: wp.array(dtype=int),
1524
+ x_values: wp.array(dtype=Any),
1525
+ y_offsets: wp.array(dtype=int),
1526
+ y_columns: wp.array(dtype=int),
1527
+ y_values: wp.array(dtype=Any),
1528
+ mm_row_count: int,
1529
+ mm_offsets: wp.array(dtype=int),
1530
+ mm_cols: wp.array(dtype=int),
1531
+ mm_values: wp.array(dtype=Any),
1532
+ ):
1533
+ mm_block = wp.tid()
1534
+
1535
+ row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
1536
+ if row == -1:
1537
+ return
1538
+
1539
+ col = mm_cols[mm_block]
1540
+
1541
+ mm_val = mm_values.dtype(type(alpha)(0.0))
1542
+
1543
+ x_beg = x_offsets[row]
1544
+ x_end = x_offsets[row + 1]
1545
+ for x_block in range(x_beg, x_end):
1546
+ x_col = x_columns[x_block]
1547
+ y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
1548
+ if y_block != -1:
1549
+ mm_val += x_values[x_block] * y_values[y_block]
1550
+
1551
+ mm_values[mm_block] += alpha * mm_val
1552
+
1553
+
1554
+ class bsr_mm_work_arrays:
1555
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
1556
+
1557
+ def __init__(self):
1558
+ self._reset(None)
1559
+
1560
+ def _reset(self, device):
1561
+ self.device = device
1562
+ self._mm_row_min = None
1563
+ self._mm_block_counts = None
1564
+ self._mm_rows = None
1565
+ self._mm_cols = None
1566
+ self._old_z_values = None
1567
+ self._old_z_offsets = None
1568
+ self._old_z_columns = None
1569
+ self._mm_nnz = 0
1570
+
1571
+ def _allocate_stage_1(self, device, x_nnz: int, z: BsrMatrix, beta: float, z_aliasing: bool):
1572
+ if self.device != device:
1573
+ self._reset(device)
1574
+
1575
+ # Allocations that do not depend on any computation
1576
+ z_nnz = z.nnz_sync()
1577
+ self._copied_z_nnz = z_nnz if beta != 0.0 or z_aliasing else 0
1578
+
1579
+ if self._mm_row_min is None or self._mm_block_counts.size < z.nrow + 1:
1580
+ self._mm_row_min = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
1581
+ if self._mm_block_counts is None or self._mm_block_counts.size < x_nnz + 1:
1582
+ self._mm_block_counts = wp.empty(shape=(x_nnz + 1,), dtype=int, device=self.device)
1583
+
1584
+ if self._copied_z_nnz > 0:
1585
+ if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
1586
+ self._old_z_values = wp.empty(shape=(self._copied_z_nnz,), dtype=z.values.dtype, device=self.device)
1587
+
1588
+ if z_aliasing:
1589
+ if self._old_z_columns is None or self._old_z_columns.size < z_nnz:
1590
+ self._old_z_columns = wp.empty(shape=(z_nnz,), dtype=z.columns.dtype, device=self.device)
1591
+ if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
1592
+ self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
1593
+
1594
+ def _allocate_stage_2(self, mm_nnz: int):
1595
+ # Allocations that depend on unmerged nnz estimate
1596
+ self._mm_nnz = mm_nnz
1597
+ if self._mm_rows is None or self._mm_rows.size < mm_nnz:
1598
+ self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
1599
+ if self._mm_cols is None or self._mm_cols.size < mm_nnz:
1600
+ self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
1601
+
1602
+
1603
+ def bsr_mm(
1604
+ x: BsrMatrixOrExpression[BlockType[Rows, Any, Scalar]],
1605
+ y: BsrMatrixOrExpression[BlockType[Any, Cols, Scalar]],
1606
+ z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
1607
+ alpha: Scalar = 1.0,
1608
+ beta: Scalar = 0.0,
1609
+ masked: bool = False,
1610
+ work_arrays: Optional[bsr_mm_work_arrays] = None,
1611
+ reuse_topology: bool = False,
1612
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1613
+ """
1614
+ Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
1615
+
1616
+ The ``x``, ``y`` and ``z`` matrices are allowed to alias.
1617
+ If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
1618
+
1619
+ Args:
1620
+ x: Read-only left factor of the matrix-matrix product.
1621
+ y: Read-only right factor of the matrix-matrix product.
1622
+ z: Mutable left-hand-side. If ``z`` is not provided, it will be allocated and treated as zero.
1623
+ alpha: Uniform scaling factor for the ``x @ y`` product
1624
+ beta: Uniform scaling factor for ``z``
1625
+ masked: If ``True``, ignore all blocks from ``x @ y`` which are not existing non-zeros of ``y``
1626
+ work_arrays: In most cases, this function will require the use of temporary storage.
1627
+ This storage can be reused across calls by passing an instance of
1628
+ :class:`bsr_mm_work_arrays` in ``work_arrays``.
1629
+ reuse_topology: If ``True``, reuse the product topology information
1630
+ stored in ``work_arrays`` rather than recompute it from scratch.
1631
+ The matrices ``x``, ``y`` and ``z`` must be structurally similar to
1632
+ the previous call in which ``work_arrays`` were populated.
1633
+ This is necessary for ``bsr_mm`` to be captured in a CUDA graph.
1634
+ """
1635
+
1636
+ x, x_scale = _extract_matrix_and_scale(x)
1637
+ alpha *= x_scale
1638
+ y, y_scale = _extract_matrix_and_scale(y)
1639
+ alpha *= y_scale
1640
+
1641
+ if z is None:
1642
+ if masked:
1643
+ raise ValueError("Left-hand-side 'z' matrix must be provided for masked multiplication")
1644
+
1645
+ # If not output matrix is provided, allocate it for convenience
1646
+ z_block_shape = (x.block_shape[0], y.block_shape[1])
1647
+ if z_block_shape == (1, 1):
1648
+ z_block_type = x.scalar_type
1649
+ else:
1650
+ z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
1651
+ z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
1652
+ beta = 0.0
1653
+
1654
+ if x.values.device != y.values.device or x.values.device != z.values.device:
1655
+ raise ValueError("All arguments must reside on the same device")
1656
+
1657
+ if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
1658
+ raise ValueError("Matrices must have the same scalar type")
1659
+
1660
+ if (
1661
+ x.block_shape[0] != z.block_shape[0]
1662
+ or y.block_shape[1] != z.block_shape[1]
1663
+ or x.block_shape[1] != y.block_shape[0]
1664
+ ):
1665
+ raise ValueError("Incompatible block sizes for matrix multiplication")
1666
+
1667
+ if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
1668
+ raise ValueError("Incompatible number of rows/columns for matrix multiplication")
1669
+
1670
+ device = z.values.device
1671
+
1672
+ if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
1673
+ # Easy case
1674
+ return bsr_scale(z, beta)
1675
+
1676
+ z_aliasing = z == x or z == y
1677
+
1678
+ if masked:
1679
+ # no need to copy z, scale in-place
1680
+ copied_z_nnz = 0
1681
+ mm_nnz = z.nnz
1682
+
1683
+ if z_aliasing:
1684
+ raise ValueError("`masked=True` is not supported for aliased inputs")
1685
+
1686
+ if beta == 0.0:
1687
+ # do not bsr_scale(0), this would not preserve topology
1688
+ z.values.zero_()
1689
+ else:
1690
+ bsr_scale(z, beta)
1691
+ elif reuse_topology:
1692
+ if work_arrays is None:
1693
+ raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
1694
+
1695
+ copied_z_nnz = work_arrays._copied_z_nnz
1696
+ mm_nnz = work_arrays._mm_nnz
1697
+ else:
1698
+ if device.is_capturing:
1699
+ raise RuntimeError("`bsr_mm` requires `reuse_topology=True` for use in graph capture")
1700
+
1701
+ if work_arrays is None:
1702
+ work_arrays = bsr_mm_work_arrays()
1703
+
1704
+ work_arrays._allocate_stage_1(device, x.nnz, z, beta, z_aliasing)
1705
+ copied_z_nnz = work_arrays._copied_z_nnz
1706
+
1707
+ # Prefix sum of number of (unmerged) mm blocks per row
1708
+ work_arrays._mm_block_counts.zero_()
1709
+ wp.launch(
1710
+ kernel=_bsr_mm_count_coeffs,
1711
+ device=device,
1712
+ dim=z.nrow,
1713
+ inputs=[
1714
+ y.ncol,
1715
+ copied_z_nnz,
1716
+ x.offsets,
1717
+ x.columns,
1718
+ y.offsets,
1719
+ y.columns,
1720
+ work_arrays._mm_row_min,
1721
+ work_arrays._mm_block_counts,
1722
+ ],
1723
+ )
1724
+ warp.utils.array_scan(work_arrays._mm_block_counts, work_arrays._mm_block_counts)
1725
+
1726
+ # Get back total counts on host -- we need a synchronization here
1727
+ # Use pinned buffer from z, we are going to need it later anyway
1728
+ nnz_buf, _ = z._nnz_transfer_buf_and_event()
1729
+ stream = wp.get_stream(device) if device.is_cuda else None
1730
+ wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
1731
+ if device.is_cuda:
1732
+ wp.synchronize_stream(stream)
1733
+ mm_nnz = int(nnz_buf.numpy()[0])
1734
+
1735
+ if mm_nnz == copied_z_nnz:
1736
+ # x@y = 0
1737
+ return bsr_scale(z, beta)
1738
+
1739
+ work_arrays._allocate_stage_2(mm_nnz)
1740
+
1741
+ # If z has a non-zero scale, save current data before overwriting it
1742
+ if copied_z_nnz > 0:
1743
+ # Copy z row and column indices
1744
+ wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1745
+ z.uncompress_rows(out=work_arrays._mm_rows)
1746
+ if z_aliasing:
1747
+ # If z is aliasing with x or y, need to save topology as well
1748
+ wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1749
+ wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1750
+
1751
+ # Fill unmerged mm blocks rows and columns
1752
+ work_arrays._mm_rows[copied_z_nnz:].fill_(-1)
1753
+ wp.launch(
1754
+ kernel=_bsr_mm_list_coeffs,
1755
+ device=device,
1756
+ dim=x.nnz,
1757
+ inputs=[
1758
+ x.nrow,
1759
+ x.offsets,
1760
+ x.columns,
1761
+ y.offsets,
1762
+ y.columns,
1763
+ work_arrays._mm_row_min,
1764
+ work_arrays._mm_block_counts,
1765
+ work_arrays._mm_rows,
1766
+ work_arrays._mm_cols,
1767
+ ],
1768
+ )
1769
+
1770
+ alpha = z.scalar_type(alpha)
1771
+ beta = z.scalar_type(beta)
1772
+
1773
+ if copied_z_nnz > 0:
1774
+ # Save current z values in temporary buffer
1775
+ wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1776
+
1777
+ if not masked:
1778
+ # Increase dest array size if needed
1779
+ if z.columns.shape[0] < mm_nnz:
1780
+ z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1781
+
1782
+ from warp.context import runtime
1783
+
1784
+ if device.is_cpu:
1785
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
1786
+ else:
1787
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
1788
+
1789
+ nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
1790
+
1791
+ with wp.ScopedDevice(z.device):
1792
+ native_func(
1793
+ z.block_shape[0],
1794
+ z.block_shape[1],
1795
+ z.nrow,
1796
+ mm_nnz,
1797
+ ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1798
+ ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1799
+ 0,
1800
+ False,
1801
+ masked,
1802
+ ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1803
+ ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1804
+ 0,
1805
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1806
+ nnz_event,
1807
+ )
1808
+
1809
+ # Resize z to fit mm result if necessary
1810
+ # If we are not reusing the product topology, this needs another synchronization
1811
+ if not reuse_topology:
1812
+ work_arrays.result_nnz = z.nnz_sync()
1813
+
1814
+ _bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
1815
+ z.values.zero_()
1816
+
1817
+ if copied_z_nnz > 0:
1818
+ # Add back original z values
1819
+ wp.launch(
1820
+ kernel=_bsr_axpy_add_block,
1821
+ device=device,
1822
+ dim=copied_z_nnz,
1823
+ inputs=[
1824
+ 0,
1825
+ beta,
1826
+ work_arrays._mm_rows,
1827
+ work_arrays._mm_cols,
1828
+ z.offsets,
1829
+ z.columns,
1830
+ work_arrays._old_z_values,
1831
+ z.values,
1832
+ ],
1833
+ )
1834
+
1835
+ # Add mm blocks to z values
1836
+ if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
1837
+ # Result block type is scalar, but operands are matrices
1838
+ # Cast result to (1x1) matrix to perform multiplication
1839
+ mm_values = z.values.view(wp.mat(shape=(1, 1), dtype=z.scalar_type))
1840
+ else:
1841
+ mm_values = z.values
1842
+
1843
+ wp.launch(
1844
+ kernel=_bsr_mm_compute_values,
1845
+ device=device,
1846
+ dim=z.nnz,
1847
+ inputs=[
1848
+ alpha,
1849
+ work_arrays._old_z_offsets if x == z else x.offsets,
1850
+ work_arrays._old_z_columns if x == z else x.columns,
1851
+ work_arrays._old_z_values if x == z else x.values,
1852
+ work_arrays._old_z_offsets if y == z else y.offsets,
1853
+ work_arrays._old_z_columns if y == z else y.columns,
1854
+ work_arrays._old_z_values if y == z else y.values,
1855
+ z.nrow,
1856
+ z.offsets,
1857
+ z.columns,
1858
+ mm_values,
1859
+ ],
1860
+ )
1861
+
1862
+ return z
1863
+
1864
+
1865
+ @wp.kernel
1866
+ def _bsr_mv_kernel(
1867
+ alpha: Any,
1868
+ A_offsets: wp.array(dtype=int),
1869
+ A_columns: wp.array(dtype=int),
1870
+ A_values: wp.array(dtype=Any),
1871
+ x: wp.array(dtype=Any),
1872
+ beta: Any,
1873
+ y: wp.array(dtype=Any),
1874
+ ):
1875
+ row = wp.tid()
1876
+
1877
+ # zero-initialize with type of y elements
1878
+ scalar_zero = type(alpha)(0)
1879
+ v = y.dtype(scalar_zero)
1880
+
1881
+ if alpha != scalar_zero:
1882
+ beg = A_offsets[row]
1883
+ end = A_offsets[row + 1]
1884
+ for block in range(beg, end):
1885
+ v += A_values[block] * x[A_columns[block]]
1886
+ v *= alpha
1887
+
1888
+ if beta != scalar_zero:
1889
+ v += beta * y[row]
1890
+
1891
+ y[row] = v
1892
+
1893
+
1894
+ @wp.kernel
1895
+ def _bsr_mv_transpose_kernel(
1896
+ alpha: Any,
1897
+ A_offsets: wp.array(dtype=int),
1898
+ A_columns: wp.array(dtype=int),
1899
+ A_values: wp.array(dtype=Any),
1900
+ x: wp.array(dtype=Any),
1901
+ y: wp.array(dtype=Any),
1902
+ ):
1903
+ row = wp.tid()
1904
+ beg = A_offsets[row]
1905
+ end = A_offsets[row + 1]
1906
+ xr = alpha * x[row]
1907
+ for block in range(beg, end):
1908
+ v = wp.transpose(A_values[block]) * xr
1909
+ wp.atomic_add(y, A_columns[block], v)
1910
+
1911
+
1912
+ def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
1913
+ # cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
1914
+
1915
+ scalar_count = array.size * type_length(array.dtype)
1916
+ if scalar_count != expected_scalar_count:
1917
+ raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
1918
+
1919
+ if array.ndim == 1 and types_equal(array.dtype, dtype):
1920
+ return array
1921
+
1922
+ if type_scalar_type(array.dtype) != type_scalar_type(dtype):
1923
+ raise ValueError(f"Incompatible scalar types, {type_repr(array.dtype)} vs {type_repr(dtype)}")
1924
+
1925
+ if array.ndim > 2:
1926
+ raise ValueError(f"Incompatible array number of dimensions {array.ndim}")
1927
+
1928
+ if not array.is_contiguous:
1929
+ raise ValueError("Array must be contiguous")
1930
+
1931
+ vec_length = type_length(dtype)
1932
+ vec_count = scalar_count // vec_length
1933
+ if vec_count * vec_length != scalar_count:
1934
+ raise ValueError(
1935
+ f"Array of shape {array.shape} and type {type_repr(array.dtype)} cannot be reshaped to an array of type {type_repr(dtype)}"
1936
+ )
1937
+
1938
+ def vec_view(array):
1939
+ return wp.array(
1940
+ data=None,
1941
+ ptr=array.ptr,
1942
+ capacity=array.capacity,
1943
+ device=array.device,
1944
+ dtype=dtype,
1945
+ shape=vec_count,
1946
+ grad=None if array.grad is None else vec_view(array.grad),
1947
+ )
1948
+
1949
+ view = vec_view(array)
1950
+ view._ref = array
1951
+ return view
1952
+
1953
+
1954
+ def bsr_mv(
1955
+ A: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
1956
+ x: "Array[Vector[Cols, Scalar] | Scalar]",
1957
+ y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1958
+ alpha: Scalar = 1.0,
1959
+ beta: Scalar = 0.0,
1960
+ transpose: bool = False,
1961
+ work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1962
+ ) -> "Array[Vector[Rows, Scalar] | Scalar]":
1963
+ """Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
1964
+
1965
+ The ``x`` and ``y`` vectors are allowed to alias.
1966
+
1967
+ Args:
1968
+ A: Read-only, left matrix factor of the matrix-vector product.
1969
+ x: Read-only, right vector factor of the matrix-vector product.
1970
+ y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
1971
+ alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
1972
+ beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
1973
+ transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
1974
+ work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
1975
+ If provided, the ``work_buffer`` array will be used for this purpose,
1976
+ otherwise a temporary allocation will be performed.
1977
+ """
1978
+
1979
+ A, A_scale = _extract_matrix_and_scale(A)
1980
+ alpha *= A_scale
1981
+
1982
+ if transpose:
1983
+ block_shape = A.block_shape[1], A.block_shape[0]
1984
+ nrow, ncol = A.ncol, A.nrow
1985
+ else:
1986
+ block_shape = A.block_shape
1987
+ nrow, ncol = A.nrow, A.ncol
1988
+
1989
+ if y is None:
1990
+ # If no output array is provided, allocate one for convenience
1991
+ y_vec_len = block_shape[0]
1992
+ y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1993
+ y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype)
1994
+ beta = 0.0
1995
+
1996
+ alpha = A.scalar_type(alpha)
1997
+ beta = A.scalar_type(beta)
1998
+
1999
+ if A.values.device != x.device or A.values.device != y.device:
2000
+ raise ValueError("A, x, and y must reside on the same device")
2001
+
2002
+ if x.ptr == y.ptr:
2003
+ # Aliasing case, need temporary storage
2004
+ if work_buffer is None:
2005
+ work_buffer = wp.empty_like(y)
2006
+ elif work_buffer.size < y.size:
2007
+ raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
2008
+ elif not types_equal(work_buffer.dtype, y.dtype):
2009
+ raise ValueError(f"Work buffer must have same data type as y, {type_repr(y.dtype)}")
2010
+
2011
+ # Save old y values before overwriting vector
2012
+ wp.copy(dest=work_buffer, src=y, count=y.size)
2013
+ x = work_buffer
2014
+
2015
+ # Promote scalar vectors to length-1 vecs and conversely
2016
+ if type_is_matrix(A.values.dtype):
2017
+ x_dtype = wp.vec(length=block_shape[1], dtype=A.scalar_type)
2018
+ y_dtype = wp.vec(length=block_shape[0], dtype=A.scalar_type)
2019
+ else:
2020
+ x_dtype = A.scalar_type
2021
+ y_dtype = A.scalar_type
2022
+
2023
+ try:
2024
+ x_view = _vec_array_view(x, x_dtype, expected_scalar_count=ncol * block_shape[1])
2025
+ except ValueError as err:
2026
+ raise ValueError("Incompatible 'x' vector for bsr_mv") from err
2027
+ try:
2028
+ y_view = _vec_array_view(y, y_dtype, expected_scalar_count=nrow * block_shape[0])
2029
+ except ValueError as err:
2030
+ raise ValueError("Incompatible 'y' vector for bsr_mv") from err
2031
+
2032
+ if transpose:
2033
+ if beta.value == 0.0:
2034
+ y.zero_()
2035
+ elif beta.value != 1.0:
2036
+ wp.launch(
2037
+ kernel=_bsr_scale_kernel,
2038
+ device=y.device,
2039
+ dim=y.shape[0],
2040
+ inputs=[beta, y],
2041
+ )
2042
+ if alpha.value != 0.0:
2043
+ wp.launch(
2044
+ kernel=_bsr_mv_transpose_kernel,
2045
+ device=A.values.device,
2046
+ dim=ncol,
2047
+ inputs=[alpha, A.offsets, A.columns, A.values, x_view, y_view],
2048
+ )
2049
+ else:
2050
+ wp.launch(
2051
+ kernel=_bsr_mv_kernel,
2052
+ device=A.values.device,
2053
+ dim=nrow,
2054
+ inputs=[alpha, A.offsets, A.columns, A.values, x_view, beta, y_view],
2055
+ )
2056
+
2057
+ return y