warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -1,14 +1,29 @@
1
+ from typing import Any, Generic, Optional, Tuple, TypeVar, Union
2
+
1
3
  import warp as wp
2
4
  import warp.types
3
5
  import warp.utils
6
+ from warp.types import Array, Cols, Matrix, Rows, Scalar, Vector
7
+
8
+ # typing hints
9
+
10
+ _BlockType = TypeVar("BlockType")
11
+
12
+
13
+ class _MatrixBlockType(Matrix):
14
+ pass
4
15
 
5
- from typing import Tuple, Any, Union
6
16
 
17
+ class _ScalarBlockType(Generic[Scalar]):
18
+ pass
19
+
20
+
21
+ BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
7
22
 
8
23
  _struct_cache = dict()
9
24
 
10
25
 
11
- class BsrMatrix:
26
+ class BsrMatrix(Generic[_BlockType]):
12
27
  """Untyped base class for BSR and CSR matrices.
13
28
 
14
29
  Should not be constructed directly but through functions such as :func:`bsr_zeros`.
@@ -16,15 +31,15 @@ class BsrMatrix:
16
31
  Attributes:
17
32
  nrow (int): Number of rows of blocks
18
33
  ncol (int): Number of columns of blocks
19
- nnz (int): Number of non-zero blocks: equal to `offsets[-1]`, cached on host for convenience
20
- offsets (wp.array(dtype=int)): Array of size at least 1 + nrows containing start and end offsets og blocks in each row
21
- columns (wp.array(dtype=int)): Array of size at least equal to nnz containing block column indices
22
- values (wp.array(dtype=dtype)): Array of size at least equal to nnz containing block values
34
+ nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
35
+ offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
36
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
37
+ values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
23
38
  """
24
39
 
25
40
  @property
26
- def scalar_type(self) -> type:
27
- """Scalar type for each of the blocks' coefficients. FOr CSR matrices, this is equal to the block type"""
41
+ def scalar_type(self) -> Scalar:
42
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
28
43
  return warp.types.type_scalar_type(self.values.dtype)
29
44
 
30
45
  @property
@@ -33,20 +48,25 @@ class BsrMatrix:
33
48
  return getattr(self.values.dtype, "_shape_", (1, 1))
34
49
 
35
50
  @property
36
- def block_size(self) -> Tuple[int, int]:
37
- """Size of the individual blocks, i.e. number of rows per block times number of columsn per block"""
51
+ def block_size(self) -> int:
52
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
38
53
  return warp.types.type_length(self.values.dtype)
39
54
 
40
55
  @property
41
56
  def shape(self) -> Tuple[int, int]:
42
- """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columsn per block"""
57
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
43
58
  block_shape = self.block_shape
44
59
  return (self.nrow * block_shape[0], self.ncol * block_shape[1])
45
60
 
46
61
 
47
- def bsr_matrix_t(dtype: type):
62
+ def bsr_matrix_t(dtype: BlockType):
48
63
  dtype = wp.types.type_to_warp(dtype)
49
64
 
65
+ if not warp.types.type_is_matrix(dtype) and not dtype in warp.types.scalar_types:
66
+ raise ValueError(
67
+ f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
68
+ )
69
+
50
70
  class BsrMatrixTyped(BsrMatrix):
51
71
  nrow: int
52
72
  """Number of rows of blocks"""
@@ -79,11 +99,23 @@ def bsr_matrix_t(dtype: type):
79
99
 
80
100
 
81
101
  def bsr_zeros(
82
- rows_of_blocks: int, cols_of_blocks: int, block_type: type, device: wp.context.Devicelike = None
102
+ rows_of_blocks: int,
103
+ cols_of_blocks: int,
104
+ block_type: BlockType,
105
+ device: wp.context.Devicelike = None,
83
106
  ) -> BsrMatrix:
84
107
  """
85
- Constructs an empty BSR or CS matrix with the given shape
108
+ Constructs and returns an empty BSR or CSR matrix with the given shape
109
+
110
+ Args:
111
+ bsr: The BSR or CSR matrix to set to zero
112
+ rows_of_blocks: Number of rows of blocks
113
+ cols_of_blocks: Number of columns of blocks
114
+ block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
115
+ for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
116
+ device: Device on which to allocate the matrix arrays
86
117
  """
118
+
87
119
  bsr = bsr_matrix_t(block_type)()
88
120
 
89
121
  bsr.nrow = rows_of_blocks
@@ -110,19 +142,42 @@ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
110
142
  bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
111
143
 
112
144
 
145
+ def bsr_set_zero(bsr: BsrMatrix, rows_of_blocks: Optional[int] = None, cols_of_blocks: Optional[int] = None):
146
+ """
147
+ Sets a BSR matrix to zero, possibly changing its size
148
+
149
+ Args:
150
+ bsr: The BSR or CSR matrix to set to zero
151
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
152
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
153
+ """
154
+
155
+ if rows_of_blocks is not None:
156
+ bsr.nrow = rows_of_blocks
157
+ if cols_of_blocks is not None:
158
+ bsr.ncol = cols_of_blocks
159
+ bsr.nnz = 0
160
+ _bsr_ensure_fits(bsr)
161
+ bsr.offsets.zero_()
162
+
163
+
113
164
  def bsr_set_from_triplets(
114
- dest: BsrMatrix,
115
- rows: wp.array(dtype=int),
116
- columns: wp.array(dtype=int),
117
- values: wp.array(dtype=Any),
165
+ dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
166
+ rows: "Array[int]",
167
+ columns: "Array[int]",
168
+ values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
118
169
  ):
