warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/tests/test_sparse.py CHANGED
@@ -18,6 +18,7 @@ import unittest
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
+ from warp._src.sparse import bsr_set_zero
21
22
  from warp.sparse import (
22
23
  bsr_assign,
23
24
  bsr_axpy,
@@ -59,6 +60,17 @@ def _triplets_to_dense(shape, rows, cols, values):
59
60
  return mat
60
61
 
61
62
 
63
+ def _bsr_pruned(bsr):
64
+ return bsr_from_triplets(
65
+ rows_of_blocks=bsr.nrow,
66
+ cols_of_blocks=bsr.ncol,
67
+ rows=bsr.uncompress_rows(),
68
+ columns=bsr.columns,
69
+ values=bsr.values,
70
+ prune_numerical_zeros=True,
71
+ )
72
+
73
+
62
74
  def _bsr_to_dense(bsr):
63
75
  mat = np.zeros(bsr.shape)
64
76
 
@@ -113,7 +125,7 @@ def test_bsr_from_triplets(test, device):
113
125
 
114
126
  ref = _triplets_to_dense(shape, rows, cols, vals)
115
127
 
116
- bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
128
+ bsr = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=float), device=device)
117
129
  bsr_set_from_triplets(bsr, rows, cols, vals)
118
130
  test.assertEqual(bsr.block_size, block_shape[0] * block_shape[1])
119
131
 
@@ -218,7 +230,7 @@ def test_bsr_get_set_diag(test, device):
218
230
  vals_np = rng.random(size=(nnz, block_shape[0], block_shape[1]))
219
231
  vals = wp.array(vals_np, dtype=float, device=device)
220
232
 
221
- bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
233
+ bsr = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=float), device=device)
222
234
  bsr_set_from_triplets(bsr, rows, cols, vals)
223
235
 
224
236
  diag = bsr_get_diag(bsr)
@@ -274,14 +286,13 @@ def test_bsr_split_merge(test, device):
274
286
  block_shape = (4, 2)
275
287
  nrow = 4
276
288
  ncol = 8
277
- shape = (block_shape[0] * nrow, block_shape[1] * ncol)
278
289
  n = 20
279
290
 
280
291
  rows = wp.array(rng.integers(0, high=nrow, size=n, dtype=int), dtype=int, device=device)
281
292
  cols = wp.array(rng.integers(0, high=ncol, size=n, dtype=int), dtype=int, device=device)
282
293
  vals = wp.array(rng.random(size=(n, block_shape[0], block_shape[1])), dtype=float, device=device)
283
294
 
284
- bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=float), device=device)
295
+ bsr = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=float), device=device)
285
296
  bsr_set_from_triplets(bsr, rows, cols, vals)
286
297
  ref = _bsr_to_dense(bsr)
287
298
 
@@ -359,13 +370,13 @@ def make_test_bsr_transpose(block_shape, scalar_type):
359
370
  vals_np = rng.random(size=(nnz, block_shape[0], block_shape[1]))
360
371
  vals = wp.array(vals_np, dtype=scalar_type, device=device).reshape((nnz, block_shape[0], block_shape[1]))
361
372
 
362
- bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
373
+ bsr = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
363
374
  bsr_set_from_triplets(bsr, rows, cols, vals)
364
375
  ref = 2.0 * np.transpose(_bsr_to_dense(bsr))
365
376
 
366
- bsr_transposed = (2.0 * bsr).transpose()
377
+ bsr_transposed = (2.0 * bsr).transpose().eval()
367
378
 
368
- res = _bsr_to_dense(bsr_transposed.eval())
379
+ res = _bsr_to_dense(bsr_transposed)
369
380
  assert_np_equal(res, ref, 0.0001)
370
381
 
371
382
  if block_shape[0] != block_shape[-1]:
@@ -373,6 +384,22 @@ def make_test_bsr_transpose(block_shape, scalar_type):
373
384
  with test.assertRaisesRegex(ValueError, "Destination block shape must be"):
374
385
  bsr_set_transpose(dest=bsr, src=bsr)
375
386
 
387
+ # test masked transpose
388
+ # remove some non zeros from src and dest matrices
389
+ bsr_set_from_triplets(bsr, rows[:3], cols[:3], vals[:3])
390
+ bsr_transposed = bsr_from_triplets(
391
+ bsr_transposed.nrow,
392
+ bsr_transposed.ncol,
393
+ bsr_transposed.uncompress_rows()[:3],
394
+ bsr_transposed.columns[:3],
395
+ bsr_transposed.values[:3],
396
+ )
397
+
398
+ assert_np_equal(bsr_transposed.uncompress_rows().numpy()[:3], [0, 1, 1])
399
+ assert_np_equal(bsr_transposed.columns.numpy()[:3], [2, 0, 2])
400
+ bsr_set_transpose(bsr_transposed, bsr, masked=True)
401
+ assert _bsr_pruned(bsr_transposed).nnz_sync() == 2
402
+
376
403
  return test_bsr_transpose
377
404
 
378
405
 
@@ -392,7 +419,7 @@ def make_test_bsr_axpy(block_shape, scalar_type):
392
419
  x_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
