warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -19,7 +19,21 @@ from typing import Any, Generic, Optional, Tuple, TypeVar, Union
19
19
  import warp as wp
20
20
  import warp.types
21
21
  import warp.utils
22
- from warp.types import Array, Cols, Rows, Scalar, Vector
22
+ from warp.types import (
23
+ Array,
24
+ Cols,
25
+ Rows,
26
+ Scalar,
27
+ Vector,
28
+ is_array,
29
+ scalar_types,
30
+ type_is_matrix,
31
+ type_length,
32
+ type_repr,
33
+ type_scalar_type,
34
+ type_to_warp,
35
+ types_equal,
36
+ )
23
37
 
24
38
  # typing hints
25
39
 
@@ -45,50 +59,89 @@ class BsrMatrix(Generic[_BlockType]):
45
59
  Should not be constructed directly but through functions such as :func:`bsr_zeros`.
46
60
 
47
61
  Attributes:
48
- nrow (int): Number of rows of blocks
49
- ncol (int): Number of columns of blocks
50
- nnz (int): Upper bound for the number of non-zero blocks, used for dimensioning launches; the exact number is at ``offsets[nrow-1]``. See also :meth:`nnz_sync`.
51
- offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
52
- columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
53
- values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
62
+ nrow (int): Number of rows of blocks.
63
+ ncol (int): Number of columns of blocks.
64
+ nnz (int): Upper bound for the number of non-zero blocks, used for
65
+ dimensioning launches. The exact number is at ``offsets[nrow-1]``.
66
+ See also :meth:`nnz_sync`.
67
+ offsets (Array[int]): Array of size at least ``1 + nrow`` such that the
68
+ start and end indices of the blocks of row ``r`` are ``offsets[r]``
69
+ and ``offsets[r+1]``, respectively.
70
+ columns (Array[int]): Array of size at least equal to ``nnz`` containing
71
+ block column indices.
72
+ values (Array[BlockType]): Array of size at least equal to ``nnz``
73
+ containing block values.
54
74
  """
55
75
 
56
76
  @property
57
77
  def scalar_type(self) -> Scalar:
58
- """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type"""
59
- return warp.types.type_scalar_type(self.values.dtype)
78
+ """Scalar type for individual block coefficients. For CSR matrices, this is the same as the block type."""
79
+ return type_scalar_type(self.values.dtype)
60
80
 
61
81
  @property
62
82
  def block_shape(self) -> Tuple[int, int]:
63
- """Shape of the individual blocks"""
83
+ """Shape of the individual blocks."""
64
84
  return getattr(self.values.dtype, "_shape_", (1, 1))
65
85
 
66
86
  @property
67
87
  def block_size(self) -> int:
68
- """Size of the individual blocks, i.e. number of rows per block times number of columns per block"""
69
- return warp.types.type_length(self.values.dtype)
88
+ """Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
89
+ return type_length(self.values.dtype)
70
90
 
71
91
  @property
72
92
  def shape(self) -> Tuple[int, int]:
73
- """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block"""
93
+ """Shape of the matrix, i.e. number of rows/columns of blocks times number of rows/columns per block."""
74
94
  block_shape = self.block_shape
75
95
  return (self.nrow * block_shape[0], self.ncol * block_shape[1])
76
96
 
77
97
  @property
78
98
  def dtype(self) -> type:
79
- """Data type for individual block values"""
99
+ """Data type for individual block values."""
80
100
  return self.values.dtype
81
101
 
82
102
  @property
83
103
  def device(self) -> wp.context.Device:
84
- """Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays"""
104
+ """Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
85
105
  return self.values.device
86
106
 
107
+ @property
108
+ def scalar_values(self) -> wp.array:
109
+ """Accesses the ``values`` array as a 3d scalar array."""
110
+ if self.block_shape == (1, 1):
111
+ return self.values.reshape((self.nnz, 1, 1))
112
+
113
+ def _as_3d_array(arr):
114
+ return wp.array(
115
+ ptr=arr.ptr,
116
+ capacity=arr.capacity,
117
+ device=arr.device,
118
+ dtype=self.scalar_type,
119
+ shape=(self.nnz, *self.block_shape),
120
+ grad=None if arr.grad is None else _as_3d_array(arr.grad),
121
+ )
122
+
123
+ values_view = _as_3d_array(self.values)
124
+ values_view._ref = self.values # keep ref in case we're garbage collected
125
+ return values_view
126
+
127
+ def uncompress_rows(self, out: wp.array = None) -> wp.array:
128
+ """Compute the row index for each non-zero block from the compressed row offsets."""
129
+ if out is None:
130
+ out = wp.empty(self.nnz, dtype=int, device=self.device)
131
+
132
+ wp.launch(
133
+ kernel=_bsr_get_block_row,
134
+ device=self.device,
135
+ dim=self.nnz,
136
+ inputs=[self.nrow, self.offsets, out],
137
+ )
138
+ return out
139
+
87
140
  def nnz_sync(self):
88
- """Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
89
- and updates the nnz upper bound.
141
+ """Ensure that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed
142
+ and update the nnz upper bound.
90
143
 
91
- See also :meth:`copy_nnz_async`
144
+ See also :meth:`copy_nnz_async`.
92
145
  """
93
146
 
94
147
  if self._is_nnz_transfer_setup():
@@ -99,10 +152,11 @@ class BsrMatrix(Generic[_BlockType]):
99
152
 
100
153
  def copy_nnz_async(self, known_nnz: int = None):
101
154
  """
102
- Starts the asynchronous transfer of the exact nnz from the device offsets array to host, and records an event for completion.
155
+ Start the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
156
+
103
157
  Needs to be called whenever the offsets array has been modified from outside ``warp.sparse``.
104
158
 
105
- See also :meth:`nnz_sync`
159
+ See also :meth:`nnz_sync`.
106
160
  """
107
161
  if known_nnz is not None:
108
162
  self.nnz = int(known_nnz)
@@ -186,35 +240,33 @@ class BsrMatrix(Generic[_BlockType]):
186
240
  return _BsrScalingExpression(self, -1.0)
187
241
 
188
242
  def transpose(self):
189
- """Returns a transposed copy of this matrix"""
243
+ """Return a transposed copy of this matrix."""
190
244
  return bsr_transposed(self)
191
245
 
192
246
 
193
247
  def bsr_matrix_t(dtype: BlockType):
194
- dtype = wp.types.type_to_warp(dtype)
248
+ dtype = type_to_warp(dtype)
195
249
 
196
- if not warp.types.type_is_matrix(dtype) and dtype not in warp.types.scalar_types:
197
- raise ValueError(
198
- f"BsrMatrix block type must be either warp matrix or scalar; got {warp.types.type_repr(dtype)}"
199
- )
250
+ if not type_is_matrix(dtype) and dtype not in scalar_types:
251
+ raise ValueError(f"BsrMatrix block type must be either warp matrix or scalar; got {type_repr(dtype)}")
200
252
 
201
253
  class BsrMatrixTyped(BsrMatrix):
202
254
  nrow: int
203
- """Number of rows of blocks"""
255
+ """Number of rows of blocks."""
204
256
  ncol: int
205
- """Number of columns of blocks"""
257
+ """Number of columns of blocks."""
206
258
  nnz: int
207
- """Upper bound for the number of non-zeros"""
259
+ """Upper bound for the number of non-zeros."""
208
260
  offsets: wp.array(dtype=int)
209
- """Array of size at least 1 + nrows"""
261
+ """Array of size at least ``1 + nrow``."""
210
262
  columns: wp.array(dtype=int)
211
- """Array of size at least equal to nnz"""
263
+ """Array of size at least equal to ``nnz``."""
212
264
  values: wp.array(dtype=dtype)
213
265
 
214
266
  module = wp.get_module(BsrMatrix.__module__)
215
267
 
216
268
  if hasattr(dtype, "_shape_"):
217
- type_str = f"{warp.types.type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
269
+ type_str = f"{type_scalar_type(dtype).__name__}_{dtype._shape_[0]}_{dtype._shape_[1]}"
218
270
  else:
219
271
  type_str = dtype.__name__
220
272
  key = f"{BsrMatrix.__qualname__}_{type_str}"
@@ -235,16 +287,16 @@ def bsr_zeros(
235
287
  block_type: BlockType,
236
288
  device: wp.context.Devicelike = None,
237
289
  ) -> BsrMatrix:
238
- """
239
- Constructs and returns an empty BSR or CSR matrix with the given shape
290
+ """Construct and return an empty BSR or CSR matrix with the given shape.
240
291
 
241
292
  Args:
242
- bsr: The BSR or CSR matrix to set to zero
243
- rows_of_blocks: Number of rows of blocks
244
- cols_of_blocks: Number of columns of blocks
245
- block_type: Type of individual blocks. For CSR matrices, this should be a scalar type;
246
- for BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`)
247
- device: Device on which to allocate the matrix arrays
293
+ bsr: The BSR or CSR matrix to set to zero.
294
+ rows_of_blocks: Number of rows of blocks.
295
+ cols_of_blocks: Number of columns of blocks.
296
+ block_type: Type of individual blocks.
297
+ For CSR matrices, this should be a scalar type.
298
+ For BSR matrices, this should be a matrix type (e.g. from :func:`warp.mat`).
299
+ device: Device on which to allocate the matrix arrays.
248
300
  """
249
301
 
250
302
  bsr = bsr_matrix_t(block_type)()
@@ -281,13 +333,12 @@ def bsr_set_zero(
281
333
  rows_of_blocks: Optional[int] = None,
282
334
  cols_of_blocks: Optional[int] = None,
283
335
  ):
284
- """
285
- Sets a BSR matrix to zero, possibly changing its size
336
+ """Set a BSR matrix to zero, possibly changing its size.
286
337
 
287
338
  Args:
288
- bsr: The BSR or CSR matrix to set to zero
289
- rows_of_blocks: If not ``None``, the new number of rows of blocks
290
- cols_of_blocks: If not ``None``, the new number of columns of blocks
339
+ bsr: The BSR or CSR matrix to set to zero.
340
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
341
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
291
342
  """
292
343
 
293
344
  if rows_of_blocks is not None:
@@ -304,46 +355,55 @@ def bsr_set_from_triplets(
304
355
  dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
305
356
  rows: "Array[int]",
306
357
  columns: "Array[int]",
307
- values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
358
+ values: Optional["Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]"] = None,
308
359
  prune_numerical_zeros: bool = True,
360
+ masked: bool = False,
309
361
  ):
310
- """
311
- Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
362
+ """Fill a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
312
363
 
313
364
  The first dimension of the three input arrays must match and indicates the number of COO triplets.
314
365
 
315
366
  Args:
316
- dest: Sparse matrix to populate
317
- rows: Row index for each non-zero
318
- columns: Columns index for each non-zero
367
+ dest: Sparse matrix to populate.
368
+ rows: Row index for each non-zero.
369
+ columns: Columns index for each non-zero.
319
370
  values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
320
- to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
321
- prune_numerical_zeros: If True, will ignore the zero-valued blocks
371
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
372
+ If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
373
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
374
+ masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
322
375
  """
323
376
 
324
- if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
377
+ if rows.device != columns.device or rows.device != dest.device:
325
378
  raise ValueError("All arguments must reside on the same device")
326
379
 
327
- if values.shape[0] != rows.shape[0] or values.shape[0] != columns.shape[0]:
380
+ if rows.shape[0] != columns.shape[0]:
328
381
  raise ValueError("All triplet arrays must have the same length")
329
382
 
330
383
  # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
331
- if values.ndim == 1:
332
- if values.dtype != dest.values.dtype:
333
- raise ValueError("Values array type must correspond to that of dest matrix")
334
- elif values.ndim == 3:
335
- if values.shape[1:] != dest.block_shape:
336
- raise ValueError(
337
- f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
338
- )
384
+ if values is not None:
385
+ if values.device != rows.device:
386
+ raise ValueError("All arguments must reside on the same device")
387
+
388
+ if values.shape[0] != rows.shape[0]:
389
+ raise ValueError("All triplet arrays must have the same length")
390
+
391
+ if values.ndim == 1:
392
+ if values.dtype != dest.values.dtype:
393
+ raise ValueError("Values array type must correspond to that of dest matrix")
394
+ elif values.ndim == 3:
395
+ if values.shape[1:] != dest.block_shape:
396
+ raise ValueError(
397
+ f"Last two dimensions in values array ({values.shape[1:]}) should correspond to matrix block shape {(dest.block_shape)})"
398
+ )
339
399
 
340
- if warp.types.type_scalar_type(values.dtype) != dest.scalar_type:
341
- raise ValueError("Scalar type of values array should correspond to that of matrix")
400
+ if type_scalar_type(values.dtype) != dest.scalar_type:
401
+ raise ValueError("Scalar type of values array should correspond to that of matrix")
342
402
 
343
- if not values.is_contiguous:
344
- raise ValueError("Multi-dimensional values array should be contiguous")
345
- else:
346
- raise ValueError("Number of dimension for values array should be 1 or 3")
403
+ if not values.is_contiguous:
404
+ raise ValueError("Multi-dimensional values array should be contiguous")
405
+ else:
406
+ raise ValueError("Number of dimension for values array should be 1 or 3")
347
407
 
348
408
  nnz = rows.shape[0]
349
409
  if nnz == 0:
@@ -351,7 +411,8 @@ def bsr_set_from_triplets(
351
411
  return
352
412
 
353
413
  # Increase dest array sizes if needed
354
- _bsr_ensure_fits(dest, nnz=nnz)
414
+ if not masked:
415
+ _bsr_ensure_fits(dest, nnz=nnz)
355
416
 
356
417
  device = dest.values.device
357
418
  scalar_type = dest.scalar_type
@@ -381,16 +442,51 @@ def bsr_set_from_triplets(
381
442
  nnz,
382
443
  ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
383
444
  ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
384
- ctypes.cast(values.ptr, ctypes.c_void_p),
445
+ None if values is None else ctypes.cast(values.ptr, ctypes.c_void_p),
385
446
  prune_numerical_zeros,
447
+ masked,
386
448
  ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
387
449
  ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
388
- ctypes.cast(dest.values.ptr, ctypes.c_void_p),
450
+ None if values is None else ctypes.cast(dest.values.ptr, ctypes.c_void_p),
389
451
  ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
390
452
  nnz_event,
391
453
  )
392
454
 
393
455
 
456
+ def bsr_from_triplets(
457
+ rows_of_blocks: int,
458
+ cols_of_blocks: int,
459
+ rows: "Array[int]",
460
+ columns: "Array[int]",
461
+ values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
462
+ prune_numerical_zeros: bool = True,
463
+ ):
464
+ """Constructs a BSR matrix with values defined by coordinate-oriented (COO) triplets.
465
+
466
+ The first dimension of the three input arrays must match and indicates the number of COO triplets.
467
+
468
+ Args:
469
+ rows_of_blocks: Number of rows of blocks.
470
+ cols_of_blocks: Number of columns of blocks.
471
+ rows: Row index for each non-zero.
472
+ columns: Columns index for each non-zero.
473
+ values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
474
+ to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
475
+ prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
476
+ """
477
+
478
+ if values.ndim == 3:
479
+ block_type = wp.mat(shape=values.shape[1:], dtype=values.dtype)
480
+ else:
481
+ block_type = values.dtype
482
+
483
+ A = bsr_zeros(
484
+ rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
485
+ )
486
+ bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
487
+ return A
488
+
489
+
394
490
  class _BsrExpression(Generic[_BlockType]):
395
491
  pass
396
492
 
@@ -501,96 +597,73 @@ def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
501
597
  raise ValueError("Argument cannot be interpreted as a BsrMatrix")
502
598
 
503
599
 
504
- @wp.kernel
505
- def _bsr_assign_split_offsets(
506
- row_factor: int,
507
- col_factor: int,
508
- src_offsets: wp.array(dtype=int),
509
- dest_offsets: wp.array(dtype=int),
600
+ @wp.func
601
+ def _bsr_row_index(
602
+ offsets: wp.array(dtype=int),
603
+ row_count: int,
604
+ block: int,
510
605
  ):
511
- row = wp.tid()
512
-
513
- base_offset = src_offsets[row] * row_factor * col_factor
514
- row_count = src_offsets[1 + row] - src_offsets[row]
606
+ """Index of the row containing a block, or -1 if non-existing."""
607
+ return wp.where(block < offsets[row_count], wp.lower_bound(offsets, 0, row_count + 1, block + 1), 0) - 1
515
608
 
516
- for k in range(row_factor):
517
- dest_offsets[1 + k + row_factor * row] = base_offset + row_count * col_factor * (k + 1)
518
609
 
519
- if row == 0:
520
- dest_offsets[0] = 0
521
-
522
-
523
- @wp.kernel
524
- def _bsr_assign_split_blocks(
525
- structure_only: wp.bool,
526
- scale: Any,
527
- row_factor: int,
528
- col_factor: int,
529
- dest_row_count: int,
530
- src_offsets: wp.array(dtype=int),
531
- src_columns: wp.array(dtype=int),
532
- src_values: wp.array3d(dtype=Any),
533
- dest_offsets: wp.array(dtype=int),
534
- dest_columns: wp.array(dtype=int),
535
- dest_values: wp.array3d(dtype=Any),
610
+ @wp.func
611
+ def _bsr_block_index(
612
+ row: int,
613
+ col: int,
614
+ bsr_offsets: wp.array(dtype=int),
615
+ bsr_columns: wp.array(dtype=int),
536
616
  ):
537
- dest_block = wp.tid()
538
-
539
- if dest_block >= dest_offsets[dest_row_count]:
540
- return
541
-
542
- dest_row = wp.lower_bound(dest_offsets, 0, dest_row_count + 1, dest_block + 1) - 1
543
- src_row = dest_row // row_factor
544
-
545
- dest_col_in_row = dest_block - dest_offsets[dest_row]
546
- src_col_in_row = dest_col_in_row // col_factor
547
-
548
- src_block = src_offsets[src_row] + src_col_in_row
617
+ """Index of the block at block-coordinates (row, col), or -1 if non-existing.
618
+ Assumes bsr_columns is sorted.
619
+ """
549
620
 
550
- dest_rows_per_block = dest_values.shape[1]
551
- dest_cols_per_block = dest_values.shape[2]
621
+ if row < 0:
622
+ return -1
552
623
 
553
- split_row = dest_row - row_factor * src_row
554
- split_col = dest_col_in_row - col_factor * src_col_in_row
624
+ mask_row_beg = bsr_offsets[row]
625
+ mask_row_end = bsr_offsets[row + 1]
555
626
 
556
- dest_columns[dest_block] = src_columns[src_block] * col_factor + split_col
627
+ if mask_row_beg == mask_row_end:
628
+ return -1
557
629
 
558
- if not structure_only:
559
- src_base_i = split_row * dest_rows_per_block
560
- src_base_j = split_col * dest_cols_per_block
561
- for i in range(dest_rows_per_block):
562
- for j in range(dest_cols_per_block):
563
- dest_values[dest_block, i, j] = dest_values.dtype(
564
- scale * src_values[src_block, i + src_base_i, j + src_base_j]
565
- )
630
+ block_index = wp.lower_bound(bsr_columns, mask_row_beg, mask_row_end, col)
631
+ return wp.where(bsr_columns[block_index] == col, block_index, -1)
566
632
 
567
633
 
568
- @wp.kernel
569
- def _bsr_assign_merge_row_col(
570
- row_factor: int,
571
- col_factor: int,
634
+ @wp.kernel(enable_backward=False)
635
+ def _bsr_assign_list_blocks(
636
+ src_subrows: int,
637
+ src_subcols: int,
638
+ dest_subrows: int,
639
+ dest_subcols: int,
572
640
  src_row_count: int,
573
641
  src_offsets: wp.array(dtype=int),
574
642
  src_columns: wp.array(dtype=int),
575
643
  dest_rows: wp.array(dtype=int),
576
644
  dest_cols: wp.array(dtype=int),
577
645
  ):
578
- block = wp.tid()
646
+ block, subrow, subcol = wp.tid()
647
+ dest_block = (block * src_subcols + subcol) * src_subrows + subrow
579
648
 
580
- if block >= src_offsets[src_row_count]:
581
- dest_rows[block] = -1 # invalid
582
- dest_cols[block] = -1
649
+ row = _bsr_row_index(src_offsets, src_row_count, block)
650
+ if row == -1:
651
+ dest_rows[dest_block] = row # invalid
652
+ dest_cols[dest_block] = row
583
653
  else:
584
- row = wp.lower_bound(src_offsets, 0, src_row_count + 1, block + 1) - 1
585
- dest_rows[block] = row // row_factor
586
- dest_cols[block] = src_columns[block] // col_factor
654
+ dest_subrow = row * src_subrows + subrow
655
+ dest_subcol = src_columns[block] * src_subcols + subcol
656
+ dest_rows[dest_block] = dest_subrow // dest_subrows
657
+ dest_cols[dest_block] = dest_subcol // dest_subcols
587
658
 
588
659
 
589
660
  @wp.kernel
590
- def _bsr_assign_merge_blocks(
661
+ def _bsr_assign_copy_blocks(
591
662
  scale: Any,
592
- row_factor: int,
593
- col_factor: int,
663
+ src_subrows: int,
664
+ src_subcols: int,
665
+ dest_subrows: int,
666
+ dest_subcols: int,
594
667
  src_row_count: int,
595
668
  src_offsets: wp.array(dtype=int),
596
669
  src_columns: wp.array(dtype=int),
@@ -600,61 +673,58 @@ def _bsr_assign_merge_blocks(
600
673
  dest_values: wp.array3d(dtype=Any),
601
674
  ):
602
675
  src_block = wp.tid()
676
+ src_block, subrow, subcol = wp.tid()
603
677
 
604
- if src_block >= src_offsets[src_row_count]:
678
+ src_row = _bsr_row_index(src_offsets, src_row_count, src_block)
679
+ if src_row == -1:
605
680
  return
606
681
 
607
- src_row = wp.lower_bound(src_offsets, 0, src_row_count + 1, src_block + 1) - 1
608
682
  src_col = src_columns[src_block]
609
683
 
610
- dest_row = src_row // row_factor
611
- dest_col = src_col // col_factor
684
+ dest_subrow = src_row * src_subrows + subrow
685
+ dest_subcol = src_col * src_subcols + subcol
686
+ dest_row = dest_subrow // dest_subrows
687
+ dest_col = dest_subcol // dest_subcols
612
688
 
613
- dest_block = wp.lower_bound(dest_columns, dest_offsets[dest_row], dest_offsets[dest_row + 1], dest_col)
689
+ dest_block = _bsr_block_index(dest_row, dest_col, dest_offsets, dest_columns)
690
+ if dest_block == -1:
691
+ return
614
692
 
615
- src_rows_per_block = src_values.shape[1]
616
- src_cols_per_block = src_values.shape[2]
693
+ split_row = dest_subrow - dest_subrows * dest_row
694
+ split_col = dest_subcol - dest_subcols * dest_col
617
695
 
618
- split_row = src_row - row_factor * dest_row
619
- split_col = src_col - col_factor * dest_col
696
+ rows_per_subblock = src_values.shape[1] // src_subrows
697
+ cols_per_subblock = src_values.shape[2] // src_subcols
620
698
 
621
- dest_base_i = split_row * src_rows_per_block
622
- dest_base_j = split_col * src_cols_per_block
699
+ dest_base_i = split_row * rows_per_subblock
700
+ dest_base_j = split_col * cols_per_subblock
623
701
 
624
- for i in range(src_rows_per_block):
625
- for j in range(src_cols_per_block):
702
+ src_base_i = subrow * rows_per_subblock
703
+ src_base_j = subcol * cols_per_subblock
704
+
705
+ for i in range(rows_per_subblock):
706
+ for j in range(cols_per_subblock):
626
707
  dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
627
- scale * src_values[src_block, i, j]
708
+ scale * src_values[src_block, i + src_base_i, j + src_base_j]
628
709
  )
629
710
 
630
711
 
631
- def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array:
632
- if A.block_shape == (1, 1):
633
- return A.values.reshape((A.values.shape[0], 1, 1))
634
-
635
- return wp.array(
636
- data=None,
637
- ptr=A.values.ptr,
638
- capacity=A.values.capacity,
639
- device=A.device,
640
- dtype=A.scalar_type,
641
- shape=(A.values.shape[0], A.block_shape[0], A.block_shape[1]),
642
- )
643
-
644
-
645
712
  def bsr_assign(
646
713
  dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
647
714
  src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
648
715
  structure_only: bool = False,
716
+ masked: bool = False,
649
717
  ):
650
- """Copies the content of the `src` BSR matrix to `dest`.
718
+ """Copy the content of the ``src`` BSR matrix to ``dest``.
651
719
 
652
720
  Args:
653
- src: Matrix to be copied
654
- dest: Destination matrix. May have a different block shape of scalar type than `src`, in which case the required casting will be performed.
721
+ src: Matrix to be copied.
722
+ dest: Destination matrix. May have a different block shape or scalar type
723
+ than ``src``, in which case the required casting will be performed.
655
724
  structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
656
- to accommodate at least `src.nnz` blocks. If `structure_only` is ``False``, values are also copied with implicit
725
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
657
726
  casting if the two matrices use distinct scalar types.
727
+ masked: If ``True``, prevent the assignment operation from adding new non-zeros blocks to ``dest``.
658
728
  """
659
729
 
660
730
  src, src_scale = _extract_matrix_and_scale(src)
@@ -662,13 +732,50 @@ def bsr_assign(
662
732
  if dest.values.device != src.values.device:
663
733
  raise ValueError("Source and destination matrices must reside on the same device")
664
734
 
665
- if dest.block_shape == src.block_shape:
666
- dest.nrow = src.nrow
667
- dest.ncol = src.ncol
735
+ if src.block_shape[0] >= dest.block_shape[0]:
736
+ src_subrows = src.block_shape[0] // dest.block_shape[0]
737
+ dest_subrows = 1
738
+ else:
739
+ dest_subrows = dest.block_shape[0] // src.block_shape[0]
740
+ src_subrows = 1
741
+
742
+ if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
743
+ raise ValueError(
744
+ f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {src.block_shape[0]}, {dest.block_shape[0]})"
745
+ )
746
+
747
+ if src.block_shape[1] >= dest.block_shape[1]:
748
+ src_subcols = src.block_shape[1] // dest.block_shape[1]
749
+ dest_subcols = 1
750
+ else:
751
+ dest_subcols = dest.block_shape[1] // src.block_shape[1]
752
+ src_subcols = 1
753
+
754
+ if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
755
+ raise ValueError(
756
+ f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {src.block_shape[1]}, {dest.block_shape[1]})"
757
+ )
758
+
759
+ dest_nrow = (src.nrow * src_subrows) // dest_subrows
760
+ dest_ncol = (src.ncol * src_subcols) // dest_subcols
668
761
 
669
- nnz_alloc = src.nnz
762
+ if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
763
+ raise ValueError("The requested block shape does not evenly divide the source matrix")
764
+
765
+ nnz_alloc = src.nnz * src_subrows * src_subcols
766
+ if masked:
767
+ if dest_nrow != dest.nrow or dest_ncol != dest.ncol:
768
+ raise ValueError(
769
+ f"Incompatible destination matrix size, expected ({dest_nrow}, {dest_ncol}), got ({dest.nrow}, {dest.ncol})"
770
+ )
771
+ else:
772
+ dest.nrow = dest_nrow
773
+ dest.ncol = dest_ncol
670
774
  _bsr_ensure_fits(dest, nnz=nnz_alloc)
671
775
 
776
+ if dest.block_shape == src.block_shape and not masked:
777
+ # Direct copy
778
+
672
779
  wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
673
780
  dest.copy_nnz_async()
674
781
 
@@ -679,86 +786,29 @@ def bsr_assign(
679
786
  warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
680
787
  bsr_scale(dest, src_scale)
681
788
 
682
- elif src.block_shape[0] >= dest.block_shape[0] and src.block_shape[1] >= dest.block_shape[1]:
683
- # Split blocks
684
-
685
- row_factor = src.block_shape[0] // dest.block_shape[0]
686
- col_factor = src.block_shape[1] // dest.block_shape[1]
687
-
688
- if (
689
- row_factor * dest.block_shape[0] != src.block_shape[0]
690
- or col_factor * dest.block_shape[1] != src.block_shape[1]
691
- ):
692
- raise ValueError(
693
- f"Dest block shape {dest.block_shape} is not an exact divider of src block shape {src.block_shape}"
694
- )
695
-
696
- dest.nrow = src.nrow * row_factor
697
- dest.ncol = src.ncol * col_factor
698
-
699
- nnz_alloc = src.nnz * row_factor * col_factor
700
- _bsr_ensure_fits(dest, nnz=nnz_alloc)
789
+ else:
790
+ # Masked and/or multiple src blocks per dest block, go through COO format
701
791
 
792
+ # Compute destination rows and columns
793
+ dest_rows = wp.empty(nnz_alloc, dtype=int, device=dest.device)
794
+ dest_cols = wp.empty(nnz_alloc, dtype=int, device=dest.device)
702
795
  wp.launch(
703
- _bsr_assign_split_offsets,
704
- dim=src.nrow,
705
- device=dest.device,
706
- inputs=[row_factor, col_factor, src.offsets, dest.offsets],
707
- )
708
- wp.launch(
709
- _bsr_assign_split_blocks,
710
- dim=dest.nnz,
796
+ _bsr_assign_list_blocks,
797
+ dim=(src.nnz, src_subrows, src_subcols),
711
798
  device=dest.device,
712
799
  inputs=[
713
- wp.bool(structure_only),
714
- src.scalar_type(src_scale),
715
- row_factor,
716
- col_factor,
717
- dest.nrow,
800
+ src_subrows,
801
+ src_subcols,
802
+ dest_subrows,
803
+ dest_subcols,
804
+ src.nrow,
718
805
  src.offsets,
719
806
  src.columns,
720
- _bsr_values_as_3d_array(src),
721
- dest.offsets,
722
- dest.columns,
723
- _bsr_values_as_3d_array(dest),
807
+ dest_rows,
808
+ dest_cols,
724
809
  ],
725
810
  )
726
811
 
727
- elif src.block_shape[0] <= dest.block_shape[0] and src.block_shape[1] <= dest.block_shape[1]:
728
- # Merge blocks
729
-
730
- row_factor = dest.block_shape[0] // src.block_shape[0]
731
- col_factor = dest.block_shape[1] // src.block_shape[1]
732
-
733
- if (
734
- row_factor * src.block_shape[0] != dest.block_shape[0]
735
- or col_factor * src.block_shape[1] != dest.block_shape[1]
736
- ):
737
- raise ValueError(
738
- f"Dest block shape {dest.block_shape} is not an exact multiple of src block shape {src.block_shape}"
739
- )
740
-
741
- if src.nrow % row_factor != 0 or src.ncol % col_factor != 0:
742
- raise ValueError(
743
- "The total rows and columns of the src matrix cannot be evenly divided using the requested block shape"
744
- )
745
-
746
- dest.nrow = src.nrow // row_factor
747
- dest.ncol = src.ncol // col_factor
748
-
749
- nnz_alloc = src.nnz # Conservative, in case all nnz in src belong to distinct merged blocks
750
- _bsr_ensure_fits(dest, nnz=nnz_alloc)
751
-
752
- # Compute destination rows and columns
753
- dest_rows = wp.empty_like(src.columns)
754
- dest_cols = wp.empty_like(src.columns)
755
- wp.launch(
756
- _bsr_assign_merge_row_col,
757
- dim=src.nnz,
758
- device=dest.device,
759
- inputs=[row_factor, col_factor, src.nrow, src.offsets, src.columns, dest_rows, dest_cols],
760
- )
761
-
762
812
  # Compute destination offsets from triplets
763
813
  from warp.context import runtime
764
814
 
@@ -773,11 +823,12 @@ def bsr_assign(
773
823
  dest.block_shape[0],
774
824
  dest.block_shape[1],
775
825
  dest.nrow,
776
- dest.nnz,
826
+ nnz_alloc,
777
827
  ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
778
828
  ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
779
829
  0,
780
830
  False,
831
+ masked,
781
832
  ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
782
833
  ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
783
834
  0,
@@ -789,26 +840,25 @@ def bsr_assign(
789
840
  if not structure_only:
790
841
  dest.values.zero_()
791
842
  wp.launch(
792
- _bsr_assign_merge_blocks,
793
- dim=src.nnz,
843
+ _bsr_assign_copy_blocks,
844
+ dim=(src.nnz, src_subrows, src_subcols),
794
845
  device=dest.device,
795
846
  inputs=[
796
847
  src.scalar_type(src_scale),
797
- row_factor,
798
- col_factor,
848
+ src_subrows,
849
+ src_subcols,
850
+ dest_subrows,
851
+ dest_subcols,
799
852
  src.nrow,
800
853
  src.offsets,
801
854
  src.columns,
802
- _bsr_values_as_3d_array(src),
855
+ src.scalar_values,
803
856
  dest.offsets,
804
857
  dest.columns,
805
- _bsr_values_as_3d_array(dest),
858
+ dest.scalar_values,
806
859
  ],
807
860
  )
808
861
 
809
- else:
810
- raise ValueError("Incompatible dest and src block shapes")
811
-
812
862
 
813
863
  def bsr_copy(
814
864
  A: BsrMatrixOrExpression,
@@ -816,15 +866,15 @@ def bsr_copy(
816
866
  block_shape: Optional[Tuple[int, int]] = None,
817
867
  structure_only: bool = False,
818
868
  ):
819
- """Returns a copy of matrix ``A``, possibly changing its scalar type.
869
+ """Return a copy of matrix ``A``, possibly changing its scalar type.
820
870
 
821
871
  Args:
822
- A: Matrix to be copied
823
- scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
824
- block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from `A`.
825
- Both dimensions of `block_shape` must be either a multiple or an exact divider of the ones from `A`.
872
+ A: Matrix to be copied.
873
+ scalar_type: If provided, the returned matrix will use this scalar type instead of the one from ``A``.
874
+ block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from ``A``.
875
+ Both dimensions of ``block_shape`` must be either a multiple or an exact divider of the ones from ``A``.
826
876
  structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
827
- to accommodate at least `src.nnz` blocks. If `structure_only` is ``False``, values are also copied with implicit
877
+ to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
828
878
  casting if the two matrices use distinct scalar types.
829
879
  """
830
880
  if scalar_type is None:
@@ -835,7 +885,7 @@ def bsr_copy(
835
885
  if block_shape == (1, 1):
836
886
  block_type = scalar_type
837
887
  else:
838
- block_type = wp.types.matrix(shape=block_shape, dtype=scalar_type)
888
+ block_type = wp.mat(shape=block_shape, dtype=scalar_type)
839
889
 
840
890
  copy = bsr_zeros(
841
891
  rows_of_blocks=A.nrow,
@@ -851,7 +901,7 @@ def bsr_set_transpose(
851
901
  dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
852
902
  src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
853
903
  ):
854
- """Assigns the transposed matrix `src` to matrix `dest`"""
904
+ """Assign the transposed matrix ``src`` to matrix ``dest``."""
855
905
 
856
906
  src, src_scale = _extract_matrix_and_scale(src)
857
907
 
@@ -912,13 +962,13 @@ def bsr_set_transpose(
912
962
  bsr_scale(dest, src_scale)
913
963
 
914
964
 
915
- def bsr_transposed(A: BsrMatrixOrExpression):
916
- """Returns a copy of the transposed matrix `A`"""
965
+ def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
966
+ """Return a copy of the transposed matrix ``A``."""
917
967
 
918
968
  if A.block_shape == (1, 1):
919
969
  block_type = A.values.dtype
920
970
  else:
921
- block_type = wp.types.matrix(shape=A.block_shape[::-1], dtype=A.scalar_type)
971
+ block_type = wp.mat(shape=A.block_shape[::-1], dtype=A.scalar_type)
922
972
 
923
973
  transposed = bsr_zeros(
924
974
  rows_of_blocks=A.ncol,
@@ -939,21 +989,18 @@ def _bsr_get_diag_kernel(
939
989
  out: wp.array(dtype=Any),
940
990
  ):
941
991
  row = wp.tid()
942
- beg = A_offsets[row]
943
- end = A_offsets[row + 1]
944
992
 
945
- diag = wp.lower_bound(A_columns, beg, end, row)
946
- if diag < end:
947
- if A_columns[diag] == row:
948
- out[row] = scale * A_values[diag]
993
+ diag = _bsr_block_index(row, row, A_offsets, A_columns)
994
+ if diag != -1:
995
+ out[row] = scale * A_values[diag]
949
996
 
950
997
 
951
998
  def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
952
- """Returns the array of blocks that constitute the diagonal of a sparse matrix.
999
+ """Return the array of blocks that constitute the diagonal of a sparse matrix.
953
1000
 
954
1001
  Args:
955
- A: the sparse matrix from which to extract the diagonal
956
- out: if provided, the array into which to store the diagonal blocks
1002
+ A: The sparse matrix from which to extract the diagonal.
1003
+ out: If provided, the array into which to store the diagonal blocks.
957
1004
  """
958
1005
 
959
1006
  A, scale = _extract_matrix_and_scale(A)
@@ -980,36 +1027,16 @@ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[Block
980
1027
  return out
981
1028
 
982
1029
 
983
- @wp.kernel
1030
+ @wp.kernel(enable_backward=False)
984
1031
  def _bsr_set_diag_kernel(
985
- diag: wp.array(dtype=Any),
986
- A_offsets: wp.array(dtype=int),
987
- A_columns: wp.array(dtype=int),
988
- A_values: wp.array(dtype=Any),
989
- ):
990
- row = wp.tid()
991
- A_offsets[row + 1] = row + 1
992
- A_columns[row] = row
993
- A_values[row] = diag[row]
994
-
995
- if row == 0:
996
- A_offsets[0] = 0
997
-
998
-
999
- @wp.kernel
1000
- def _bsr_set_diag_constant_kernel(
1001
- diag_value: Any,
1032
+ nnz: int,
1002
1033
  A_offsets: wp.array(dtype=int),
1003
1034
  A_columns: wp.array(dtype=int),
1004
- A_values: wp.array(dtype=Any),
1005
1035
  ):
1006
1036
  row = wp.tid()
1007
- A_offsets[row + 1] = row + 1
1008
- A_columns[row] = row
1009
- A_values[row] = diag_value
1010
-
1011
- if row == 0:
1012
- A_offsets[0] = 0
1037
+ A_offsets[row] = wp.min(row, nnz)
1038
+ if row < nnz:
1039
+ A_columns[row] = row
1013
1040
 
1014
1041
 
1015
1042
  def bsr_set_diag(
@@ -1017,20 +1044,26 @@ def bsr_set_diag(
1017
1044
  diag: "Union[BlockType, Array[BlockType]]",
1018
1045
  rows_of_blocks: Optional[int] = None,
1019
1046
  cols_of_blocks: Optional[int] = None,
1020
- ):
1021
- """Sets `A` as a block-diagonal matrix
1047
+ ) -> None:
1048
+ """Set ``A`` as a block-diagonal matrix.
1022
1049
 
1023
1050
  Args:
1024
- A: the sparse matrix to modify
1025
- diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
1026
- or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
1027
- rows_of_blocks: If not ``None``, the new number of rows of blocks
1028
- cols_of_blocks: If not ``None``, the new number of columns of blocks
1051
+ A: The sparse matrix to modify.
1052
+ diag: Specifies the values for diagonal blocks. Can be one of:
1053
+
1054
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1055
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1056
+ - ``None``: Diagonal block values are left uninitialized
1057
+
1058
+ rows_of_blocks: If not ``None``, the new number of rows of blocks.
1059
+ cols_of_blocks: If not ``None``, the new number of columns of blocks.
1060
+
1061
+ The shape of the matrix will be defined one of the following, in this order:
1029
1062
 
1030
- The shape of the matrix will be defined one of the following, in that order:
1031
- - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
1032
- - the first dimension of `diag`, if `diag` is an array
1033
- - the current dimensions of `A` otherwise
1063
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1064
+ If only one is given, the second is assumed equal.
1065
+ - The first dimension of ``diag``, if ``diag`` is an array
1066
+ - The current dimensions of ``A`` otherwise
1034
1067
  """
1035
1068
 
1036
1069
  if rows_of_blocks is None and cols_of_blocks is not None:
@@ -1038,7 +1071,7 @@ def bsr_set_diag(
1038
1071
  if cols_of_blocks is None and rows_of_blocks is not None:
1039
1072
  cols_of_blocks = rows_of_blocks
1040
1073
 
1041
- if warp.types.is_array(diag):
1074
+ if is_array(diag):
1042
1075
  if rows_of_blocks is None:
1043
1076
  rows_of_blocks = diag.shape[0]
1044
1077
  cols_of_blocks = diag.shape[0]
@@ -1050,43 +1083,45 @@ def bsr_set_diag(
1050
1083
  nnz = min(A.nrow, A.ncol)
1051
1084
  _bsr_ensure_fits(A, nnz=nnz)
1052
1085
 
1053
- if warp.types.is_array(diag):
1054
- wp.launch(
1055
- kernel=_bsr_set_diag_kernel,
1056
- dim=nnz,
1057
- device=A.values.device,
1058
- inputs=[diag, A.offsets, A.columns, A.values],
1059
- )
1060
- else:
1061
- if not warp.types.type_is_value(type(diag)):
1062
- # Cast to launchable type
1063
- diag = A.values.dtype(diag)
1064
- wp.launch(
1065
- kernel=_bsr_set_diag_constant_kernel,
1066
- dim=nnz,
1067
- device=A.values.device,
1068
- inputs=[diag, A.offsets, A.columns, A.values],
1069
- )
1086
+ wp.launch(
1087
+ kernel=_bsr_set_diag_kernel,
1088
+ dim=nnz + 1,
1089
+ device=A.offsets.device,
1090
+ inputs=[nnz, A.offsets, A.columns],
1091
+ )
1092
+
1093
+ if is_array(diag):
1094
+ wp.copy(src=diag, dest=A.values, count=nnz)
1095
+ elif diag is not None:
1096
+ A.values.fill_(diag)
1070
1097
 
1071
1098
  A.copy_nnz_async(known_nnz=nnz)
1072
1099
 
1073
1100
 
1074
1101
  def bsr_diag(
1075
- diag: "Union[BlockType, Array[BlockType]]",
1102
+ diag: Optional[Union[BlockType, Array[BlockType]]] = None,
1076
1103
  rows_of_blocks: Optional[int] = None,
1077
1104
  cols_of_blocks: Optional[int] = None,
1105
+ block_type: Optional[BlockType] = None,
1106
+ device=None,
1078
1107
  ) -> BsrMatrix["BlockType"]:
1079
- """Creates and returns a block-diagonal BSR matrix from an given block value or array of block values.
1108
+ """Create and return a block-diagonal BSR matrix from an given block value or array of block values.
1080
1109
 
1081
1110
  Args:
1082
- diag: Either a warp array of type ``A.values.dtype``, in which case each element will define one block of the diagonal,
1083
- or a constant value of type ``A.values.dtype``, in which case it will get assigned to all diagonal blocks.
1111
+ diag: Specifies the values for diagonal blocks. Can be one of:
1112
+
1113
+ - A Warp array of type ``A.values.dtype``: Each element defines one block of the diagonal
1114
+ - A constant value of type ``A.values.dtype``: This value is assigned to all diagonal blocks
1084
1115
  rows_of_blocks: If not ``None``, the new number of rows of blocks
1085
1116
  cols_of_blocks: If not ``None``, the new number of columns of blocks
1117
+ block_type: If ``diag`` is ``None``, block type of the matrix. Otherwise deduced from ``diag``
1118
+ device: If ``diag`` is not a Warp array, device on which to allocate the matrix. Otherwise deduced from ``diag``
1119
+
1120
+ The shape of the matrix will be defined one of the following, in this order:
1086
1121
 
1087
- The shape of the matrix will be defined one of the following, in that order:
1088
- - `rows_of_blocks` and `cols_of_blocks`, if provided. If only one is given, the second is assumed equal.
1089
- - the first dimension of `diag`, if `diag` is an array
1122
+ - ``rows_of_blocks`` and ``cols_of_blocks``, if provided.
1123
+ If only one is given, the second is assumed equal.
1124
+ - The first dimension of ``diag`` if ``diag`` is an array.
1090
1125
  """
1091
1126
 
1092
1127
  if rows_of_blocks is None and cols_of_blocks is not None:
@@ -1094,43 +1129,39 @@ def bsr_diag(
1094
1129
  if cols_of_blocks is None and rows_of_blocks is not None:
1095
1130
  cols_of_blocks = rows_of_blocks
1096
1131
 
1097
- if warp.types.is_array(diag):
1132
+ if is_array(diag):
1098
1133
  if rows_of_blocks is None:
1099
1134
  rows_of_blocks = diag.shape[0]
1100
1135
  cols_of_blocks = diag.shape[0]
1101
1136
 
1102
- A = bsr_zeros(
1103
- rows_of_blocks,
1104
- cols_of_blocks,
1105
- block_type=diag.dtype,
1106
- device=diag.device,
1107
- )
1137
+ block_type = diag.dtype
1138
+ device = diag.device
1108
1139
  else:
1109
1140
  if rows_of_blocks is None:
1110
1141
  raise ValueError(
1111
1142
  "rows_of_blocks and/or cols_of_blocks must be provided for constructing a diagonal matrix with uniform diagonal"
1112
1143
  )
1113
1144
 
1145
+ if block_type is None:
1146
+ if diag is None:
1147
+ raise ValueError("Either `diag` or `block_type` needs to be provided")
1148
+
1114
1149
  block_type = type(diag)
1115
- if not warp.types.type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
1150
+ if not type_is_matrix(block_type) and len(getattr(diag, "shape", ())) == 2:
1116
1151
  block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
1117
1152
 
1118
- A = bsr_zeros(
1119
- rows_of_blocks,
1120
- cols_of_blocks,
1121
- block_type=block_type,
1122
- )
1123
-
1153
+ A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
1124
1154
  bsr_set_diag(A, diag)
1125
1155
  return A
1126
1156
 
1127
1157
 
1128
- def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None):
1129
- """Sets `A` as the identity matrix
1158
+ def bsr_set_identity(A: BsrMatrix, rows_of_blocks: Optional[int] = None) -> None:
1159
+ """Set ``A`` as the identity matrix.
1130
1160
 
1131
1161
  Args:
1132
- A: the sparse matrix to modify
1133
- rows_of_blocks: if provided, the matrix will be resized as a square matrix with `rows_of_blocks` rows and columns.
1162
+ A: The sparse matrix to modify.
1163
+ rows_of_blocks: If provided, the matrix will be resized as a square
1164
+ matrix with ``rows_of_blocks`` rows and columns.
1134
1165
  """
1135
1166
 
1136
1167
  if A.block_shape == (1, 1):
@@ -1148,11 +1179,11 @@ def bsr_identity(
1148
1179
  block_type: BlockType[Rows, Rows, Scalar],
1149
1180
  device: wp.context.Devicelike = None,
1150
1181
  ) -> BsrMatrix[BlockType[Rows, Rows, Scalar]]:
1151
- """Creates and returns a square identity matrix.
1182
+ """Create and return a square identity matrix.
1152
1183
 
1153
1184
  Args:
1154
1185
  rows_of_blocks: Number of rows and columns of blocks in the created matrix.
1155
- block_type: Block type for the newly created matrix -- must be square
1186
+ block_type: Block type for the newly created matrix. Must be square
1156
1187
  device: Device onto which to allocate the data arrays
1157
1188
  """
1158
1189
  A = bsr_zeros(
@@ -1174,9 +1205,7 @@ def _bsr_scale_kernel(
1174
1205
 
1175
1206
 
1176
1207
  def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1177
- """
1178
- Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
1179
- """
1208
+ """Perform the operation ``x := alpha * x`` on BSR matrix ``x`` and return ``x``."""
1180
1209
 
1181
1210
  x, scale = _extract_matrix_and_scale(x)
1182
1211
  alpha *= scale
@@ -1185,8 +1214,7 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1185
1214
  if alpha == 0.0:
1186
1215
  bsr_set_zero(x)
1187
1216
  else:
1188
- if not isinstance(alpha, x.scalar_type):
1189
- alpha = x.scalar_type(alpha)
1217
+ alpha = x.scalar_type(alpha)
1190
1218
 
1191
1219
  wp.launch(
1192
1220
  kernel=_bsr_scale_kernel,
@@ -1198,15 +1226,10 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1198
1226
  return x
1199
1227
 
1200
1228
 
1201
- @wp.kernel
1202
- def _bsr_get_block_row(dest_offset: int, row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
1203
- i = wp.tid()
1204
-
1205
- if i >= bsr_offsets[row_count]:
1206
- rows[dest_offset + i] = -1 # invalid
1207
- else:
1208
- row = wp.lower_bound(bsr_offsets, 0, row_count + 1, i + 1) - 1
1209
- rows[dest_offset + i] = row
1229
+ @wp.kernel(enable_backward=False)
1230
+ def _bsr_get_block_row(row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
1231
+ block = wp.tid()
1232
+ rows[block] = _bsr_row_index(bsr_offsets, row_count, block)
1210
1233
 
1211
1234
 
1212
1235
  @wp.kernel
@@ -1222,21 +1245,15 @@ def _bsr_axpy_add_block(
1222
1245
  ):
1223
1246
  i = wp.tid()
1224
1247
  row = rows[i + src_offset]
1225
-
1226
- if row < 0:
1227
- return
1228
-
1229
1248
  col = cols[i + src_offset]
1230
- beg = dst_offsets[row]
1231
- end = dst_offsets[row + 1]
1232
1249
 
1233
- block = wp.lower_bound(dst_columns, beg, end, col)
1234
-
1235
- dst_values[block] = dst_values[block] + scale * src_values[i]
1250
+ block = _bsr_block_index(row, col, dst_offsets, dst_columns)
1251
+ if block != -1:
1252
+ dst_values[block] += scale * src_values[i]
1236
1253
 
1237
1254
 
1238
1255
  class bsr_axpy_work_arrays:
1239
- """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls"""
1256
+ """Opaque structure for persisting :func:`bsr_axpy` temporary work buffers across calls."""
1240
1257
 
1241
1258
  def __init__(self):
1242
1259
  self._reset(None)
@@ -1266,25 +1283,33 @@ def bsr_axpy(
1266
1283
  y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
1267
1284
  alpha: Scalar = 1.0,
1268
1285
  beta: Scalar = 1.0,
1286
+ masked: bool = False,
1269
1287
  work_arrays: Optional[bsr_axpy_work_arrays] = None,
1270
1288
  ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1271
1289
  """
1272
- Performs the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices `x` and `y` and returns `y`.
1290
+ Perform the sparse matrix addition ``y := alpha * X + beta * y`` on BSR matrices ``x`` and ``y`` and return ``y``.
1273
1291
 
1274
- The `x` and `y` matrices are allowed to alias.
1292
+ The ``x`` and ``y`` matrices are allowed to alias.
1275
1293
 
1276
1294
  Args:
1277
1295
  x: Read-only right-hand-side.
1278
- y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1279
- alpha: Uniform scaling factor for `x`
1280
- beta: Uniform scaling factor for `y`
1281
- work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
1296
+ y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
1297
+ alpha: Uniform scaling factor for ``x``.
1298
+ beta: Uniform scaling factor for ``y``.
1299
+ masked: If ``True``, discard all blocks from ``x`` which are not
1300
+ existing non-zeros of ``y``.
1301
+ work_arrays: In most cases, this function will require the use of temporary storage.
1302
+ This storage can be reused across calls by passing an instance of
1303
+ :class:`bsr_axpy_work_arrays` in ``work_arrays``.
1282
1304
  """
1283
1305
 
1284
1306
  x, x_scale = _extract_matrix_and_scale(x)
1285
1307
  alpha *= x_scale
1286
1308
 
1287
1309
  if y is None:
1310
+ if masked:
1311
+ raise ValueError("Left-hand-side 'y' matrix must be provided for masked addition")
1312
+
1288
1313
  # If not output matrix is provided, allocate it for convenience
1289
1314
  y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
1290
1315
  beta = 0.0
@@ -1328,27 +1353,17 @@ def bsr_axpy(
1328
1353
  work_arrays._allocate(device, y, sum_nnz)
1329
1354
 
1330
1355
  wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
1331
- wp.launch(
1332
- kernel=_bsr_get_block_row,
1333
- device=device,
1334
- dim=y_nnz,
1335
- inputs=[0, y.nrow, y.offsets, work_arrays._sum_rows],
1336
- )
1356
+ y.uncompress_rows(out=work_arrays._sum_rows)
1337
1357
 
1338
1358
  wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
1339
- wp.launch(
1340
- kernel=_bsr_get_block_row,
1341
- device=device,
1342
- dim=x_nnz,
1343
- inputs=[y_nnz, x.nrow, x.offsets, work_arrays._sum_rows],
1344
- )
1359
+ x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
1345
1360
 
1346
1361
  # Save old y values before overwriting matrix
1347
1362
  wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y_nnz)
1348
1363
 
1349
1364
  # Increase dest array sizes if needed
1350
- if y.columns.shape[0] < sum_nnz:
1351
- y.columns = wp.empty(shape=(sum_nnz,), dtype=int, device=device)
1365
+ if not masked:
1366
+ _bsr_ensure_fits(y, nnz=sum_nnz)
1352
1367
 
1353
1368
  from warp.context import runtime
1354
1369
 
@@ -1370,6 +1385,7 @@ def bsr_axpy(
1370
1385
  ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1371
1386
  0,
1372
1387
  False,
1388
+ masked,
1373
1389
  ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1374
1390
  ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1375
1391
  0,
@@ -1377,8 +1393,6 @@ def bsr_axpy(
1377
1393
  nnz_event,
1378
1394
  )
1379
1395
 
1380
- _bsr_ensure_fits(y, nnz=sum_nnz)
1381
-
1382
1396
  y.values.zero_()
1383
1397
 
1384
1398
  wp.launch(
@@ -1416,55 +1430,90 @@ def bsr_axpy(
1416
1430
  return y
1417
1431
 
1418
1432
 
1419
- @wp.kernel
1433
+ @wp.kernel(enable_backward=False)
1420
1434
  def _bsr_mm_count_coeffs(
1435
+ y_ncol: int,
1421
1436
  z_nnz: int,
1422
1437
  x_offsets: wp.array(dtype=int),
1423
1438
  x_columns: wp.array(dtype=int),
1424
1439
  y_offsets: wp.array(dtype=int),
1425
- counts: wp.array(dtype=int),
1440
+ y_columns: wp.array(dtype=int),
1441
+ row_min: wp.array(dtype=int),
1442
+ block_counts: wp.array(dtype=int),
1426
1443
  ):
1427
1444
  row = wp.tid()
1428
- count = int(0)
1445
+ row_count = int(0)
1429
1446
 
1430
1447
  x_beg = x_offsets[row]
1431
1448
  x_end = x_offsets[row + 1]
1432
1449
 
1450
+ min_col = y_ncol
1451
+ max_col = int(0)
1452
+
1433
1453
  for x_block in range(x_beg, x_end):
1434
1454
  x_col = x_columns[x_block]
1435
- count += y_offsets[x_col + 1] - y_offsets[x_col]
1436
-
1437
- counts[row + 1] = count
1455
+ y_row_end = y_offsets[x_col + 1]
1456
+ y_row_beg = y_offsets[x_col]
1457
+ block_count = y_row_end - y_row_beg
1458
+ if block_count != 0:
1459
+ min_col = wp.min(y_columns[y_row_beg], min_col)
1460
+ max_col = wp.max(y_columns[y_row_end - 1], max_col)
1461
+
1462
+ block_counts[x_block + 1] = block_count
1463
+ row_count += block_count
1464
+
1465
+ if row_count > wp.max(0, max_col - min_col):
1466
+ row_min[row] = min_col
1467
+ block_counts[x_end] = max_col + 1 - min_col
1468
+ for x_block in range(x_beg, x_end - 1):
1469
+ block_counts[x_block + 1] = 0
1470
+ else:
1471
+ row_min[row] = -1
1438
1472
 
1439
1473
  if row == 0:
1440
- counts[0] = z_nnz
1474
+ block_counts[0] = z_nnz
1441
1475
 
1442
1476
 
1443
- @wp.kernel
1477
+ @wp.kernel(enable_backward=False)
1444
1478
  def _bsr_mm_list_coeffs(
1479
+ x_nrow: int,
1445
1480
  x_offsets: wp.array(dtype=int),
1446
1481
  x_columns: wp.array(dtype=int),
1447
1482
  y_offsets: wp.array(dtype=int),
1448
1483
  y_columns: wp.array(dtype=int),
1484
+ mm_row_min: wp.array(dtype=int),
1449
1485
  mm_offsets: wp.array(dtype=int),
1450
1486
  mm_rows: wp.array(dtype=int),
1451
1487
  mm_cols: wp.array(dtype=int),
1452
1488
  ):
1453
- row = wp.tid()
1454
- mm_block = mm_offsets[row]
1489
+ x_block = wp.tid()
1490
+ mm_block = mm_offsets[x_block]
1455
1491
 
1456
- x_beg = x_offsets[row]
1457
- x_end = x_offsets[row + 1]
1492
+ row = _bsr_row_index(x_offsets, x_nrow, x_block)
1493
+ if row == -1:
1494
+ return
1458
1495
 
1459
- for x_block in range(x_beg, x_end):
1496
+ row_min_col = mm_row_min[row]
1497
+ if row_min_col != -1:
1460
1498
  x_col = x_columns[x_block]
1461
1499
 
1462
1500
  y_beg = y_offsets[x_col]
1463
1501
  y_end = y_offsets[x_col + 1]
1502
+
1464
1503
  for y_block in range(y_beg, y_end):
1465
- mm_cols[mm_block] = y_columns[y_block]
1466
- mm_rows[mm_block] = row
1467
- mm_block += 1
1504
+ col = y_columns[y_block]
1505
+ mm_rows[mm_block + col - row_min_col] = row
1506
+ mm_cols[mm_block + col - row_min_col] = col
1507
+
1508
+ return
1509
+
1510
+ x_col = x_columns[x_block]
1511
+ y_beg = y_offsets[x_col]
1512
+ y_end = y_offsets[x_col + 1]
1513
+ for y_block in range(y_beg, y_end):
1514
+ mm_cols[mm_block] = y_columns[y_block]
1515
+ mm_rows[mm_block] = row
1516
+ mm_block += 1
1468
1517
 
1469
1518
 
1470
1519
  @wp.kernel
@@ -1483,7 +1532,10 @@ def _bsr_mm_compute_values(
1483
1532
  ):
1484
1533
  mm_block = wp.tid()
1485
1534
 
1486
- row = wp.lower_bound(mm_offsets, 0, mm_row_count + 1, mm_block + 1) - 1
1535
+ row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
1536
+ if row == -1:
1537
+ return
1538
+
1487
1539
  col = mm_cols[mm_block]
1488
1540
 
1489
1541
  mm_val = mm_values.dtype(type(alpha)(0.0))
@@ -1492,26 +1544,23 @@ def _bsr_mm_compute_values(
1492
1544
  x_end = x_offsets[row + 1]
1493
1545
  for x_block in range(x_beg, x_end):
1494
1546
  x_col = x_columns[x_block]
1495
- y_beg = y_offsets[x_col]
1496
- y_end = y_offsets[x_col + 1]
1497
-
1498
- y_block = wp.lower_bound(y_columns, y_beg, y_end, col)
1499
- if y_block < y_end:
1500
- if y_columns[y_block] == col:
1501
- mm_val += x_values[x_block] * y_values[y_block]
1547
+ y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
1548
+ if y_block != -1:
1549
+ mm_val += x_values[x_block] * y_values[y_block]
1502
1550
 
1503
1551
  mm_values[mm_block] += alpha * mm_val
1504
1552
 
1505
1553
 
1506
1554
  class bsr_mm_work_arrays:
1507
- """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls"""
1555
+ """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
1508
1556
 
1509
1557
  def __init__(self):
1510
1558
  self._reset(None)
1511
1559
 
1512
1560
  def _reset(self, device):
1513
1561
  self.device = device
1514
- self._mm_row_counts = None
1562
+ self._mm_row_min = None
1563
+ self._mm_block_counts = None
1515
1564
  self._mm_rows = None
1516
1565
  self._mm_cols = None
1517
1566
  self._old_z_values = None
@@ -1519,7 +1568,7 @@ class bsr_mm_work_arrays:
1519
1568
  self._old_z_columns = None
1520
1569
  self._mm_nnz = 0
1521
1570
 
1522
- def _allocate_stage_1(self, device, z: BsrMatrix, beta: float, z_aliasing: bool):
1571
+ def _allocate_stage_1(self, device, x_nnz: int, z: BsrMatrix, beta: float, z_aliasing: bool):
1523
1572
  if self.device != device:
1524
1573
  self._reset(device)
1525
1574
 
@@ -1527,8 +1576,10 @@ class bsr_mm_work_arrays:
1527
1576
  z_nnz = z.nnz_sync()
1528
1577
  self._copied_z_nnz = z_nnz if beta != 0.0 or z_aliasing else 0
1529
1578
 
1530
- if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
1531
- self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
1579
+ if self._mm_row_min is None or self._mm_block_counts.size < z.nrow + 1:
1580
+ self._mm_row_min = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
1581
+ if self._mm_block_counts is None or self._mm_block_counts.size < x_nnz + 1:
1582
+ self._mm_block_counts = wp.empty(shape=(x_nnz + 1,), dtype=int, device=self.device)
1532
1583
 
1533
1584
  if self._copied_z_nnz > 0:
1534
1585
  if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
@@ -1555,25 +1606,31 @@ def bsr_mm(
1555
1606
  z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
1556
1607
  alpha: Scalar = 1.0,
1557
1608
  beta: Scalar = 0.0,
1609
+ masked: bool = False,
1558
1610
  work_arrays: Optional[bsr_mm_work_arrays] = None,
1559
1611
  reuse_topology: bool = False,
1560
1612
  ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1561
1613
  """
1562
- Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
1614
+ Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
1563
1615
 
1564
- The `x`, `y` and `z` matrices are allowed to alias.
1565
- If the matrix `z` is not provided as input, it will be allocated and treated as zero.
1616
+ The ``x``, ``y`` and ``z`` matrices are allowed to alias.
1617
+ If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
1566
1618
 
1567
1619
  Args:
1568
1620
  x: Read-only left factor of the matrix-matrix product.
1569
1621
  y: Read-only right factor of the matrix-matrix product.
1570
- z: Mutable left-hand-side. If `z` is not provided, it will be allocated and treated as zero.
1571
- alpha: Uniform scaling factor for the ``x * y`` product
1572
- beta: Uniform scaling factor for `z`
1573
- work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
1574
- reuse_topology: If True, reuse the product topology information stored in `work_arrays` rather than recompute it from scratch.
1575
- The matrices x, y and z must be structurally similar to the previous call in which `work_arrays` were populated.
1576
- This is necessary for `bsr_mm` to be captured in a CUDA graph.
1622
+ z: Mutable left-hand-side. If ``z`` is not provided, it will be allocated and treated as zero.
1623
+ alpha: Uniform scaling factor for the ``x @ y`` product
1624
+ beta: Uniform scaling factor for ``z``
1625
+ masked: If ``True``, ignore all blocks from ``x @ y`` which are not existing non-zeros of ``y``
1626
+ work_arrays: In most cases, this function will require the use of temporary storage.
1627
+ This storage can be reused across calls by passing an instance of
1628
+ :class:`bsr_mm_work_arrays` in ``work_arrays``.
1629
+ reuse_topology: If ``True``, reuse the product topology information
1630
+ stored in ``work_arrays`` rather than recompute it from scratch.
1631
+ The matrices ``x``, ``y`` and ``z`` must be structurally similar to
1632
+ the previous call in which ``work_arrays`` were populated.
1633
+ This is necessary for ``bsr_mm`` to be captured in a CUDA graph.
1577
1634
  """
1578
1635
 
1579
1636
  x, x_scale = _extract_matrix_and_scale(x)
@@ -1582,12 +1639,15 @@ def bsr_mm(
1582
1639
  alpha *= y_scale
1583
1640
 
1584
1641
  if z is None:
1642
+ if masked:
1643
+ raise ValueError("Left-hand-side 'z' matrix must be provided for masked multiplication")
1644
+
1585
1645
  # If not output matrix is provided, allocate it for convenience
1586
1646
  z_block_shape = (x.block_shape[0], y.block_shape[1])
1587
1647
  if z_block_shape == (1, 1):
1588
1648
  z_block_type = x.scalar_type
1589
1649
  else:
1590
- z_block_type = wp.types.matrix(shape=z_block_shape, dtype=x.scalar_type)
1650
+ z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
1591
1651
  z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
1592
1652
  beta = 0.0
1593
1653
 
@@ -1613,14 +1673,22 @@ def bsr_mm(
1613
1673
  # Easy case
1614
1674
  return bsr_scale(z, beta)
1615
1675
 
1616
- if not isinstance(alpha, z.scalar_type):
1617
- alpha = z.scalar_type(alpha)
1618
- if not isinstance(beta, z.scalar_type):
1619
- beta = z.scalar_type(beta)
1620
-
1621
1676
  z_aliasing = z == x or z == y
1622
1677
 
1623
- if reuse_topology:
1678
+ if masked:
1679
+ # no need to copy z, scale in-place
1680
+ copied_z_nnz = 0
1681
+ mm_nnz = z.nnz
1682
+
1683
+ if z_aliasing:
1684
+ raise ValueError("`masked=True` is not supported for aliased inputs")
1685
+
1686
+ if beta == 0.0:
1687
+ # do not bsr_scale(0), this would not preserve topology
1688
+ z.values.zero_()
1689
+ else:
1690
+ bsr_scale(z, beta)
1691
+ elif reuse_topology:
1624
1692
  if work_arrays is None:
1625
1693
  raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
1626
1694
 
@@ -1633,133 +1701,142 @@ def bsr_mm(
1633
1701
  if work_arrays is None:
1634
1702
  work_arrays = bsr_mm_work_arrays()
1635
1703
 
1636
- work_arrays._allocate_stage_1(device, z, beta, z_aliasing)
1704
+ work_arrays._allocate_stage_1(device, x.nnz, z, beta, z_aliasing)
1637
1705
  copied_z_nnz = work_arrays._copied_z_nnz
1638
1706
 
1639
1707
  # Prefix sum of number of (unmerged) mm blocks per row
1708
+ work_arrays._mm_block_counts.zero_()
1640
1709
  wp.launch(
1641
1710
  kernel=_bsr_mm_count_coeffs,
1642
1711
  device=device,
1643
1712
  dim=z.nrow,
1644
1713
  inputs=[
1714
+ y.ncol,
1645
1715
  copied_z_nnz,
1646
1716
  x.offsets,
1647
1717
  x.columns,
1648
1718
  y.offsets,
1649
- work_arrays._mm_row_counts,
1719
+ y.columns,
1720
+ work_arrays._mm_row_min,
1721
+ work_arrays._mm_block_counts,
1650
1722
  ],
1651
1723
  )
1652
- warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
1724
+ warp.utils.array_scan(work_arrays._mm_block_counts, work_arrays._mm_block_counts)
1653
1725
 
1654
1726
  # Get back total counts on host -- we need a synchronization here
1655
1727
  # Use pinned buffer from z, we are going to need it later anyway
1656
1728
  nnz_buf, _ = z._nnz_transfer_buf_and_event()
1657
1729
  stream = wp.get_stream(device) if device.is_cuda else None
1658
- wp.copy(dest=nnz_buf, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1, stream=stream)
1730
+ wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
1659
1731
  if device.is_cuda:
1660
1732
  wp.synchronize_stream(stream)
1661
1733
  mm_nnz = int(nnz_buf.numpy()[0])
1662
1734
 
1735
+ if mm_nnz == copied_z_nnz:
1736
+ # x@y = 0
1737
+ return bsr_scale(z, beta)
1738
+
1663
1739
  work_arrays._allocate_stage_2(mm_nnz)
1664
1740
 
1665
1741
  # If z has a non-zero scale, save current data before overwriting it
1666
1742
  if copied_z_nnz > 0:
1667
1743
  # Copy z row and column indices
1668
1744
  wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1669
- wp.launch(
1670
- kernel=_bsr_get_block_row,
1671
- device=device,
1672
- dim=copied_z_nnz,
1673
- inputs=[0, z.nrow, z.offsets, work_arrays._mm_rows],
1674
- )
1745
+ z.uncompress_rows(out=work_arrays._mm_rows)
1675
1746
  if z_aliasing:
1676
1747
  # If z is aliasing with x or y, need to save topology as well
1677
1748
  wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1678
1749
  wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1679
1750
 
1680
1751
  # Fill unmerged mm blocks rows and columns
1752
+ work_arrays._mm_rows[copied_z_nnz:].fill_(-1)
1681
1753
  wp.launch(
1682
1754
  kernel=_bsr_mm_list_coeffs,
1683
1755
  device=device,
1684
- dim=z.nrow,
1756
+ dim=x.nnz,
1685
1757
  inputs=[
1758
+ x.nrow,
1686
1759
  x.offsets,
1687
1760
  x.columns,
1688
1761
  y.offsets,
1689
1762
  y.columns,
1690
- work_arrays._mm_row_counts,
1763
+ work_arrays._mm_row_min,
1764
+ work_arrays._mm_block_counts,
1691
1765
  work_arrays._mm_rows,
1692
1766
  work_arrays._mm_cols,
1693
1767
  ],
1694
1768
  )
1695
1769
 
1770
+ alpha = z.scalar_type(alpha)
1771
+ beta = z.scalar_type(beta)
1772
+
1696
1773
  if copied_z_nnz > 0:
1697
1774
  # Save current z values in temporary buffer
1698
1775
  wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1699
1776
 
1700
- # Increase dest array size if needed
1701
- if z.columns.shape[0] < mm_nnz:
1702
- z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1777
+ if not masked:
1778
+ # Increase dest array size if needed
1779
+ if z.columns.shape[0] < mm_nnz:
1780
+ z.columns = wp.empty(shape=(mm_nnz,), dtype=int, device=device)
1703
1781
 
1704
- from warp.context import runtime
1782
+ from warp.context import runtime
1705
1783
 
1706
- if device.is_cpu:
1707
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
1708
- else:
1709
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
1784
+ if device.is_cpu:
1785
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
1786
+ else:
1787
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
1710
1788
 
1711
- nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
1789
+ nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
1712
1790
 
1713
- with wp.ScopedDevice(z.device):
1714
- native_func(
1715
- z.block_shape[0],
1716
- z.block_shape[1],
1717
- z.nrow,
1718
- mm_nnz,
1719
- ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1720
- ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1721
- 0,
1722
- False,
1723
- ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1724
- ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1725
- 0,
1726
- ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1727
- nnz_event,
1728
- )
1791
+ with wp.ScopedDevice(z.device):
1792
+ native_func(
1793
+ z.block_shape[0],
1794
+ z.block_shape[1],
1795
+ z.nrow,
1796
+ mm_nnz,
1797
+ ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1798
+ ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1799
+ 0,
1800
+ False,
1801
+ masked,
1802
+ ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1803
+ ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1804
+ 0,
1805
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1806
+ nnz_event,
1807
+ )
1729
1808
 
1730
- # Resize z to fit mm result if necessary
1731
- # If we are not reusing the product topology, this needs another synchronization
1732
- if not reuse_topology:
1733
- work_arrays.result_nnz = z.nnz_sync()
1734
- _bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
1809
+ # Resize z to fit mm result if necessary
1810
+ # If we are not reusing the product topology, this needs another synchronization
1811
+ if not reuse_topology:
1812
+ work_arrays.result_nnz = z.nnz_sync()
1735
1813
 
1736
- z.values.zero_()
1814
+ _bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
1815
+ z.values.zero_()
1737
1816
 
1738
- if copied_z_nnz > 0:
1739
- # Add back original z values
1740
- wp.launch(
1741
- kernel=_bsr_axpy_add_block,
1742
- device=device,
1743
- dim=copied_z_nnz,
1744
- inputs=[
1745
- 0,
1746
- beta,
1747
- work_arrays._mm_rows,
1748
- work_arrays._mm_cols,
1749
- z.offsets,
1750
- z.columns,
1751
- work_arrays._old_z_values,
1752
- z.values,
1753
- ],
1754
- )
1817
+ if copied_z_nnz > 0:
1818
+ # Add back original z values
1819
+ wp.launch(
1820
+ kernel=_bsr_axpy_add_block,
1821
+ device=device,
1822
+ dim=copied_z_nnz,
1823
+ inputs=[
1824
+ 0,
1825
+ beta,
1826
+ work_arrays._mm_rows,
1827
+ work_arrays._mm_cols,
1828
+ z.offsets,
1829
+ z.columns,
1830
+ work_arrays._old_z_values,
1831
+ z.values,
1832
+ ],
1833
+ )
1755
1834
 
1756
1835
  # Add mm blocks to z values
1757
- if (warp.types.type_is_matrix(x.values.dtype) or warp.types.type_is_matrix(y.values.dtype)) and not (
1758
- warp.types.type_is_matrix(z.values.dtype)
1759
- ):
1836
+ if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
1760
1837
  # Result block type is scalar, but operands are matrices
1761
1838
  # Cast result to (1x1) matrix to perform multiplication
1762
- mm_values = z.values.view(wp.types.matrix(shape=(1, 1), dtype=z.scalar_type))
1839
+ mm_values = z.values.view(wp.mat(shape=(1, 1), dtype=z.scalar_type))
1763
1840
  else:
1764
1841
  mm_values = z.values
1765
1842
 
@@ -1832,15 +1909,31 @@ def _bsr_mv_transpose_kernel(
1832
1909
  wp.atomic_add(y, A_columns[block], v)
1833
1910
 
1834
1911
 
1835
- def _bsr_mv_as_vec_array(array: wp.array) -> wp.array:
1836
- if array.ndim == 1:
1912
+ def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
1913
+ # cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
1914
+
1915
+ scalar_count = array.size * type_length(array.dtype)
1916
+ if scalar_count != expected_scalar_count:
1917
+ raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
1918
+
1919
+ if array.ndim == 1 and types_equal(array.dtype, dtype):
1837
1920
  return array
1838
1921
 
1922
+ if type_scalar_type(array.dtype) != type_scalar_type(dtype):
1923
+ raise ValueError(f"Incompatible scalar types, {type_repr(array.dtype)} vs {type_repr(dtype)}")
1924
+
1839
1925
  if array.ndim > 2:
1840
1926
  raise ValueError(f"Incompatible array number of dimensions {array.ndim}")
1841
1927
 
1842
1928
  if not array.is_contiguous:
1843
- raise ValueError("2d array must be contiguous")
1929
+ raise ValueError("Array must be contiguous")
1930
+
1931
+ vec_length = type_length(dtype)
1932
+ vec_count = scalar_count // vec_length
1933
+ if vec_count * vec_length != scalar_count:
1934
+ raise ValueError(
1935
+ f"Array of shape {array.shape} and type {type_repr(array.dtype)} cannot be reshaped to an array of type {type_repr(dtype)}"
1936
+ )
1844
1937
 
1845
1938
  def vec_view(array):
1846
1939
  return wp.array(
@@ -1848,8 +1941,8 @@ def _bsr_mv_as_vec_array(array: wp.array) -> wp.array:
1848
1941
  ptr=array.ptr,
1849
1942
  capacity=array.capacity,
1850
1943
  device=array.device,
1851
- dtype=wp.vec(length=array.shape[1], dtype=array.dtype),
1852
- shape=array.shape[0],
1944
+ dtype=dtype,
1945
+ shape=vec_count,
1853
1946
  grad=None if array.grad is None else vec_view(array.grad),
1854
1947
  )
1855
1948
 
@@ -1867,20 +1960,20 @@ def bsr_mv(
1867
1960
  transpose: bool = False,
1868
1961
  work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1869
1962
  ) -> "Array[Vector[Rows, Scalar] | Scalar]":
1870
- """
1871
- Performs the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and returns `y`.
1963
+ """Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
1872
1964
 
1873
- The `x` and `y` vectors are allowed to alias.
1965
+ The ``x`` and ``y`` vectors are allowed to alias.
1874
1966
 
1875
1967
  Args:
1876
1968
  A: Read-only, left matrix factor of the matrix-vector product.
1877
1969
  x: Read-only, right vector factor of the matrix-vector product.
1878
- y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1879
- alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1880
- beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1881
- transpose: If ``True``, use the transpose of the matrix `A`. In this case the result is **non-deterministic**.
1882
- work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1883
- will be used for this purpose, otherwise a temporary allocation will be performed.
1970
+ y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
1971
+ alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
1972
+ beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
1973
+ transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
1974
+ work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
1975
+ If provided, the ``work_buffer`` array will be used for this purpose,
1976
+ otherwise a temporary allocation will be performed.
1884
1977
  """
1885
1978
 
1886
1979
  A, A_scale = _extract_matrix_and_scale(A)
@@ -1900,22 +1993,11 @@ def bsr_mv(
1900
1993
  y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype)
1901
1994
  beta = 0.0
1902
1995
 
1903
- if not isinstance(alpha, A.scalar_type):
1904
- alpha = A.scalar_type(alpha)
1905
- if not isinstance(beta, A.scalar_type):
1906
- beta = A.scalar_type(beta)
1996
+ alpha = A.scalar_type(alpha)
1997
+ beta = A.scalar_type(beta)
1907
1998
 
1908
1999
  if A.values.device != x.device or A.values.device != y.device:
1909
- raise ValueError("A, x and y must reside on the same device")
1910
-
1911
- if x.shape[0] != ncol:
1912
- raise ValueError("Number of columns of A must match number of rows of x")
1913
- if y.shape[0] != nrow:
1914
- raise ValueError("Number of rows of A must match number of rows of y")
1915
-
1916
- # View 2d arrays as arrays of vecs
1917
- x = _bsr_mv_as_vec_array(x)
1918
- y = _bsr_mv_as_vec_array(y)
2000
+ raise ValueError("A, x, and y must reside on the same device")
1919
2001
 
1920
2002
  if x.ptr == y.ptr:
1921
2003
  # Aliasing case, need temporary storage
@@ -1923,24 +2005,29 @@ def bsr_mv(
1923
2005
  work_buffer = wp.empty_like(y)
1924
2006
  elif work_buffer.size < y.size:
1925
2007
  raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
1926
- elif not wp.types.types_equal(work_buffer.dtype, y.dtype):
1927
- raise ValueError(f"Work buffer must have same data type as y, {wp.types.type_repr(y.dtype)}")
2008
+ elif not types_equal(work_buffer.dtype, y.dtype):
2009
+ raise ValueError(f"Work buffer must have same data type as y, {type_repr(y.dtype)}")
1928
2010
 
1929
2011
  # Save old y values before overwriting vector
1930
2012
  wp.copy(dest=work_buffer, src=y, count=y.size)
1931
2013
  x = work_buffer
1932
2014
 
1933
2015
  # Promote scalar vectors to length-1 vecs and conversely
1934
- if warp.types.type_is_matrix(A.values.dtype):
1935
- if block_shape[0] == 1 and y.dtype == A.scalar_type:
1936
- y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1937
- if block_shape[1] == 1 and x.dtype == A.scalar_type:
1938
- x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
2016
+ if type_is_matrix(A.values.dtype):
2017
+ x_dtype = wp.vec(length=block_shape[1], dtype=A.scalar_type)
2018
+ y_dtype = wp.vec(length=block_shape[0], dtype=A.scalar_type)
1939
2019
  else:
1940
- if block_shape[0] == 1 and y.dtype != A.scalar_type:
1941
- y = y.view(dtype=A.scalar_type)
1942
- if block_shape[1] == 1 and x.dtype != A.scalar_type:
1943
- x = x.view(dtype=A.scalar_type)
2020
+ x_dtype = A.scalar_type
2021
+ y_dtype = A.scalar_type
2022
+
2023
+ try:
2024
+ x_view = _vec_array_view(x, x_dtype, expected_scalar_count=ncol * block_shape[1])
2025
+ except ValueError as err:
2026
+ raise ValueError("Incompatible 'x' vector for bsr_mv") from err
2027
+ try:
2028
+ y_view = _vec_array_view(y, y_dtype, expected_scalar_count=nrow * block_shape[0])
2029
+ except ValueError as err:
2030
+ raise ValueError("Incompatible 'y' vector for bsr_mv") from err
1944
2031
 
1945
2032
  if transpose:
1946
2033
  if beta.value == 0.0:
@@ -1957,14 +2044,14 @@ def bsr_mv(
1957
2044
  kernel=_bsr_mv_transpose_kernel,
1958
2045
  device=A.values.device,
1959
2046
  dim=ncol,
1960
- inputs=[alpha, A.offsets, A.columns, A.values, x, y],
2047
+ inputs=[alpha, A.offsets, A.columns, A.values, x_view, y_view],
1961
2048
  )
1962
2049
  else:
1963
2050
  wp.launch(
1964
2051
  kernel=_bsr_mv_kernel,
1965
2052
  device=A.values.device,
1966
2053
  dim=nrow,
1967
- inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
2054
+ inputs=[alpha, A.offsets, A.columns, A.values, x_view, beta, y_view],
1968
2055
  )
1969
2056
 
1970
2057
  return y