119
170
  """
120
- Fills a BSR matrix `dest` with values defined by COO triplets `rows`, `columns`, `values`.
171
+ Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
121
172
 
122
- Values must be either one-dimensional with data type identical to the `dest` matrix block times,
123
- or a 3d array with data type equal to the `dest` matrix scalar type.
173
+ The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
124
174
 
125
- Previous blocks of `dest` are discarded.
175
+ Args:
176
+ dest: Sparse matrix to populate
177
+ rows: Row index for each non-zero
178
+ columns: Columns index for each non-zero
179
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
180
+ to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
126
181
  """
127
182
 
128
183
  if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
@@ -138,7 +193,7 @@ def bsr_set_from_triplets(
138
193
  elif values.ndim == 3:
139
194
  if values.shape[1:] != dest.block_shape:
140
195
  raise ValueError(
141
- f"Last two dimensions in values array ({values.shape[1:]}) shoudl correspond to matrix block shape {(dest.block_shape)})"
196
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
142
197
  )
143
198
 
144
199
  if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
@@ -150,6 +205,9 @@ def bsr_set_from_triplets(
150
205
  raise ValueError("Number of dimension for values array should be 1 or 3")
151
206
 
152
207
  nnz = rows.shape[0]
208
+ if nnz == 0:
209
+ bsr_set_zero(dest)
210
+ return
153
211
 
154
212
  # Increase dest array sizes if needed
155
213
  _bsr_ensure_fits(dest, nnz=nnz)
@@ -186,8 +244,8 @@ def bsr_set_from_triplets(
186
244
  )
187
245
 
188
246
 
189
- def bsr_assign(dest: BsrMatrix, src: BsrMatrix):
190
- """Copies the content of the `src` matrix to `dest`, possibly casting the block values."""
247
+ def bsr_assign(dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Any]]):
248
+ """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
191
249
 
192
250
  if dest.values.device != src.values.device:
193
251
  raise ValueError("Source and destination matrices must reside on the same device")
@@ -207,8 +265,12 @@ def bsr_assign(dest: BsrMatrix, src: BsrMatrix):
207
265
  warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
208
266
 
209
267
 
210
- def bsr_copy(A: BsrMatrix, scalar_type=None):
211
- """Returns a copy of matrix A, possibly casting values to a new scalar type"""
268
+ def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
269
+ """Returns a copy of matrix ``A``, possibly changing its scalar type.
270
+
271
+ Args:
272
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
273
+ """
212
274
  if scalar_type is None:
213
275
  block_type = A.values.dtype
214
276
  elif A.block_shape == (1, 1):
@@ -221,7 +283,7 @@ def bsr_copy(A: BsrMatrix, scalar_type=None):
221
283
  return copy
222
284
 
223
285
 
224
- def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
286
+ def bsr_set_transpose(dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], src: BsrMatrix[BlockType[Rows, Cols, Scalar]]):
225
287
  """Assigns the transposed matrix `src` to matrix `dest`"""
226
288
 
227
289
  if dest.values.device != src.values.device:
@@ -230,10 +292,7 @@ def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
230
292
  if dest.scalar_type != src.scalar_type:
231
293
  raise ValueError("All arguments must have the same scalar type")
232
294
 
233
- if src.block_shape == (1, 1):
234
- transpose_block_shape = (1, 1)
235
- else:
236
- transpose_block_shape = src.block_shape[::-1]
295
+ transpose_block_shape = src.block_shape[::-1]
237
296
 
238
297
  if dest.block_shape != transpose_block_shape:
239
298
  raise ValueError(f"Destination block shape must be {transpose_block_shape}")
@@ -242,6 +301,9 @@ def bsr_set_transpose(dest: BsrMatrix, src: BsrMatrix):
242
301
  dest.ncol = src.nrow
243
302
  dest.nnz = src.nnz
244
303
 
304
+ if src.nnz == 0:
305
+ return
306
+
245
307
  # Increase dest array sizes if needed
246
308
  _bsr_ensure_fits(dest)
247
309
 
@@ -301,27 +363,33 @@ def _bsr_get_diag_kernel(
301
363
  end = A_offsets[row + 1]
302
364
 
303
365
  diag = wp.lower_bound(A_columns, beg, end, row)
304
- if A_columns[diag] == row:
305
- out[row] = A_values[diag]
366
+ if diag < end:
367
+ if A_columns[diag] == row:
368
+ out[row] = A_values[diag]
369
+
370
+
371
+ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
372
+ """Returns the array of blocks that constitute the diagonal of a sparse matrix.
306
373
 