393
420
  x_vals = x_vals.reshape((nnz, block_shape[0], block_shape[1]))
394
421
 
395
- x = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
422
+ x = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
396
423
  bsr_set_from_triplets(x, x_rows, x_cols, x_vals)
397
424
 
398
425
  y_rows = wp.array(rng.integers(0, high=nrow, size=nnz, dtype=int), dtype=int, device=device)
@@ -400,7 +427,7 @@ def make_test_bsr_axpy(block_shape, scalar_type):
400
427
  y_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
401
428
  y_vals = y_vals.reshape((nnz, block_shape[0], block_shape[1]))
402
429
 
403
- y = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
430
+ y = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
404
431
  bsr_set_from_triplets(y, y_rows, y_cols, y_vals)
405
432
 
406
433
  work_arrays = bsr_axpy_work_arrays()
@@ -457,7 +484,7 @@ def make_test_bsr_mm(block_shape, scalar_type):
457
484
  x_vals = wp.array(rng.random(size=(nnz, x_block_shape[0], x_block_shape[1])), dtype=scalar_type, device=device)
458
485
  x_vals = x_vals.reshape((nnz, x_block_shape[0], x_block_shape[1]))
459
486
 
460
- x = bsr_zeros(x_nrow, x_ncol, wp.types.matrix(shape=x_block_shape, dtype=scalar_type), device=device)
487
+ x = bsr_zeros(x_nrow, x_ncol, wp._src.types.matrix(shape=x_block_shape, dtype=scalar_type), device=device)
461
488
  bsr_set_from_triplets(x, x_rows, x_cols, x_vals)
462
489
 
463
490
  y_rows = wp.array(rng.integers(0, high=y_nrow, size=nnz, dtype=int), dtype=int, device=device)
@@ -465,7 +492,7 @@ def make_test_bsr_mm(block_shape, scalar_type):
465
492
  y_vals = wp.array(rng.random(size=(nnz, y_block_shape[0], y_block_shape[1])), dtype=scalar_type, device=device)
466
493
  y_vals = y_vals.reshape((nnz, y_block_shape[0], y_block_shape[1]))
467
494
 
468
- y = bsr_zeros(y_nrow, y_ncol, wp.types.matrix(shape=y_block_shape, dtype=scalar_type), device=device)
495
+ y = bsr_zeros(y_nrow, y_ncol, wp._src.types.matrix(shape=y_block_shape, dtype=scalar_type), device=device)
469
496
  bsr_set_from_triplets(y, y_rows, y_cols, y_vals)
470
497
 
471
498
  z_rows = wp.array(rng.integers(0, high=z_nrow, size=nnz, dtype=int), dtype=int, device=device)
@@ -473,7 +500,7 @@ def make_test_bsr_mm(block_shape, scalar_type):
473
500
  z_vals = wp.array(rng.random(size=(nnz, z_block_shape[0], z_block_shape[1])), dtype=scalar_type, device=device)
474
501
  z_vals = z_vals.reshape((nnz, z_block_shape[0], z_block_shape[1]))
475
502
 
476
- z = bsr_zeros(z_nrow, z_ncol, wp.types.matrix(shape=z_block_shape, dtype=scalar_type), device=device)
503
+ z = bsr_zeros(z_nrow, z_ncol, wp._src.types.matrix(shape=z_block_shape, dtype=scalar_type), device=device)
477
504
  bsr_set_from_triplets(z, z_rows, z_cols, z_vals)
478
505
 
479
506
  work_arrays = bsr_mm_work_arrays()
@@ -544,7 +571,7 @@ def make_test_bsr_mv(block_shape, scalar_type):
544
571
  A_vals = wp.array(rng.random(size=(nnz, block_shape[0], block_shape[1])), dtype=scalar_type, device=device)
545
572
  A_vals = A_vals.reshape((nnz, block_shape[0], block_shape[1]))
546
573
 
547
- A = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
574
+ A = bsr_zeros(nrow, ncol, wp._src.types.matrix(shape=block_shape, dtype=scalar_type), device=device)
548
575
  bsr_set_from_triplets(A, A_rows, A_cols, A_vals)
549
576
 
550
577
  if block_shape[1] == 1:
@@ -664,6 +691,83 @@ def make_test_bsr_multiply_deep(block_shape, scalar_type):
664
691
  return test_bsr_multiply_deep
665
692
 
666
693
 