374
+ Args:
375
+ A: the sparse matrix from which to extract the diagonal
376
+ out: if provided, the array into which to store the diagonal blocks
377
+ """
307
378
 
308
- def bsr_get_diag(A: BsrMatrix, out: wp.array = None):
309
- """Returns the block diagonal of a square sparse matrix"""
310
- if A.nrow != A.ncol:
311
- raise ValueError("bsr_get_diag is only available for square sparse matrices")
379
+ dim = min(A.nrow, A.ncol)
312
380
 
313
381
  if out is None:
314
- out = wp.zeros(shape=(A.nrow,), dtype=A.values.dtype, device=A.values.device)
382
+ out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
315
383
  else:
316
384
  if out.dtype != A.values.dtype:
317
385
  raise ValueError(f"Output array must have type {A.values.dtype}")
318
386
  if out.device != A.values.device:
319
387
  raise ValueError(f"Output array must reside on device {A.values.device}")
320
- if out.shape[0] < A.nrow:
321
- raise ValueError(f"Output array must be of length at least {A.nrow}")
388
+ if out.shape[0] < dim:
389
+ raise ValueError(f"Output array must be of length at least {dim}")
322
390
 
323
391
  wp.launch(
324
- kernel=_bsr_get_diag_kernel, dim=A.nrow, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
392
+ kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, inputs=[A.offsets, A.columns, A.values, out]
325
393
  )
326
394
 
327
395
  return out
@@ -329,40 +397,205 @@ def bsr_get_diag(A: BsrMatrix, out: wp.array = None):
329
397
 
330
398
  @wp.kernel
331
399
  def _bsr_set_diag_kernel(
400
+ diag: wp.array(dtype=Any),
401
+ A_offsets: wp.array(dtype=int),
402
+ A_columns: wp.array(dtype=int),
403
+ A_values: wp.array(dtype=Any),
404
+ ):
405
+ row = wp.tid()
406
+ A_offsets[row + 1] = row + 1
407
+ A_columns[row] = row
408
+ A_values[row] = diag[row]
409
+
410
+ if row == 0:
411
+ A_offsets[0] = 0
412
+
413
+
414
+ @wp.kernel
415
+ def _bsr_set_diag_constant_kernel(
416
+ diag_value: Any,
332
417
  A_offsets: wp.array(dtype=int),
333
418
  A_columns: wp.array(dtype=int),
419
+ A_values: wp.array(dtype=Any),
334
420
  ):
335
421
  row = wp.tid()
336
422
  A_offsets[row + 1] = row + 1
337
423
  A_columns[row] = row
424
+ A_values[row] = diag_value
338
425
 
339
426
  if row == 0:
340
427
  A_offsets[0] = 0
341
428
 
342
429
 
343
- def bsr_set_diag(A: BsrMatrix, diag: wp.array):
344
- """Sets A as a block-diagonal square matrix"""
430
+ def bsr_set_diag(
431
+ A: BsrMatrix[BlockType],
432
+ diag: "Union[BlockType, Array[BlockType]]",
433
+ rows_of_blocks: Optional[int] = None,
434
+ cols_of_blocks: Optional[int] = None,
435
+ ):
436
+ """Sets `A` as a block-diagonal matrix
437
+
438
+ Args:
439
+ A: the sparse matrix to modify
440
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
441
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
442
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
443
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
444
+
445
+ The shape of the matrix will be defined one of the following, in that order:
446
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
447
+ - the first dimension of `diag`, if `diag` is an array
448
+ - the current dimensions of `A` otherwise
449
+ """
450
+
451
+ if rows_of_blocks is None and cols_of_blocks is not None:
452
+ rows_of_blocks = cols_of_blocks
453
+ if cols_of_blocks is None and rows_of_blocks is not None:
454
+ cols_of_blocks = rows_of_blocks
455
+
456
+ if warp.types.is_array(diag):
457
+ if rows_of_blocks is None:
458
+ rows_of_blocks = diag.shape[0]
459
+ cols_of_blocks = diag.shape[0]
460
+
461
+ if rows_of_blocks is not None:
462
+ A.nrow = rows_of_blocks
463
+ A.ncol = cols_of_blocks
464
+
465
+ A.nnz = min(A.nrow, A.ncol)
466
+ _bsr_ensure_fits(A)
467
+
468
+ if warp.types.is_array(diag):
469
+ wp.launch(
470
+ kernel=_bsr_set_diag_kernel,
471
+ dim=A.nnz,
472
+ device=A.values.device,
473
+ inputs=[diag, A.offsets, A.columns, A.values],
474
+ )
475
+ else:
476
+ if not warp.types.type_is_value(type(diag)):
477
+ # Cast to launchable type
478
+ diag = A.values.dtype(diag)
479
+ wp.launch(
480
+ kernel=_bsr_set_diag_constant_kernel,
481
+ dim=A.nnz,
482
+ device=A.values.device,
483
+ inputs=[diag, A.offsets, A.columns, A.values],
484
+ )
345
485
 
346
- A.nrow = diag.shape[0]
347
- A.ncol = diag.shape[0]
348
- A.nnz = diag.shape[0]
349
486
 
350
- A.values = diag
351
- if A.columns.size < A.nrow:
352
- A.columns = wp.empty(shape=(A.nrow,), dtype=int, device=diag.device)
353
- if A.offsets.size < A.nrow + 1:
354
- A.offsets = wp.empty(shape=(A.nrow + 1,), dtype=int, device=diag.device)
487
+ def bsr_diag(
488
+ diag: "Union[BlockType, Array[BlockType]]",
489
+ rows_of_blocks: Optional[int] = None,
490
+ cols_of_blocks: Optional[int] = None,
491
+ ) -> BsrMatrix["BlockType"]:
492
+ """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
355
493
 
356
- wp.launch(kernel=_bsr_set_diag_kernel, dim=A.nrow, device=A.values.device, inputs=[A.offsets, A.columns])
494
+ Args:
495
+ diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
496
+ or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
497
+ rows_of_blocks: If not ``None``, the new number of rows of blocks
498
+ cols_of_blocks: If not ``None``, the new number of columns of blocks
357
499
 
500
+ The shape of the matrix will be defined one of the following, in that order:
501
+ - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
502
+ - the first dimension of `diag`, if `diag` is an array
503
+ """
504
+
505
+ if rows_of_blocks is None and cols_of_blocks is not None:
506
+ rows_of_blocks = cols_of_blocks
507
+ if cols_of_blocks is None and rows_of_blocks is not None:
508
+ cols_of_blocks = rows_of_blocks
509
+
510
+ if warp.types.is_array(diag):
511
+ if rows_of_blocks is None:
512
+ rows_of_blocks = diag.shape[0]
513
+ cols_of_blocks = diag.shape[0]
514
+
515
+ A = bsr_zeros(
516
+ rows_of_blocks,
517
+ cols_of_blocks,
518
+ block_type=diag.dtype,
519
+ device=diag.device,
520
+ )
521
+ else:
522
+ if rows_of_blocks is None:
523
+ raise ValueError(
524
+ "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
525
+ )
526
+
527
+ block_type = type(diag)
528
+ if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
529
+ block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
530
+
531
+ A = bsr_zeros(
532
+ rows_of_blocks,
533
+ cols_of_blocks,
534
+ block_type=block_type,
535
+ )
358
536
 
359
- def bsr_diag(diag: wp.array):
360
- """Creates a square block-diagonal BSR matrix from the values array `diag`"""
361
- A = bsr_zeros(rows_of_blocks=diag.shape[0], cols_of_blocks=diag.shape[0], block_type=diag.dtype, device=diag.device)
362
537
  bsr_set_diag(A, diag)
363
538
  return A
364
539
 
365
540
 
541
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
542
+ """Sets `A` as the identity matrix
543
+
544
+ Args:
545
+ A: the sparse matrix to modify
546
+ rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
547
+ """
548
+
549
+ if A.block_shape == (1, 1):
550
+ identity = A.scalar_type(1.0)
551
+ else:
552
+ from numpy import eye
553
+
554
+ identity = eye(A.block_shape[0])
555
+
556
+ bsr_set_diag(A, diag=identity, rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks)
557
+
558
+
559
+ def bsr_identity(
560
+ rows_of_blocks: int, block_type: BlockType[Rows, Rows, Scalar], device: wp.context.Devicelike = None
561
+ ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
562
+ """Creates and returns a square identity matrix.
563
+
564
+ Args:
565
+ rows_of_blocks: Number of rows and columns of blocks in the created matrix.
566
+ block_type: Block type for the newly created matrix -- must be square
567
+ device: Device onto which to allocate the data arrays
568
+ """
569
+ A = bsr_zeros(rows_of_blocks=rows_of_blocks, cols_of_blocks=rows_of_blocks, block_type=block_type, device=device)
570
+ bsr_set_identity(A)
571
+ return A
572
+
573
+
574
+ @wp.kernel
575
+ def _bsr_scale_kernel(
576
+ alpha: Any,
577
+ values: wp.array(dtype=Any),
578
+ ):
579
+ values[wp.tid()] = alpha * values[wp.tid()]
580
+
581
+
582
+ def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
583
+ """
584
+ Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
585
+ """
586
+
587
+ if alpha != 1.0 and x.nnz > 0:
588
+ if alpha == 0.0:
589
+ bsr_set_zero(x)
590
+ else:
591
+ if not isinstance(alpha, x.scalar_type):
592
+ alpha = x.scalar_type(alpha)
593
+
594
+ wp.launch(kernel=_bsr_scale_kernel, dim=x.nnz, device=x.values.device, inputs=[alpha, x.values])
595
+
596
+ return x
597
+
598
+
366
599
  @wp.kernel
367
600
  def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
368
601
  i = wp.tid()
@@ -393,16 +626,75 @@ def _bsr_axpy_add_block(
393
626
  dst_values[block] = dst_values[block] + scale * src_values[i]
394
627
 
395
628
 
396
- def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
629
+ class bsr_axpy_work_arrays:
630
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
631
+
632
+ def __init__(self):
633
+ self._reset(None)
634
+
635
+ def _reset(self, device):
636
+ self.device = device
637
+ self._sum_rows = None
638
+ self._sum_cols = None
639
+ self._old_y_values = None
640
+ self._old_x_values = None
641
+
642
+ def _allocate(self, device, y: BsrMatrix, sum_nnz: int):
643
+ if self.device != device:
644
+ self._reset(device)
645
+
646
+ if self._sum_rows is None or self._sum_rows.size < sum_nnz:
647
+ self._sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
648
+ if self._sum_cols is None or self._sum_cols.size < sum_nnz:
649
+ self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
650
+
651
+ if self._old_y_values is None or self._old_y_values.size < y.nnz:
652
+ self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
653
+
654
+
655
+ def bsr_axpy(
656
+ x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
657
+ y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
658
+ alpha: Scalar = 1.0,
659
+ beta: Scalar = 1.0,
660
+ work_arrays: Optional[bsr_axpy_work_arrays] = None,
661
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
397
662
  """