694
+ def test_bsr_mm_max_new_nnz(test, device):
695
+ """Test that BSR matrix multiplication with max_new_nnz works"""
696
+ A = bsr_from_triplets(
697
+ 2,
698
+ 2,
699
+ wp.array([0, 0, 1, 1], dtype=int, device=device),
700
+ wp.array([0, 1, 0, 1], dtype=int, device=device),
701
+ wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, device=device),
702
+ )
703
+ B = bsr_from_triplets(
704
+ 2,
705
+ 2,
706
+ wp.array([0, 0, 1, 1], dtype=int, device=device),
707
+ wp.array([0, 1, 0, 1], dtype=int, device=device),
708
+ wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, device=device),
709
+ )
710
+ C = bsr_zeros(2, 2, wp.float32, device=device)
711
+
712
+ # max_new_nnz big enough
713
+ bsr_mm(A, B, C, max_new_nnz=4)
714
+ test.assertEqual(C.nnz_sync(), 4)
715
+
716
+ bsr_set_zero(C)
717
+ test.assertEqual(C.nnz_sync(), 0)
718
+
719
+ # max_new_nnz too small, check warning
720
+ capture = StdOutCapture()
721
+ capture.begin()
722
+ bsr_mm(A, B, C, max_new_nnz=2)
723
+ test.assertEqual(C.nnz_sync(), 2)
724
+ output = capture.end()
725
+
726
+ # Check that the output contains warnings about "max_new_nnz" being exceeded.
727
+ # Older Windows C runtimes have a bug where stdout sometimes does not get properly flushed.
728
+ if output != "" or sys.platform != "win32":
729
+ test.assertRegex(output, r"exceeded")
730
+
731
+
732
+ def test_capturability(test, device):
733
+ """Test that BSR operations are graph-capturable"""
734
+
735
+ N = 5
736
+ M = 3
737
+
738
+ C = bsr_diag(wp.zeros(N, dtype=wp.mat33, device=device))
739
+
740
+ rows = wp.array([3, 4, 2, 0, 1], dtype=int, device=device)
741
+ columns = wp.array([2, 0, 1, 2, 1], dtype=int, device=device)
742
+ values = wp.ones(5, dtype=wp.mat33, device=device)
743
+
744
+ def test_body():
745
+ A = bsr_from_triplets(
746
+ N,
747
+ M,
748
+ rows=rows,
749
+ columns=columns,
750
+ values=values,
751
+ )
752
+ B = A + bsr_copy(A * 2.0)
753
+ bsr_mm(A, bsr_transposed(B), C, max_new_nnz=N * N)
754
+
755
+ # ensure necessary modules are loaded and reset result
756
+ test_body()
757
+ bsr_set_zero(C)
758
+ test.assertEqual(C.nnz_sync(), 0)
759
+
760
+ with wp.ScopedDevice(device):
761
+ with wp.ScopedCapture(force_module_load=False) as capture:
762
+ test_body()
763
+
764
+ assert_array_equal(bsr_get_diag(C), wp.zeros(N, dtype=wp.mat33, device=device))
765
+
766
+ wp.capture_launch(capture.graph)
767
+ test.assertEqual(C.nnz_sync(), 9)
768
+ assert_array_equal(bsr_get_diag(C), wp.full(N, value=wp.mat33(9.0), dtype=wp.mat33, device=device))
769
+
770
+
667
771
  devices = get_test_devices()
668
772
  cuda_test_devices = get_selected_cuda_test_devices()
669
773
 
@@ -676,7 +780,9 @@ class TestSparse(unittest.TestCase):
676
780
  diag_bsr = bsr_diag(diag=np.eye(bsize, dtype=float) * 2.0, rows_of_blocks=nrow)
677
781
  diag_copy = bsr_copy(diag_bsr, scalar_type=wp.float64)
678
782
 
679
- self.assertTrue(wp.types.types_equal(diag_copy.values.dtype, wp.mat(shape=(bsize, bsize), dtype=wp.float64)))
783
+ self.assertTrue(
784
+ wp._src.types.types_equal(diag_copy.values.dtype, wp.mat(shape=(bsize, bsize), dtype=wp.float64))
785
+ )
680
786
  bsr_scale(x=diag_copy, alpha=0.5)
681
787
 
682
788
  res = _bsr_to_dense(diag_copy)
@@ -686,7 +792,10 @@ class TestSparse(unittest.TestCase):
686
792
  bsr_scale(x=diag_copy, alpha=0.0)
687
793
  self.assertEqual(diag_copy.nrow, nrow)
688
794
  self.assertEqual(diag_copy.ncol, nrow)
689
- self.assertEqual(diag_copy.nnz, 0)
795
+ self.assertEqual(diag_copy.nnz, diag_bsr.nnz)
796
+
797
+ diag_pruned = _bsr_pruned(diag_copy)
798
+ self.assertEqual(diag_pruned.nnz_sync(), 0)
690
799
 
691
800
 
692
801
  add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets, devices=devices)
@@ -728,6 +837,8 @@ add_function_test(TestSparse, "test_csr_mv", make_test_bsr_mv((1, 1), wp.float32
728
837
  add_function_test(TestSparse, "test_bsr_mv_1_3", make_test_bsr_mv((1, 3), wp.float32), devices=devices)
729
838
  add_function_test(TestSparse, "test_bsr_mv_3_3", make_test_bsr_mv((3, 3), wp.float64), devices=devices)
730
839
 
840
+ add_function_test(TestSparse, "test_capturability", test_capturability, devices=cuda_test_devices)
841
+ add_function_test(TestSparse, "test_bsr_mm_max_new_nnz", test_bsr_mm_max_new_nnz, devices=devices, check_output=False)
731
842
 
732
843
  if __name__ == "__main__":
733
844
  wp.clear_kernel_cache()