398
- Performs the operation `y := alpha * X + beta * y` on BSR matrices `x` and `y`
663
+ Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
664
+
665
+ The `x` and `y` matrices are allowed to alias.
666
+
667
+ Args:
668
+ x: Read-only right-hand-side.
669
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
670
+ alpha: Uniform scaling factor for `x`
671
+ beta: Uniform scaling factor for `y`
672
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
399
673
  """
400
674
 
401
675
  if y is None:
402
- y = bsr_zeros(x.nrow, x.ncol, block_type=x.block_type, device=x.values.device)
676
+ # If not output matrix is provided, allocate it for convenience
677
+ y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
403
678
  beta = 0.0
404
679
 
405
- device = y.values.device
680
+ # Handle easy cases first
681
+ if beta == 0.0 or y.nnz == 0:
682
+ bsr_assign(src=x, dest=y)
683
+ return bsr_scale(y, alpha=alpha)
684
+
685
+ if alpha == 0.0 or x.nnz == 0:
686
+ return bsr_scale(y, alpha=beta)
687
+
688
+ if not isinstance(alpha, y.scalar_type):
689
+ alpha = y.scalar_type(alpha)
690
+ if not isinstance(beta, y.scalar_type):
691
+ beta = y.scalar_type(beta)
692
+
693
+ if x == y:
694
+ # Aliasing case
695
+ return bsr_scale(y, alpha=alpha.value + beta.value)
696
+
697
+ # General case
406
698
 
407
699
  if x.values.device != y.values.device:
408
700
  raise ValueError("All arguments must reside on the same device")
@@ -413,20 +705,21 @@ def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
413
705
  if x.nrow != y.nrow or x.ncol != y.ncol:
414
706
  raise ValueError("Matrices must have the same number of rows and columns")
415
707
 
416
- alpha = y.scalar_type(alpha)
417
- beta = y.scalar_type(beta)
708
+ if work_arrays is None:
709
+ work_arrays = bsr_axpy_work_arrays()
418
710
 
419
711
  sum_nnz = x.nnz + y.nnz
420
- sum_rows = wp.empty(shape=(sum_nnz), dtype=int, device=device)
421
- sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=device)
712
+ device = y.values.device
713
+ work_arrays._allocate(device, y, sum_nnz)
714
+
715
+ wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
716
+ wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, work_arrays._sum_rows])
422
717
 
423
- if y.nnz > 0:
424
- wp.copy(sum_cols, y.columns, 0, 0, y.nnz)
425
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=y.nnz, inputs=[0, y.offsets, sum_rows])
718
+ wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
719
+ wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, work_arrays._sum_rows])
426
720
 
427
- if x.nnz > 0:
428
- wp.copy(sum_cols, x.columns, y.nnz, 0, x.nnz)
429
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=x.nnz, inputs=[y.nnz, x.offsets, sum_rows])
721
+ # Save old y values before overwriting matrix
722
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
430
723
 
431
724
  # Increase dest array sizes if needed
432
725
  if y.columns.shape[0] < sum_nnz:
@@ -439,37 +732,55 @@ def bsr_axpy(x: BsrMatrix, y: BsrMatrix, alpha: float = 1.0, beta: float = 1.0):
439
732
  else:
440
733
  native_func = runtime.core.bsr_matrix_from_triplets_float_device
441
734
 
442
- sum_nnz = native_func(
735
+ old_y_nnz = y.nnz
736
+ y.nnz = native_func(
443
737
  y.block_shape[0],
444
738
  y.block_shape[1],
445
739
  y.nrow,
446
740
  sum_nnz,
447
- sum_rows.ptr,
448
- sum_cols.ptr,
741
+ work_arrays._sum_rows.ptr,
742
+ work_arrays._sum_cols.ptr,
449
743
  0,
450
744
  y.offsets.ptr,
451
745
  y.columns.ptr,
452
746
  0,
453
747
  )
454
748
 
455
- sum_values = wp.zeros(shape=(sum_nnz,), dtype=y.values.dtype, device=device)
749
+ _bsr_ensure_fits(y)
750
+ y.values.zero_()
456
751
 
457
752
  wp.launch(
458
753
  kernel=_bsr_axpy_add_block,
459
754
  device=device,
460
- dim=y.nnz,
461
- inputs=[0, beta, sum_rows, sum_cols, y.offsets, y.columns, y.values, sum_values],
755
+ dim=old_y_nnz,
756
+ inputs=[
757
+ 0,
758
+ beta,
759
+ work_arrays._sum_rows,
760
+ work_arrays._sum_cols,
761
+ y.offsets,
762
+ y.columns,
763
+ work_arrays._old_y_values,
764
+ y.values,
765
+ ],
462
766
  )
767
+
463
768
  wp.launch(
464
769
  kernel=_bsr_axpy_add_block,
465
770
  device=device,
466
771
  dim=x.nnz,
467
- inputs=[y.nnz, alpha, sum_rows, sum_cols, y.offsets, y.columns, x.values, sum_values],
772
+ inputs=[
773
+ old_y_nnz,
774
+ alpha,
775
+ work_arrays._sum_rows,
776
+ work_arrays._sum_cols,
777
+ y.offsets,
778
+ y.columns,
779
+ x.values,
780
+ y.values,
781
+ ],
468
782
  )
469
783
 
470
- y.values = sum_values
471
- y.nnz = sum_nnz
472
-
473
784
  return y
474
785
 
475
786
 
@@ -555,23 +866,77 @@ def _bsr_mm_compute_values(
555
866
  mm_values[mm_block] = mm_values[mm_block] + ax_val * y_values[y_block]
556
867
 
557
868
 
558
- _pinned_temp_count_buffer = {}
559
-
560
-
561
- def _get_pinned_temp_count_buffer(device):
562
- device = str(device)
563
- if device not in _pinned_temp_count_buffer:
564
- _pinned_temp_count_buffer[device] = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
565
-
566
- return _pinned_temp_count_buffer[device]
567
-
568
-
569
- def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0, beta: float = 0.0):
869
+ class bsr_mm_work_arrays:
870
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
871
+
872
+ def __init__(self):
873
+ self._reset(None)
874
+
875
+ def _reset(self, device):
876
+ self.device = device
877
+ self._pinned_count_buffer = None
878
+ self._mm_row_counts = None
879
+ self._mm_rows = None
880
+ self._mm_cols = None
881
+ self._old_z_values = None
882
+ self._old_z_offsets = None
883
+ self._old_z_columns = None
884
+
885
+ def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
886
+ if self.device != device:
887
+ self._reset(device)
888
+
889
+ # Allocations that do not depend on any computation
890
+ if self.device.is_cuda:
891
+ if self._pinned_count_buffer is None:
892
+ self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
893
+
894
+ if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
895
+ self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
896
+
897
+ if copied_z_nnz > 0:
898
+ if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
899
+ self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
900
+
901
+ if z_aliasing:
902
+ if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
903
+ self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
904
+ if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
905
+ self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
906
+
907
+ def _allocate_stage_2(self, mm_nnz: int):
908
+ # Allocations that depend on unmerged nnz estimate
909
+ if self._mm_rows is None or self._mm_rows.size < mm_nnz:
910
+ self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
911
+ if self._mm_cols is None or self._mm_cols.size < mm_nnz:
912
+ self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
913
+
914
+
915
+ def bsr_mm(
916
+ x: BsrMatrix[BlockType[Rows, Any, Scalar]],
917
+ y: BsrMatrix[BlockType[Any, Cols, Scalar]],
918
+ z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
919
+ alpha: Scalar = 1.0,
920
+ beta: Scalar = 0.0,
921
+ work_arrays: Optional[bsr_mm_work_arrays] = None,
922
+ ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
570
923
  """
571
- Performs the operation `z := alpha * X * Y + beta * z` on BSR matrices `x`, `y` and `z`
924
+ Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
925
+
926
+ The `x`, `y` and `z` matrices are allowed to alias.
927
+ If the matrix `z` is not provided as input, it will be allocated and treated as zero.
928
+
929
+ Args:
930
+ x: Read-only left factor of the matrix-matrix product.
931
+ y: Read-only right factor of the matrix-matrix product.
932
+ z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
933
+ alpha: Uniform scaling factor for the ``x * y`` product
934
+ beta: Uniform scaling factor for `z`
935
+ work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
572
936
  """
573
937
 
574
938
  if z is None:
939
+ # If not output matrix is provided, allocate it for convenience
575
940
  z_block_shape = (x.block_shape[0], y.block_shape[1])
576
941
  if z_block_shape == (1, 1):
577
942
  z_block_type = x.scalar_type
@@ -586,52 +951,85 @@ def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0,
586
951
  if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
587
952
  raise ValueError("Matrices must have the same scalar type")
588
953
 
589
- if x.block_shape[0] != z.block_shape[0] or y.block_shape[1] != z.block_shape[1]:
590
- raise ValueError("Incompatible blocks sizes for matrix multiplication")
954
+ if (
955
+ x.block_shape[0] != z.block_shape[0]
956
+ or y.block_shape[1] != z.block_shape[1]
957
+ or x.block_shape[1] != y.block_shape[0]
958
+ ):
959
+ raise ValueError("Incompatible block sizes for matrix multiplication")
591
960
 
592
- if x.nrow != z.nrow or z.ncol != y.ncol:
961
+ if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
593
962
  raise ValueError("Incompatible number of rows/columns for matrix multiplication")
594
963
 
595
964
  device = z.values.device
596
965
 
597
- alpha = z.scalar_type(alpha)
598
- beta = z.scalar_type(beta)
966
+ if alpha == 0.0 or x.nnz == 0 or y.nnz == 0:
967
+ # Easy case
968
+ return bsr_scale(z, beta)
969
+
970
+ if not isinstance(alpha, z.scalar_type):
971
+ alpha = z.scalar_type(alpha)
972
+ if not isinstance(beta, z.scalar_type):
973
+ beta = z.scalar_type(beta)
974
+
975
+ if work_arrays is None:
976
+ work_arrays = bsr_mm_work_arrays()
977
+
978
+ z_aliasing = z == x or z == y
979
+ copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
980
+
981
+ work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
599
982
 
600
983
  # Prefix sum of number of (unmerged) mm blocks per row
601
- mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=device)
602
984
  wp.launch(
603
985
  kernel=_bsr_mm_count_coeffs,
604
986
  device=device,
605
987
  dim=z.nrow,
606
- inputs=[z.nnz, x.offsets, x.columns, y.offsets, mm_row_counts],
988
+ inputs=[copied_z_nnz, x.offsets, x.columns, y.offsets, work_arrays._mm_row_counts],
607
989
  )
608
- warp.utils.array_scan(mm_row_counts, mm_row_counts)
990
+ warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
609
991
 
610
992
  # Get back total counts on host
611
993
  if device.is_cuda:
612
- mm_tot_count = _get_pinned_temp_count_buffer(device)
613
- wp.copy(dest=mm_tot_count, src=mm_row_counts, src_offset=z.nrow, count=1)
614
- wp.synchronize_stream(wp.get_stream())
615
- mm_nnz = int(mm_tot_count.numpy()[0])
994
+ wp.copy(dest=work_arrays._pinned_count_buffer, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1)
995
+ wp.synchronize_stream(wp.get_stream(device))
996
+ mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
616
997
  else:
617
- mm_nnz = int(mm_row_counts.numpy()[z.nrow])
998
+ mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
618
999
 
619
- mm_rows = wp.empty(shape=(mm_nnz), dtype=int, device=device)
620
- mm_cols = wp.empty(shape=(mm_nnz), dtype=int, device=device)
1000
+ work_arrays._allocate_stage_2(mm_nnz)
621
1001
 
622
- # Copy z rows columns
623
- wp.copy(mm_cols, z.columns, 0, 0, z.nnz)
624
- wp.launch(kernel=_bsr_get_block_row, device=device, dim=z.nnz, inputs=[0, z.offsets, mm_rows])
1002
+ # If z has a non-zero scale, save current data before overwriting it
1003
+ if copied_z_nnz > 0:
1004
+ # Copy z row and column indices
1005
+ wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1006
+ wp.launch(
1007
+ kernel=_bsr_get_block_row, device=device, dim=copied_z_nnz, inputs=[0, z.offsets, work_arrays._mm_rows]
1008
+ )
1009
+ # Save current z values in temporary buffer
1010
+ wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1011
+ if z_aliasing:
1012
+ # If z is aliasing with x or y, need to save topology as well
1013
+ wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1014
+ wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
625
1015
 
626
1016
  # Fill unmerged mm blocks rows and columns
627
1017
  wp.launch(
628
1018
  kernel=_bsr_mm_list_coeffs,
629
1019
  device=device,
630
1020
  dim=z.nrow,
631
- inputs=[x.offsets, x.columns, y.offsets, y.columns, mm_row_counts, mm_rows, mm_cols],
1021
+ inputs=[
1022
+ x.offsets,
1023
+ x.columns,
1024
+ y.offsets,
1025
+ y.columns,
1026
+ work_arrays._mm_row_counts,
1027
+ work_arrays._mm_rows,
1028
+ work_arrays._mm_cols,
1029
+ ],
632
1030
  )
633
1031
 
634
- # Increase dest array sizes if needed
1032
+ # Increase dest array size if needed
635
1033
  if z.columns.shape[0] < mm_nnz:
636
1034
  z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
637
1035
 
@@ -642,45 +1040,66 @@ def bsr_mm(x: BsrMatrix, y: BsrMatrix, z: BsrMatrix = None, alpha: float = 1.0,
642
1040
  else:
643
1041
  native_func = runtime.core.bsr_matrix_from_triplets_float_device
644
1042
 
645
- mm_nnz = native_func(
1043
+ z.nnz = native_func(
646
1044
  z.block_shape[0],
647
1045
  z.block_shape[1],
648
1046
  z.nrow,
649
1047
  mm_nnz,
650
- mm_rows.ptr,
651
- mm_cols.ptr,
1048
+ work_arrays._mm_rows.ptr,
1049
+ work_arrays._mm_cols.ptr,
652
1050
  0,
653
1051
  z.offsets.ptr,
654
1052
  z.columns.ptr,
655
1053
  0,
656
1054
  )
657
1055
 
658
- mm_values = wp.zeros(shape=(mm_nnz,), dtype=z.values.dtype, device=device)
659
-
660
- # Copy blocks from z
661
- wp.launch(
662
- kernel=_bsr_axpy_add_block,
663
- device=device,
664
- dim=z.nnz,
665
- inputs=[0, beta, mm_rows, mm_cols, z.offsets, z.columns, z.values, mm_values],
666
- )
667
-
668
- # Update z to point to result blocks
669
- z.values = mm_values
670
- z.nnz = mm_nnz
1056
+ _bsr_ensure_fits(z)
1057
+ z.values.zero_()
1058
+
1059
+ if copied_z_nnz > 0:
1060
+ # Add back original z values
1061
+ wp.launch(
1062
+ kernel=_bsr_axpy_add_block,
1063
+ device=device,
1064
+ dim=copied_z_nnz,
1065
+ inputs=[
1066
+ 0,
1067
+ beta,
1068
+ work_arrays._mm_rows,
1069
+ work_arrays._mm_cols,
1070
+ z.offsets,
1071
+ z.columns,
1072
+ work_arrays._old_z_values,
1073
+ z.values,
1074
+ ],
1075
+ )
671
1076
 
672
1077
  # Add mm blocks to z values
673
-
674
- if z.block_shape == (1, 1) and x.block_shape != (1, 1):
1078
+ if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
1079
+ warp.types.type_is_matrix(z.values.dtype)
1080
+ ):
675
1081
  # Result block type is scalar, but operands are matrices
676
1082
  # Cast result to (1x1) matrix to perform multiplication
677
- mm_values = mm_values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1083
+ mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1084
+ else:
1085
+ mm_values = z.values
678
1086
 
679
1087
  wp.launch(
680
1088
  kernel=_bsr_mm_compute_values,
681
1089
  device=device,
682
1090
  dim=z.nrow,
683
- inputs=[alpha, x.offsets, x.columns, x.values, y.offsets, y.columns, y.values, z.offsets, z.columns, mm_values],
1091
+ inputs=[
1092
+ alpha,
1093
+ work_arrays._old_z_offsets if x == z else x.offsets,
1094
+ work_arrays._old_z_columns if x == z else x.columns,
1095
+ work_arrays._old_z_values if x == z else x.values,
1096
+ work_arrays._old_z_offsets if y == z else y.offsets,
1097
+ work_arrays._old_z_columns if y == z else y.columns,
1098
+ work_arrays._old_z_values if y == z else y.values,
1099
+ z.offsets,
1100
+ z.columns,
1101
+ mm_values,
1102
+ ],
684
1103
  )
685
1104
 
686
1105
  return z
@@ -697,44 +1116,96 @@ def _bsr_mv_kernel(
697
1116
  y: wp.array(dtype=Any),
698
1117
  ):
699
1118
  row = wp.tid()
700
- beg = A_offsets[row]
701
- end = A_offsets[row + 1]
702
1119
 
703
- yr = y[row]
704
- v = yr - yr # WAR to get zero with correct type
705
- for block in range(beg, end):
706
- v = v + A_values[block] * x[A_columns[block]]
1120
+ # zero-initialize with type of y elements
1121
+ scalar_zero = type(alpha)(0)
1122
+ v = y.dtype(scalar_zero)
707
1123
 
708
- y[row] = beta * yr + alpha * v
1124
+ if alpha != scalar_zero:
1125
+ beg = A_offsets[row]
1126
+ end = A_offsets[row + 1]
1127
+ for block in range(beg, end):
1128
+ v += A_values[block] * x[A_columns[block]]
1129
+ v *= alpha
709
1130
 
1131
+ if beta != scalar_zero:
1132
+ v += beta * y[row]
710
1133
 
711
- def bsr_mv(A: BsrMatrix, x: wp.array, y: wp.array, alpha: float = 1.0, beta: float = 0.0):
1134
+ y[row] = v
1135
+
1136
+
1137
+ def bsr_mv(
1138
+ A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
1139
+ x: "Array[Vector[Cols, Scalar] | Scalar]",
1140
+ y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1141
+ alpha: Scalar = 1.0,
1142
+ beta: Scalar = 0.0,
1143
+ work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1144
+ ) -> "Array[Vector[Rows, Scalar] | Scalar]":
712
1145
  """
713
- Naive implementation of sparse matrix-vector product, `y := alpha * A * x + beta * y`.
1146
+ Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
1147
+
1148
+ The `x` and `y` vectors are allowed to alias.
1149
+
1150
+ Args:
1151
+ A: Read-only, left matrix factor of the matrix-vector product.
1152
+ x: Read-only, right vector factor of the matrix-vector product.
1153
+ y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1154
+ alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1155
+ beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1156
+ work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1157
+ will be used for this purpose, otherwise a temporary allocation wil be performed.
714
1158
  """
715
- alpha = A.scalar_type(alpha)
716
- beta = A.scalar_type(beta)
717
1159
 
718
- # if A.scalar_type != x.dtype or A.scalar_type != y.dtype:
719
- # raise ValueError("A, x and y must have the same data types")
1160
+ if y is None:
1161
+ # If no output array is provided, allocate one for convenience
1162
+ y_vec_len = A.block_shape[0]
1163
+ y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1164
+ y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
1165
+ y.zero_()
1166
+ beta = 0.0
1167
+
1168
+ if not isinstance(alpha, A.scalar_type):
1169
+ alpha = A.scalar_type(alpha)
1170
+ if not isinstance(beta, A.scalar_type):
1171
+ beta = A.scalar_type(beta)
720
1172
 
721
1173
  if A.values.device != x.device or A.values.device != y.device:
722
- raise ValueError("A, x and y must reide on the same device")
1174
+ raise ValueError("A, x and y must reside on the same device")
723
1175
 
724
1176
  if x.shape[0] != A.ncol:
725
1177
  raise ValueError("Number of columns of A must match number of rows of x")
726
1178
  if y.shape[0] != A.nrow:
727
1179
  raise ValueError("Number of rows of A must match number of rows of y")
728
1180
 
729
- # Promote scalar vectors to length-1 vecs
730
- block_shape = A.block_shape
731
- if block_shape != (1, 1):
732
- if block_shape[0] == 1:
1181
+ if x == y:
1182
+ # Aliasing case, need temporary storage
1183
+ if work_buffer is None:
1184
+ work_buffer = wp.empty_like(y)
1185
+ elif work_buffer.size < y.size:
1186
+ raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
1187
+ elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
1188
+ raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
1189
+
1190
+ # Save old y values before overwriting vector
1191
+ wp.copy(dest=work_buffer, src=y, count=y.size)
1192
+ x = work_buffer
1193
+
1194
+ # Promote scalar vectors to length-1 vecs and conversely
1195
+ if warp.types.type_is_matrix(A.values.dtype):
1196
+ if A.block_shape[0] == 1:
733
1197
  if y.dtype == A.scalar_type:
734
1198
  y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
735
- if block_shape[1] == 1:
1199
+ if A.block_shape[1] == 1:
736
1200
  if x.dtype == A.scalar_type:
737
1201
  x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1202
+ else:
1203
+ if A.block_shape[0] == 1:
1204
+ if y.dtype != A.scalar_type:
1205
+ y = y.view(dtype=A.scalar_type)
1206
+ if A.block_shape[1] == 1:
1207
+ if x.dtype != A.scalar_type:
1208
+ x = x.view(dtype=A.scalar_type)
738
1209
 
739
1210
  wp.launch(
740
1211
  kernel=_bsr_mv_kernel,
@@ -742,3 +1213,5 @@ def bsr_mv(A: BsrMatrix, x: wp.array, y: wp.array, alpha: float = 1.0, beta: flo
742
1213
  dim=A.nrow,
743
1214
  inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
744
1215
  )
1216
+
1217
+ return y