warp-lang 1.2.2__py3-none-manylinux2014_x86_64.whl → 1.3.1__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +6 -2
  5. warp/builtins.py +1412 -888
  6. warp/codegen.py +503 -166
  7. warp/config.py +48 -18
  8. warp/context.py +400 -198
  9. warp/dlpack.py +8 -0
  10. warp/examples/assets/bunny.usd +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  12. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  13. warp/examples/benchmarks/benchmark_launches.py +1 -1
  14. warp/examples/core/example_cupy.py +78 -0
  15. warp/examples/fem/example_apic_fluid.py +17 -36
  16. warp/examples/fem/example_burgers.py +9 -18
  17. warp/examples/fem/example_convection_diffusion.py +7 -17
  18. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  19. warp/examples/fem/example_deformed_geometry.py +11 -22
  20. warp/examples/fem/example_diffusion.py +7 -18
  21. warp/examples/fem/example_diffusion_3d.py +24 -28
  22. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  23. warp/examples/fem/example_magnetostatics.py +190 -0
  24. warp/examples/fem/example_mixed_elasticity.py +111 -80
  25. warp/examples/fem/example_navier_stokes.py +30 -34
  26. warp/examples/fem/example_nonconforming_contact.py +290 -0
  27. warp/examples/fem/example_stokes.py +17 -32
  28. warp/examples/fem/example_stokes_transfer.py +12 -21
  29. warp/examples/fem/example_streamlines.py +350 -0
  30. warp/examples/fem/utils.py +936 -0
  31. warp/fabric.py +5 -2
  32. warp/fem/__init__.py +13 -3
  33. warp/fem/cache.py +161 -11
  34. warp/fem/dirichlet.py +37 -28
  35. warp/fem/domain.py +105 -14
  36. warp/fem/field/__init__.py +14 -3
  37. warp/fem/field/field.py +454 -11
  38. warp/fem/field/nodal_field.py +33 -18
  39. warp/fem/geometry/deformed_geometry.py +50 -15
  40. warp/fem/geometry/hexmesh.py +12 -24
  41. warp/fem/geometry/nanogrid.py +106 -31
  42. warp/fem/geometry/quadmesh_2d.py +6 -11
  43. warp/fem/geometry/tetmesh.py +103 -61
  44. warp/fem/geometry/trimesh_2d.py +98 -47
  45. warp/fem/integrate.py +231 -186
  46. warp/fem/operator.py +14 -9
  47. warp/fem/quadrature/pic_quadrature.py +35 -9
  48. warp/fem/quadrature/quadrature.py +119 -32
  49. warp/fem/space/basis_space.py +98 -22
  50. warp/fem/space/collocated_function_space.py +3 -1
  51. warp/fem/space/function_space.py +7 -2
  52. warp/fem/space/grid_2d_function_space.py +3 -3
  53. warp/fem/space/grid_3d_function_space.py +4 -4
  54. warp/fem/space/hexmesh_function_space.py +3 -2
  55. warp/fem/space/nanogrid_function_space.py +12 -14
  56. warp/fem/space/partition.py +45 -47
  57. warp/fem/space/restriction.py +19 -16
  58. warp/fem/space/shape/cube_shape_function.py +91 -3
  59. warp/fem/space/shape/shape_function.py +7 -0
  60. warp/fem/space/shape/square_shape_function.py +32 -0
  61. warp/fem/space/shape/tet_shape_function.py +11 -7
  62. warp/fem/space/shape/triangle_shape_function.py +10 -1
  63. warp/fem/space/topology.py +116 -42
  64. warp/fem/types.py +8 -1
  65. warp/fem/utils.py +301 -83
  66. warp/native/array.h +16 -0
  67. warp/native/builtin.h +0 -15
  68. warp/native/cuda_util.cpp +14 -6
  69. warp/native/exports.h +1348 -1308
  70. warp/native/quat.h +79 -0
  71. warp/native/rand.h +27 -4
  72. warp/native/sparse.cpp +83 -81
  73. warp/native/sparse.cu +381 -453
  74. warp/native/vec.h +64 -0
  75. warp/native/volume.cpp +40 -49
  76. warp/native/volume_builder.cu +2 -3
  77. warp/native/volume_builder.h +12 -17
  78. warp/native/warp.cu +3 -3
  79. warp/native/warp.h +69 -59
  80. warp/render/render_opengl.py +17 -9
  81. warp/sim/articulation.py +117 -17
  82. warp/sim/collide.py +35 -29
  83. warp/sim/model.py +123 -18
  84. warp/sim/render.py +3 -1
  85. warp/sparse.py +867 -203
  86. warp/stubs.py +312 -541
  87. warp/tape.py +29 -1
  88. warp/tests/disabled_kinematics.py +1 -1
  89. warp/tests/test_adam.py +1 -1
  90. warp/tests/test_arithmetic.py +1 -1
  91. warp/tests/test_array.py +58 -1
  92. warp/tests/test_array_reduce.py +1 -1
  93. warp/tests/test_async.py +1 -1
  94. warp/tests/test_atomic.py +1 -1
  95. warp/tests/test_bool.py +1 -1
  96. warp/tests/test_builtins_resolution.py +1 -1
  97. warp/tests/test_bvh.py +6 -1
  98. warp/tests/test_closest_point_edge_edge.py +1 -1
  99. warp/tests/test_codegen.py +91 -1
  100. warp/tests/test_compile_consts.py +1 -1
  101. warp/tests/test_conditional.py +1 -1
  102. warp/tests/test_copy.py +1 -1
  103. warp/tests/test_ctypes.py +1 -1
  104. warp/tests/test_dense.py +1 -1
  105. warp/tests/test_devices.py +1 -1
  106. warp/tests/test_dlpack.py +1 -1
  107. warp/tests/test_examples.py +33 -4
  108. warp/tests/test_fabricarray.py +5 -2
  109. warp/tests/test_fast_math.py +1 -1
  110. warp/tests/test_fem.py +213 -6
  111. warp/tests/test_fp16.py +1 -1
  112. warp/tests/test_func.py +1 -1
  113. warp/tests/test_future_annotations.py +90 -0
  114. warp/tests/test_generics.py +1 -1
  115. warp/tests/test_grad.py +1 -1
  116. warp/tests/test_grad_customs.py +1 -1
  117. warp/tests/test_grad_debug.py +247 -0
  118. warp/tests/test_hash_grid.py +6 -1
  119. warp/tests/test_implicit_init.py +354 -0
  120. warp/tests/test_import.py +1 -1
  121. warp/tests/test_indexedarray.py +1 -1
  122. warp/tests/test_intersect.py +1 -1
  123. warp/tests/test_jax.py +1 -1
  124. warp/tests/test_large.py +1 -1
  125. warp/tests/test_launch.py +1 -1
  126. warp/tests/test_lerp.py +1 -1
  127. warp/tests/test_linear_solvers.py +1 -1
  128. warp/tests/test_lvalue.py +1 -1
  129. warp/tests/test_marching_cubes.py +5 -2
  130. warp/tests/test_mat.py +34 -35
  131. warp/tests/test_mat_lite.py +2 -1
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_math.py +1 -1
  134. warp/tests/test_matmul.py +20 -16
  135. warp/tests/test_matmul_lite.py +1 -1
  136. warp/tests/test_mempool.py +1 -1
  137. warp/tests/test_mesh.py +5 -2
  138. warp/tests/test_mesh_query_aabb.py +1 -1
  139. warp/tests/test_mesh_query_point.py +1 -1
  140. warp/tests/test_mesh_query_ray.py +1 -1
  141. warp/tests/test_mlp.py +1 -1
  142. warp/tests/test_model.py +1 -1
  143. warp/tests/test_module_hashing.py +77 -1
  144. warp/tests/test_modules_lite.py +1 -1
  145. warp/tests/test_multigpu.py +1 -1
  146. warp/tests/test_noise.py +1 -1
  147. warp/tests/test_operators.py +1 -1
  148. warp/tests/test_options.py +1 -1
  149. warp/tests/test_overwrite.py +542 -0
  150. warp/tests/test_peer.py +1 -1
  151. warp/tests/test_pinned.py +1 -1
  152. warp/tests/test_print.py +1 -1
  153. warp/tests/test_quat.py +15 -1
  154. warp/tests/test_rand.py +1 -1
  155. warp/tests/test_reload.py +1 -1
  156. warp/tests/test_rounding.py +1 -1
  157. warp/tests/test_runlength_encode.py +1 -1
  158. warp/tests/test_scalar_ops.py +95 -0
  159. warp/tests/test_sim_grad.py +1 -1
  160. warp/tests/test_sim_kinematics.py +1 -1
  161. warp/tests/test_smoothstep.py +1 -1
  162. warp/tests/test_sparse.py +82 -15
  163. warp/tests/test_spatial.py +1 -1
  164. warp/tests/test_special_values.py +2 -11
  165. warp/tests/test_streams.py +11 -1
  166. warp/tests/test_struct.py +1 -1
  167. warp/tests/test_tape.py +1 -1
  168. warp/tests/test_torch.py +194 -1
  169. warp/tests/test_transient_module.py +1 -1
  170. warp/tests/test_types.py +1 -1
  171. warp/tests/test_utils.py +1 -1
  172. warp/tests/test_vec.py +15 -63
  173. warp/tests/test_vec_lite.py +2 -1
  174. warp/tests/test_vec_scalar_ops.py +65 -1
  175. warp/tests/test_verify_fp.py +1 -1
  176. warp/tests/test_volume.py +28 -2
  177. warp/tests/test_volume_write.py +1 -1
  178. warp/tests/unittest_serial.py +1 -1
  179. warp/tests/unittest_suites.py +9 -1
  180. warp/tests/walkthrough_debug.py +1 -1
  181. warp/thirdparty/unittest_parallel.py +2 -5
  182. warp/torch.py +103 -41
  183. warp/types.py +341 -224
  184. warp/utils.py +11 -2
  185. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
  186. warp_lang-1.3.1.dist-info/RECORD +368 -0
  187. warp/examples/fem/bsr_utils.py +0 -378
  188. warp/examples/fem/mesh_utils.py +0 -133
  189. warp/examples/fem/plot_utils.py +0 -292
  190. warp_lang-1.2.2.dist-info/RECORD +0 -359
  191. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -1,3 +1,4 @@
1
+ import ctypes
1
2
  from typing import Any, Generic, Optional, Tuple, TypeVar, Union
2
3
 
3
4
  import warp as wp
@@ -31,7 +32,7 @@ class BsrMatrix(Generic[_BlockType]):
31
32
  Attributes:
32
33
  nrow (int): Number of rows of blocks
33
34
  ncol (int): Number of columns of blocks
34
- nnz (int): Number of non-zero blocks: must be equal to ``offsets[nrow-1]``, cached on host for convenience
35
+ nnz (int): Upper bound for the number of non-zero blocks, used for dimensioning launches; the exact number is at ``offsets[nrow-1]``. See also :meth:`nnz_sync`.
35
36
  offsets (Array[int]): Array of size at least ``1 + nrows`` such that the start and end indices of the blocks of row ``r`` are ``offsets[r]`` and ``offsets[r+1]``, respectively.
36
37
  columns (Array[int]): Array of size at least equal to ``nnz`` containing block column indices
37
38
  values (Array[BlockType]): Array of size at least equal to ``nnz`` containing block values
@@ -68,6 +69,111 @@ class BsrMatrix(Generic[_BlockType]):
68
69
  """Device on which offsets, columns and values are allocated -- assumed to be the same for all three arrays"""
69
70
  return self.values.device
70
71
 
72
+ def nnz_sync(self):
73
+ """Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
74
+ and updates the nnz upper bound.
75
+
76
+ See also :meth:`copy_nnz_async`
77
+ """
78
+
79
+ if self._is_nnz_transfer_setup():
80
+ if self.device.is_cuda:
81
+ wp.synchronize_event(self._nnz_event)
82
+ self.nnz = int(self._nnz_buf.numpy()[0])
83
+ return self.nnz
84
+
85
+ def copy_nnz_async(self, known_nnz: int = None):
86
+ """
87
+ Starts the asynchronous transfer of the exact nnz from the device offsets array to host, and records an event for completion.
88
+ Needs to be called whenever the offsets array has been modified from outside ``warp.sparse``.
89
+
90
+ See also :meth:`nnz_sync`
91
+ """
92
+ if known_nnz is not None:
93
+ self.nnz = int(known_nnz)
94
+ else:
95
+ self._setup_nnz_transfer()
96
+
97
+ # If a transfer is already ongoing, or if the actual nnz is unknown, schedule a new transfer
98
+ if self._is_nnz_transfer_setup():
99
+ stream = wp.get_stream(self.device) if self.device.is_cuda else None
100
+ wp.copy(src=self.offsets, dest=self._nnz_buf, src_offset=self.nrow, count=1, stream=stream)
101
+ if self.device.is_cuda:
102
+ stream.record_event(self._nnz_event)
103
+
104
+ def _setup_nnz_transfer(self):
105
+ if self._is_nnz_transfer_setup():
106
+ return
107
+
108
+ BsrMatrix.__setattr__(
109
+ self, "_nnz_buf", wp.zeros(dtype=int, shape=(1,), device="cpu", pinned=self.device.is_cuda)
110
+ )
111
+ if self.device.is_cuda:
112
+ BsrMatrix.__setattr__(self, "_nnz_event", wp.Event(self.device))
113
+
114
+ def _is_nnz_transfer_setup(self):
115
+ return hasattr(self, "_nnz_buf")
116
+
117
+ def _nnz_transfer_buf_and_event(self):
118
+ self._setup_nnz_transfer()
119
+
120
+ if not self.device.is_cuda:
121
+ return self._nnz_buf, ctypes.c_void_p(None)
122
+ return self._nnz_buf, self._nnz_event.cuda_event
123
+
124
+ # Overloaded math operators
125
+ def __add__(self, y):
126
+ return bsr_axpy(y, bsr_copy(self))
127
+
128
+ def __iadd__(self, y):
129
+ return bsr_axpy(y, self)
130
+
131
+ def __radd__(self, x):
132
+ return bsr_axpy(x, bsr_copy(self))
133
+
134
+ def __sub__(self, y):
135
+ return bsr_axpy(y, bsr_copy(self), alpha=-1.0)
136
+
137
+ def __rsub__(self, x):
138
+ return bsr_axpy(x, bsr_copy(self), beta=-1.0)
139
+
140
+ def __isub__(self, y):
141
+ return bsr_axpy(y, self, alpha=-1.0)
142
+
143
+ def __mul__(self, y):
144
+ return _BsrScalingExpression(self, y)
145
+
146
+ def __rmul__(self, x):
147
+ return _BsrScalingExpression(self, x)
148
+
149
+ def __imul__(self, y):
150
+ return bsr_scale(self, y)
151
+
152
+ def __matmul__(self, y):
153
+ if isinstance(y, wp.array):
154
+ return bsr_mv(self, y)
155
+
156
+ return bsr_mm(self, y)
157
+
158
+ def __rmatmul__(self, x):
159
+ if isinstance(x, wp.array):
160
+ return bsr_mv(self, x, transpose=True)
161
+
162
+ return bsr_mm(x, self)
163
+
164
+ def __imatmul__(self, y):
165
+ return bsr_mm(self, y, self)
166
+
167
+ def __truediv__(self, y):
168
+ return _BsrScalingExpression(self, 1.0 / y)
169
+
170
+ def __neg__(self):
171
+ return _BsrScalingExpression(self, -1.0)
172
+
173
+ def transpose(self):
174
+ """Returns a transposed copy of this matrix"""
175
+ return bsr_transposed(self)
176
+
71
177
 
72
178
  def bsr_matrix_t(dtype: BlockType):
73
179
  dtype = wp.types.type_to_warp(dtype)
@@ -83,7 +189,7 @@ def bsr_matrix_t(dtype: BlockType):
83
189
  ncol: int
84
190
  """Number of columns of blocks"""
85
191
  nnz: int
86
- """Number of non-zero blocks: equal to offsets[-1], cached on host for convenience"""
192
+ """Upper bound for the number of non-zeros"""
87
193
  offsets: wp.array(dtype=int)
88
194
  """Array of size at least 1 + nrows"""
89
195
  columns: wp.array(dtype=int)
@@ -130,7 +236,7 @@ def bsr_zeros(
130
236
 
131
237
  bsr.nrow = int(rows_of_blocks)
132
238
  bsr.ncol = int(cols_of_blocks)
133
- bsr.nnz = 0
239
+ bsr.nnz = int(0)
134
240
  bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
135
241
  bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
136
242
  bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
@@ -143,6 +249,9 @@ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
143
249
  nrow = bsr.nrow
144
250
  if nnz is None:
145
251
  nnz = bsr.nnz
252
+ else:
253
+ # update nnz upper bound
254
+ bsr.nnz = int(nnz)
146
255
 
147
256
  if bsr.offsets.size < nrow + 1:
148
257
  bsr.offsets = wp.empty(shape=(nrow + 1,), dtype=int, device=bsr.offsets.device)
@@ -170,9 +279,10 @@ def bsr_set_zero(
170
279
  bsr.nrow = int(rows_of_blocks)
171
280
  if cols_of_blocks is not None:
172
281
  bsr.ncol = int(cols_of_blocks)
173
- bsr.nnz = 0
174
- _bsr_ensure_fits(bsr)
282
+
283
+ _bsr_ensure_fits(bsr, nnz=0)
175
284
  bsr.offsets.zero_()
285
+ bsr.copy_nnz_async(known_nnz=0)
176
286
 
177
287
 
178
288
  def bsr_set_from_triplets(
@@ -180,11 +290,12 @@ def bsr_set_from_triplets(
180
290
  rows: "Array[int]",
181
291
  columns: "Array[int]",
182
292
  values: "Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]",
293
+ prune_numerical_zeros: bool = True,
183
294
  ):
184
295
  """
185
296
  Fills a BSR matrix with values defined by coordinate-oriented (COO) triplets, discarding existing blocks.
186
297
 
187
- The first dimension of the three input arrays must match, and determines the number of non-zeros in the constructed matrix.
298
+ The first dimension of the three input arrays must match and indicates the number of COO triplets.
188
299
 
189
300
  Args:
190
301
  dest: Sparse matrix to populate
@@ -192,6 +303,7 @@ def bsr_set_from_triplets(
192
303
  columns: Columns index for each non-zero
193
304
  values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
194
305
  to the `dest` matrix's block type, or a 3d array with data type equal to the `dest` matrix's scalar type.
306
+ prune_numerical_zeros: If True, will ignore the zero-valued blocks
195
307
  """
196
308
 
197
309
  if values.device != columns.device or values.device != rows.device or values.device != dest.values.device:
@@ -244,62 +356,477 @@ def bsr_set_from_triplets(
244
356
  if not native_func:
245
357
  raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
246
358
 
247
- dest.nnz = native_func(
248
- dest.block_shape[0],
249
- dest.block_shape[1],
250
- dest.nrow,
251
- nnz,
252
- rows.ptr,
253
- columns.ptr,
254
- values.ptr,
255
- dest.offsets.ptr,
256
- dest.columns.ptr,
257
- dest.values.ptr,
359
+ nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
360
+
361
+ with wp.ScopedDevice(device):
362
+ native_func(
363
+ dest.block_shape[0],
364
+ dest.block_shape[1],
365
+ dest.nrow,
366
+ nnz,
367
+ ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
368
+ ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
369
+ ctypes.cast(values.ptr, ctypes.c_void_p),
370
+ prune_numerical_zeros,
371
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
372
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
373
+ ctypes.cast(dest.values.ptr, ctypes.c_void_p),
374
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
375
+ nnz_event,
376
+ )
377
+
378
+
379
+ class _BsrExpression(Generic[_BlockType]):
380
+ pass
381
+
382
+
383
+ class _BsrScalingExpression(_BsrExpression):
384
+ def __init__(self, mat, scale):
385
+ self.mat = mat
386
+ self.scale = scale
387
+
388
+ def eval(self):
389
+ return bsr_copy(self)
390
+
391
+ @property
392
+ def nrow(self) -> int:
393
+ return self.mat.nrow
394
+
395
+ @property
396
+ def ncol(self) -> int:
397
+ return self.mat.ncol
398
+
399
+ @property
400
+ def nnz(self) -> int:
401
+ return self.mat.nnz
402
+
403
+ @property
404
+ def offsets(self) -> wp.array:
405
+ return self.mat.offsets
406
+
407
+ @property
408
+ def columns(self) -> wp.array:
409
+ return self.mat.columns
410
+
411
+ @property
412
+ def scalar_type(self) -> Scalar:
413
+ return self.mat.scalar_type
414
+
415
+ @property
416
+ def block_shape(self) -> Tuple[int, int]:
417
+ return self.mat.block_shape
418
+
419
+ @property
420
+ def block_size(self) -> int:
421
+ return self.mat.block_size
422
+
423
+ @property
424
+ def shape(self) -> Tuple[int, int]:
425
+ return self.mat.shape
426
+
427
+ @property
428
+ def dtype(self) -> type:
429
+ return self.mat.dtype
430
+
431
+ @property
432
+ def device(self) -> wp.context.Device:
433
+ return self.mat.device
434
+
435
+ # Overloaded math operators
436
+ def __add__(self, y):
437
+ return bsr_axpy(y, bsr_copy(self.mat), alpha=self.scale)
438
+
439
+ def __radd__(self, x):
440
+ return bsr_axpy(x, bsr_copy(self.mat), beta=self.scale)
441
+
442
+ def __sub__(self, y):
443
+ return bsr_axpy(y, bsr_copy(self.mat), alpha=-self.scale)
444
+
445
+ def __rsub__(self, x):
446
+ return bsr_axpy(x, bsr_copy(self.mat), beta=-self.scale)
447
+
448
+ def __mul__(self, y):
449
+ return _BsrScalingExpression(self.mat, y * self.scale)
450
+
451
+ def __rmul__(self, x):
452
+ return _BsrScalingExpression(self.mat, x * self.scale)
453
+
454
+ def __matmul__(self, y):
455
+ if isinstance(y, wp.array):
456
+ return bsr_mv(self.mat, y, alpha=self.scale)
457
+
458
+ return bsr_mm(self.mat, y, alpha=self.scale)
459
+
460
+ def __rmatmul__(self, x):
461
+ if isinstance(x, wp.array):
462
+ return bsr_mv(self.mat, x, alpha=self.scale, transpose=True)
463
+
464
+ return bsr_mm(x, self.mat, alpha=self.scale)
465
+
466
+ def __truediv__(self, y):
467
+ return _BsrScalingExpression(self.mat, self.scale / y)
468
+
469
+ def __neg__(self):
470
+ return _BsrScalingExpression(self.mat, -self.scale)
471
+
472
+ def transpose(self):
473
+ """Returns a transposed copy of this matrix"""
474
+ return _BsrScalingExpression(self.mat.transpose(), self.scale)
475
+
476
+
477
+ BsrMatrixOrExpression = Union[BsrMatrix[_BlockType], _BsrExpression[_BlockType]]
478
+
479
+
480
+ def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression):
481
+ if isinstance(bsr, BsrMatrix):
482
+ return bsr, 1.0
483
+ if isinstance(bsr, _BsrScalingExpression):
484
+ return bsr.mat, bsr.scale
485
+
486
+ raise ValueError("Argument cannot be interpreted as a BsrMatrix")
487
+
488
+
489
+ @wp.kernel
490
+ def _bsr_assign_split_offsets(
491
+ row_factor: int,
492
+ col_factor: int,
493
+ src_offsets: wp.array(dtype=int),
494
+ dest_offsets: wp.array(dtype=int),
495
+ ):
496
+ row = wp.tid()
497
+
498
+ base_offset = src_offsets[row] * row_factor * col_factor
499
+ row_count = src_offsets[1 + row] - src_offsets[row]
500
+
501
+ for k in range(row_factor):
502
+ dest_offsets[1 + k + row_factor * row] = base_offset + row_count * col_factor * (k + 1)
503
+
504
+ if row == 0:
505
+ dest_offsets[0] = 0
506
+
507
+
508
+ @wp.kernel
509
+ def _bsr_assign_split_blocks(
510
+ structure_only: wp.bool,
511
+ scale: Any,
512
+ row_factor: int,
513
+ col_factor: int,
514
+ dest_row_count: int,
515
+ src_offsets: wp.array(dtype=int),
516
+ src_columns: wp.array(dtype=int),
517
+ src_values: wp.array3d(dtype=Any),
518
+ dest_offsets: wp.array(dtype=int),
519
+ dest_columns: wp.array(dtype=int),
520
+ dest_values: wp.array3d(dtype=Any),
521
+ ):
522
+ dest_block = wp.tid()
523
+
524
+ if dest_block >= dest_offsets[dest_row_count]:
525
+ return
526
+
527
+ dest_row = wp.lower_bound(dest_offsets, dest_block + 1) - 1
528
+ src_row = dest_row // row_factor
529
+
530
+ dest_col_in_row = dest_block - dest_offsets[dest_row]
531
+ src_col_in_row = dest_col_in_row // col_factor
532
+
533
+ src_block = src_offsets[src_row] + src_col_in_row
534
+
535
+ dest_rows_per_block = dest_values.shape[1]
536
+ dest_cols_per_block = dest_values.shape[2]
537
+
538
+ split_row = dest_row - row_factor * src_row
539
+ split_col = dest_col_in_row - col_factor * src_col_in_row
540
+
541
+ dest_columns[dest_block] = src_columns[src_block] * col_factor + split_col
542
+
543
+ if not structure_only:
544
+ src_base_i = split_row * dest_rows_per_block
545
+ src_base_j = split_col * dest_cols_per_block
546
+ for i in range(dest_rows_per_block):
547
+ for j in range(dest_cols_per_block):
548
+ dest_values[dest_block, i, j] = dest_values.dtype(
549
+ scale * src_values[src_block, i + src_base_i, j + src_base_j]
550
+ )
551
+
552
+
553
+ @wp.kernel
554
+ def _bsr_assign_merge_row_col(
555
+ row_factor: int,
556
+ col_factor: int,
557
+ src_row_count: int,
558
+ src_offsets: wp.array(dtype=int),
559
+ src_columns: wp.array(dtype=int),
560
+ dest_rows: wp.array(dtype=int),
561
+ dest_cols: wp.array(dtype=int),
562
+ ):
563
+ block = wp.tid()
564
+
565
+ if block >= src_offsets[src_row_count]:
566
+ dest_rows[block] = -1 # invalid
567
+ dest_cols[block] = -1
568
+ else:
569
+ row = wp.lower_bound(src_offsets, block + 1) - 1
570
+ dest_rows[block] = row // row_factor
571
+ dest_cols[block] = src_columns[block] // col_factor
572
+
573
+
574
+ @wp.kernel
575
+ def _bsr_assign_merge_blocks(
576
+ scale: Any,
577
+ row_factor: int,
578
+ col_factor: int,
579
+ src_row_count: int,
580
+ src_offsets: wp.array(dtype=int),
581
+ src_columns: wp.array(dtype=int),
582
+ src_values: wp.array3d(dtype=Any),
583
+ dest_offsets: wp.array(dtype=int),
584
+ dest_columns: wp.array(dtype=int),
585
+ dest_values: wp.array3d(dtype=Any),
586
+ ):
587
+ src_block = wp.tid()
588
+
589
+ if src_block >= src_offsets[src_row_count]:
590
+ return
591
+
592
+ src_row = wp.lower_bound(src_offsets, src_block + 1) - 1
593
+ src_col = src_columns[src_block]
594
+
595
+ dest_row = src_row // row_factor
596
+ dest_col = src_col // col_factor
597
+
598
+ dest_block = wp.lower_bound(dest_columns, dest_offsets[dest_row], dest_offsets[dest_row + 1], dest_col)
599
+
600
+ src_rows_per_block = src_values.shape[1]
601
+ src_cols_per_block = src_values.shape[2]
602
+
603
+ split_row = src_row - row_factor * dest_row
604
+ split_col = src_col - col_factor * dest_col
605
+
606
+ dest_base_i = split_row * src_rows_per_block
607
+ dest_base_j = split_col * src_cols_per_block
608
+
609
+ for i in range(src_rows_per_block):
610
+ for j in range(src_cols_per_block):
611
+ dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype(
612
+ scale * src_values[src_block, i, j]
613
+ )
614
+
615
+
616
+ def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array:
617
+ if A.block_shape == (1, 1):
618
+ return A.values.reshape((A.values.shape[0], 1, 1))
619
+
620
+ return wp.array(
621
+ data=None,
622
+ ptr=A.values.ptr,
623
+ capacity=A.values.capacity,
624
+ device=A.device,
625
+ dtype=A.scalar_type,
626
+ shape=(A.values.shape[0], A.block_shape[0], A.block_shape[1]),
258
627
  )
259
628
 
260
629
 
261
630
  def bsr_assign(
262
631
  dest: BsrMatrix[BlockType[Rows, Cols, Scalar]],
263
- src: BsrMatrix[BlockType[Rows, Cols, Any]],
632
+ src: BsrMatrixOrExpression[BlockType[Any, Any, Any]],
633
+ structure_only: bool = False,
264
634
  ):
265
- """Copies the content of the `src` matrix to `dest`, casting the block values if the two matrices use distinct scalar types."""
635
+ """Copies the content of the `src` BSR matrix to `dest`.
636
+
637
+ Args:
638
+ src: Matrix to be copied
639
+ dest: Destination matrix. May have a different block shape of scalar type than `src`, in which case the required casting will be performed.
640
+ structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
641
+ to accommodate at least `src.nnz` blocks. If `structure_only` is ``False``, values are also copied with implicit
642
+ casting if the two matrices use distinct scalar types.
643
+ """
644
+
645
+ src, src_scale = _extract_matrix_and_scale(src)
266
646
 
267
647
  if dest.values.device != src.values.device:
268
648
  raise ValueError("Source and destination matrices must reside on the same device")
269
649
 
270
- if dest.block_shape != src.block_shape:
271
- raise ValueError("Source and destination matrices must have the same block shape")
650
+ if dest.block_shape == src.block_shape:
651
+ dest.nrow = src.nrow
652
+ dest.ncol = src.ncol
653
+
654
+ nnz_alloc = src.nnz
655
+ _bsr_ensure_fits(dest, nnz=nnz_alloc)
656
+
657
+ wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
658
+ dest.copy_nnz_async()
659
+
660
+ if nnz_alloc > 0:
661
+ wp.copy(dest=dest.columns, src=src.columns, count=nnz_alloc)
272
662
 
273
- dest.nrow = src.nrow
274
- dest.ncol = src.ncol
275
- dest.nnz = src.nnz
663
+ if not structure_only:
664
+ warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc)
665
+ bsr_scale(dest, src_scale)
276
666
 
277
- _bsr_ensure_fits(dest)
667
+ elif src.block_shape[0] >= dest.block_shape[0] and src.block_shape[1] >= dest.block_shape[1]:
668
+ # Split blocks
669
+
670
+ row_factor = src.block_shape[0] // dest.block_shape[0]
671
+ col_factor = src.block_shape[1] // dest.block_shape[1]
672
+
673
+ if (
674
+ row_factor * dest.block_shape[0] != src.block_shape[0]
675
+ or col_factor * dest.block_shape[1] != src.block_shape[1]
676
+ ):
677
+ raise ValueError(
678
+ f"Dest block shape {dest.block_shape} is not an exact divider of src block shape {src.block_shape}"
679
+ )
278
680
 
279
- wp.copy(dest=dest.offsets, src=src.offsets, count=src.nrow + 1)
280
- if src.nnz > 0:
281
- wp.copy(dest=dest.columns, src=src.columns, count=src.nnz)
282
- warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=src.nnz)
681
+ dest.nrow = src.nrow * row_factor
682
+ dest.ncol = src.ncol * col_factor
283
683
 
684
+ nnz_alloc = src.nnz * row_factor * col_factor
685
+ _bsr_ensure_fits(dest, nnz=nnz_alloc)
284
686
 
285
- def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
687
+ wp.launch(
688
+ _bsr_assign_split_offsets,
689
+ dim=src.nrow,
690
+ device=dest.device,
691
+ inputs=[row_factor, col_factor, src.offsets, dest.offsets],
692
+ )
693
+ wp.launch(
694
+ _bsr_assign_split_blocks,
695
+ dim=dest.nnz,
696
+ device=dest.device,
697
+ inputs=[
698
+ wp.bool(structure_only),
699
+ src.scalar_type(src_scale),
700
+ row_factor,
701
+ col_factor,
702
+ dest.nrow,
703
+ src.offsets,
704
+ src.columns,
705
+ _bsr_values_as_3d_array(src),
706
+ dest.offsets,
707
+ dest.columns,
708
+ _bsr_values_as_3d_array(dest),
709
+ ],
710
+ )
711
+
712
+ elif src.block_shape[0] <= dest.block_shape[0] and src.block_shape[1] <= dest.block_shape[1]:
713
+ # Merge blocks
714
+
715
+ row_factor = dest.block_shape[0] // src.block_shape[0]
716
+ col_factor = dest.block_shape[1] // src.block_shape[1]
717
+
718
+ if (
719
+ row_factor * src.block_shape[0] != dest.block_shape[0]
720
+ or col_factor * src.block_shape[1] != dest.block_shape[1]
721
+ ):
722
+ raise ValueError(
723
+ f"Dest block shape {dest.block_shape} is not an exact multiple of src block shape {src.block_shape}"
724
+ )
725
+
726
+ if src.nrow % row_factor != 0 or src.ncol % col_factor != 0:
727
+ raise ValueError(
728
+ "The total rows and columns of the src matrix cannot be evenly divided using the requested block shape"
729
+ )
730
+
731
+ dest.nrow = src.nrow // row_factor
732
+ dest.ncol = src.ncol // col_factor
733
+
734
+ nnz_alloc = src.nnz # Conservative, in case all nnz in src belong to distinct merged blocks
735
+ _bsr_ensure_fits(dest, nnz=nnz_alloc)
736
+
737
+ # Compute destination rows and columns
738
+ dest_rows = wp.empty_like(src.columns)
739
+ dest_cols = wp.empty_like(src.columns)
740
+ wp.launch(
741
+ _bsr_assign_merge_row_col,
742
+ dim=src.nnz,
743
+ device=dest.device,
744
+ inputs=[row_factor, col_factor, src.nrow, src.offsets, src.columns, dest_rows, dest_cols],
745
+ )
746
+
747
+ # Compute destination offsets from triplets
748
+ from warp.context import runtime
749
+
750
+ if dest.device.is_cpu:
751
+ native_func = runtime.core.bsr_matrix_from_triplets_float_host
752
+ else:
753
+ native_func = runtime.core.bsr_matrix_from_triplets_float_device
754
+
755
+ nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
756
+ with wp.ScopedDevice(dest.device):
757
+ native_func(
758
+ dest.block_shape[0],
759
+ dest.block_shape[1],
760
+ dest.nrow,
761
+ dest.nnz,
762
+ ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
763
+ ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
764
+ 0,
765
+ False,
766
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
767
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
768
+ 0,
769
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
770
+ nnz_event,
771
+ )
772
+
773
+ # merge block values
774
+ if not structure_only:
775
+ dest.values.zero_()
776
+ wp.launch(
777
+ _bsr_assign_merge_blocks,
778
+ dim=src.nnz,
779
+ device=dest.device,
780
+ inputs=[
781
+ src.scalar_type(src_scale),
782
+ row_factor,
783
+ col_factor,
784
+ src.nrow,
785
+ src.offsets,
786
+ src.columns,
787
+ _bsr_values_as_3d_array(src),
788
+ dest.offsets,
789
+ dest.columns,
790
+ _bsr_values_as_3d_array(dest),
791
+ ],
792
+ )
793
+
794
+ else:
795
+ raise ValueError("Incompatible dest and src block shapes")
796
+
797
+
798
+ def bsr_copy(
799
+ A: BsrMatrixOrExpression,
800
+ scalar_type: Optional[Scalar] = None,
801
+ block_shape: Optional[Tuple[int, int]] = None,
802
+ structure_only: bool = False,
803
+ ):
286
804
  """Returns a copy of matrix ``A``, possibly changing its scalar type.
287
805
 
288
806
  Args:
807
+ A: Matrix to be copied
289
808
  scalar_type: If provided, the returned matrix will use this scalar type instead of the one from `A`.
809
+ block_shape: If provided, the returned matrix will use blocks of this shape instead of the one from `A`.
810
+ Both dimensions of `block_shape` must be either a multiple or an exact divider of the ones from `A`.
811
+ structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
812
+ to accommodate at least `src.nnz` blocks. If `structure_only` is ``False``, values are also copied with implicit
813
+ casting if the two matrices use distinct scalar types.
290
814
  """
291
815
  if scalar_type is None:
292
- block_type = A.values.dtype
293
- elif A.block_shape == (1, 1):
816
+ scalar_type = A.scalar_type
817
+ if block_shape is None:
818
+ block_shape = A.block_shape
819
+
820
+ if block_shape == (1, 1):
294
821
  block_type = scalar_type
295
822
  else:
296
- block_type = wp.types.matrix(shape=A.block_shape, dtype=scalar_type)
823
+ block_type = wp.types.matrix(shape=block_shape, dtype=scalar_type)
297
824
 
298
825
  copy = bsr_zeros(
299
826
  rows_of_blocks=A.nrow,
300
827
  cols_of_blocks=A.ncol,
301
828
  block_type=block_type,
302
- device=A.values.device,
829
+ device=A.device,
303
830
  )
304
831
  bsr_assign(dest=copy, src=A)
305
832
  return copy
@@ -307,10 +834,12 @@ def bsr_copy(A: BsrMatrix, scalar_type: Optional[Scalar] = None):
307
834
 
308
835
  def bsr_set_transpose(
309
836
  dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
310
- src: BsrMatrix[BlockType[Rows, Cols, Scalar]],
837
+ src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
311
838
  ):
312
839
  """Assigns the transposed matrix `src` to matrix `dest`"""
313
840
 
841
+ src, src_scale = _extract_matrix_and_scale(src)
842
+
314
843
  if dest.values.device != src.values.device:
315
844
  raise ValueError("All arguments must reside on the same device")
316
845
 
@@ -322,15 +851,16 @@ def bsr_set_transpose(
322
851
  if dest.block_shape != transpose_block_shape:
323
852
  raise ValueError(f"Destination block shape must be {transpose_block_shape}")
324
853
 
854
+ nnz = src.nnz
325
855
  dest.nrow = src.ncol
326
856
  dest.ncol = src.nrow
327
- dest.nnz = src.nnz
328
857
 
329
- if src.nnz == 0:
858
+ if nnz == 0:
859
+ bsr_set_zero(dest)
330
860
  return
331
861
 
332
862
  # Increase dest array sizes if needed
333
- _bsr_ensure_fits(dest)
863
+ _bsr_ensure_fits(dest, nnz=nnz)
334
864
 
335
865
  from warp.context import runtime
336
866
 
@@ -348,22 +878,26 @@ def bsr_set_transpose(
348
878
  if not native_func:
349
879
  raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
350
880
 
351
- native_func(
352
- src.block_shape[0],
353
- src.block_shape[1],
354
- src.nrow,
355
- src.ncol,
356
- src.nnz,
357
- src.offsets.ptr,
358
- src.columns.ptr,
359
- src.values.ptr,
360
- dest.offsets.ptr,
361
- dest.columns.ptr,
362
- dest.values.ptr,
363
- )
881
+ with wp.ScopedDevice(dest.device):
882
+ native_func(
883
+ src.block_shape[0],
884
+ src.block_shape[1],
885
+ src.nrow,
886
+ src.ncol,
887
+ nnz,
888
+ ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
889
+ ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
890
+ ctypes.cast(src.values.ptr, ctypes.c_void_p),
891
+ ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
892
+ ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
893
+ ctypes.cast(dest.values.ptr, ctypes.c_void_p),
894
+ )
364
895
 
896
+ dest.copy_nnz_async()
897
+ bsr_scale(dest, src_scale)
365
898
 
366
- def bsr_transposed(A: BsrMatrix):
899
+
900
+ def bsr_transposed(A: BsrMatrixOrExpression):
367
901
  """Returns a copy of the transposed matrix `A`"""
368
902
 
369
903
  if A.block_shape == (1, 1):
@@ -375,7 +909,7 @@ def bsr_transposed(A: BsrMatrix):
375
909
  rows_of_blocks=A.ncol,
376
910
  cols_of_blocks=A.nrow,
377
911
  block_type=block_type,
378
- device=A.values.device,
912
+ device=A.device,
379
913
  )
380
914
  bsr_set_transpose(dest=transposed, src=A)
381
915
  return transposed
@@ -383,6 +917,7 @@ def bsr_transposed(A: BsrMatrix):
383
917
 
384
918
  @wp.kernel
385
919
  def _bsr_get_diag_kernel(
920
+ scale: Any,
386
921
  A_offsets: wp.array(dtype=int),
387
922
  A_columns: wp.array(dtype=int),
388
923
  A_values: wp.array(dtype=Any),
@@ -395,10 +930,10 @@ def _bsr_get_diag_kernel(
395
930
  diag = wp.lower_bound(A_columns, beg, end, row)
396
931
  if diag < end:
397
932
  if A_columns[diag] == row:
398
- out[row] = A_values[diag]
933
+ out[row] = scale * A_values[diag]
399
934
 
400
935
 
401
- def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
936
+ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
402
937
  """Returns the array of blocks that constitute the diagonal of a sparse matrix.
403
938
 
404
939
  Args:
@@ -406,6 +941,8 @@ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = N
406
941
  out: if provided, the array into which to store the diagonal blocks
407
942
  """
408
943
 
944
+ A, scale = _extract_matrix_and_scale(A)
945
+
409
946
  dim = min(A.nrow, A.ncol)
410
947
 
411
948
  if out is None:
@@ -422,7 +959,7 @@ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = N
422
959
  kernel=_bsr_get_diag_kernel,
423
960
  dim=dim,
424
961
  device=A.values.device,
425
- inputs=[A.offsets, A.columns, A.values, out],
962
+ inputs=[A.scalar_type(scale), A.offsets, A.columns, A.values, out],
426
963
  )
427
964
 
428
965
  return out
@@ -495,13 +1032,13 @@ def bsr_set_diag(
495
1032
  A.nrow = rows_of_blocks
496
1033
  A.ncol = cols_of_blocks
497
1034
 
498
- A.nnz = min(A.nrow, A.ncol)
499
- _bsr_ensure_fits(A)
1035
+ nnz = min(A.nrow, A.ncol)
1036
+ _bsr_ensure_fits(A, nnz=nnz)
500
1037
 
501
1038
  if warp.types.is_array(diag):
502
1039
  wp.launch(
503
1040
  kernel=_bsr_set_diag_kernel,
504
- dim=A.nnz,
1041
+ dim=nnz,
505
1042
  device=A.values.device,
506
1043
  inputs=[diag, A.offsets, A.columns, A.values],
507
1044
  )
@@ -511,11 +1048,13 @@ def bsr_set_diag(
511
1048
  diag = A.values.dtype(diag)
512
1049
  wp.launch(
513
1050
  kernel=_bsr_set_diag_constant_kernel,
514
- dim=A.nnz,
1051
+ dim=nnz,
515
1052
  device=A.values.device,
516
1053
  inputs=[diag, A.offsets, A.columns, A.values],
517
1054
  )
518
1055
 
1056
+ A.copy_nnz_async(known_nnz=nnz)
1057
+
519
1058
 
520
1059
  def bsr_diag(
521
1060
  diag: "Union[BlockType, Array[BlockType]]",
@@ -619,11 +1158,14 @@ def _bsr_scale_kernel(
619
1158
  values[wp.tid()] = alpha * values[wp.tid()]
620
1159
 
621
1160
 
622
- def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
1161
+ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
623
1162
  """
624
1163
  Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x`
625
1164
  """
626
1165
 
1166
+ x, scale = _extract_matrix_and_scale(x)
1167
+ alpha *= scale
1168
+
627
1169
  if alpha != 1.0 and x.nnz > 0:
628
1170
  if alpha == 0.0:
629
1171
  bsr_set_zero(x)
@@ -642,11 +1184,14 @@ def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix:
642
1184
 
643
1185
 
644
1186
  @wp.kernel
645
- def _bsr_get_block_row(dest_offset: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
1187
+ def _bsr_get_block_row(dest_offset: int, row_count: int, bsr_offsets: wp.array(dtype=int), rows: wp.array(dtype=int)):
646
1188
  i = wp.tid()
647
1189
 
648
- row = wp.lower_bound(bsr_offsets, i + 1) - 1
649
- rows[dest_offset + i] = row
1190
+ if i >= bsr_offsets[row_count]:
1191
+ rows[dest_offset + i] = -1 # invalid
1192
+ else:
1193
+ row = wp.lower_bound(bsr_offsets, i + 1) - 1
1194
+ rows[dest_offset + i] = row
650
1195
 
651
1196
 
652
1197
  @wp.kernel
@@ -662,6 +1207,10 @@ def _bsr_axpy_add_block(
662
1207
  ):
663
1208
  i = wp.tid()
664
1209
  row = rows[i + src_offset]
1210
+
1211
+ if row < 0:
1212
+ return
1213
+
665
1214
  col = cols[i + src_offset]
666
1215
  beg = dst_offsets[row]
667
1216
  end = dst_offsets[row + 1]
@@ -694,11 +1243,11 @@ class bsr_axpy_work_arrays:
694
1243
  self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
695
1244
 
696
1245
  if self._old_y_values is None or self._old_y_values.size < y.nnz:
697
- self._old_y_values = wp.empty(shape=(y.nnz), dtype=y.values.dtype, device=self.device)
1246
+ self._old_y_values = wp.empty(shape=(y.nnz,), dtype=y.values.dtype, device=self.device)
698
1247
 
699
1248
 
700
1249
  def bsr_axpy(
701
- x: BsrMatrix[BlockType[Rows, Cols, Scalar]],
1250
+ x: BsrMatrixOrExpression,
702
1251
  y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
703
1252
  alpha: Scalar = 1.0,
704
1253
  beta: Scalar = 1.0,
@@ -717,17 +1266,23 @@ def bsr_axpy(
717
1266
  work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`.
718
1267
  """
719
1268
 
1269
+ x, x_scale = _extract_matrix_and_scale(x)
1270
+ alpha *= x_scale
1271
+
720
1272
  if y is None:
721
1273
  # If not output matrix is provided, allocate it for convenience
722
1274
  y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
723
1275
  beta = 0.0
724
1276
 
1277
+ x_nnz = x.nnz
1278
+ y_nnz = y.nnz
1279
+
725
1280
  # Handle easy cases first
726
- if beta == 0.0 or y.nnz == 0:
1281
+ if beta == 0.0 or y_nnz == 0:
727
1282
  bsr_assign(src=x, dest=y)
728
1283
  return bsr_scale(y, alpha=alpha)
729
1284
 
730
- if alpha == 0.0 or x.nnz == 0:
1285
+ if alpha == 0.0 or x_nnz == 0:
731
1286
  return bsr_scale(y, alpha=beta)
732
1287
 
733
1288
  if not isinstance(alpha, y.scalar_type):
@@ -753,28 +1308,28 @@ def bsr_axpy(
753
1308
  if work_arrays is None:
754
1309
  work_arrays = bsr_axpy_work_arrays()
755
1310
 
756
- sum_nnz = x.nnz + y.nnz
1311
+ sum_nnz = x_nnz + y_nnz
757
1312
  device = y.values.device
758
1313
  work_arrays._allocate(device, y, sum_nnz)
759
1314
 
760
- wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y.nnz)
1315
+ wp.copy(work_arrays._sum_cols, y.columns, 0, 0, y_nnz)
761
1316
  wp.launch(
762
1317
  kernel=_bsr_get_block_row,
763
1318
  device=device,
764
- dim=y.nnz,
765
- inputs=[0, y.offsets, work_arrays._sum_rows],
1319
+ dim=y_nnz,
1320
+ inputs=[0, y.nrow, y.offsets, work_arrays._sum_rows],
766
1321
  )
767
1322
 
768
- wp.copy(work_arrays._sum_cols, x.columns, y.nnz, 0, x.nnz)
1323
+ wp.copy(work_arrays._sum_cols, x.columns, y_nnz, 0, x_nnz)
769
1324
  wp.launch(
770
1325
  kernel=_bsr_get_block_row,
771
1326
  device=device,
772
- dim=x.nnz,
773
- inputs=[y.nnz, x.offsets, work_arrays._sum_rows],
1327
+ dim=x_nnz,
1328
+ inputs=[y_nnz, x.nrow, x.offsets, work_arrays._sum_rows],
774
1329
  )
775
1330
 
776
1331
  # Save old y values before overwriting matrix
777
- wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
1332
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y_nnz)
778
1333
 
779
1334
  # Increase dest array sizes if needed
780
1335
  if y.columns.shape[0] < sum_nnz:
@@ -787,21 +1342,28 @@ def bsr_axpy(
787
1342
  else:
788
1343
  native_func = runtime.core.bsr_matrix_from_triplets_float_device
789
1344
 
790
- old_y_nnz = y.nnz
791
- y.nnz = native_func(
792
- y.block_shape[0],
793
- y.block_shape[1],
794
- y.nrow,
795
- sum_nnz,
796
- work_arrays._sum_rows.ptr,
797
- work_arrays._sum_cols.ptr,
798
- 0,
799
- y.offsets.ptr,
800
- y.columns.ptr,
801
- 0,
802
- )
1345
+ old_y_nnz = y_nnz
1346
+ nnz_buf, nnz_event = y._nnz_transfer_buf_and_event()
1347
+
1348
+ with wp.ScopedDevice(y.device):
1349
+ native_func(
1350
+ y.block_shape[0],
1351
+ y.block_shape[1],
1352
+ y.nrow,
1353
+ sum_nnz,
1354
+ ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1355
+ ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1356
+ 0,
1357
+ False,
1358
+ ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1359
+ ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1360
+ 0,
1361
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1362
+ nnz_event,
1363
+ )
1364
+
1365
+ _bsr_ensure_fits(y, nnz=sum_nnz)
803
1366
 
804
- _bsr_ensure_fits(y)
805
1367
  y.values.zero_()
806
1368
 
807
1369
  wp.launch(
@@ -823,7 +1385,7 @@ def bsr_axpy(
823
1385
  wp.launch(
824
1386
  kernel=_bsr_axpy_add_block,
825
1387
  device=device,
826
- dim=x.nnz,
1388
+ dim=x_nnz,
827
1389
  inputs=[
828
1390
  old_y_nnz,
829
1391
  alpha,
@@ -918,8 +1480,9 @@ def _bsr_mm_compute_values(
918
1480
  y_end = y_offsets[x_col + 1]
919
1481
 
920
1482
  y_block = wp.lower_bound(y_columns, y_beg, y_end, col)
921
- if y_block < y_end and y_columns[y_block] == col:
922
- mm_val += x_values[x_block] * y_values[y_block]
1483
+ if y_block < y_end:
1484
+ if y_columns[y_block] == col:
1485
+ mm_val += x_values[x_block] * y_values[y_block]
923
1486
 
924
1487
  mm_values[mm_block] += alpha * mm_val
925
1488
 
@@ -932,38 +1495,38 @@ class bsr_mm_work_arrays:
932
1495
 
933
1496
  def _reset(self, device):
934
1497
  self.device = device
935
- self._pinned_count_buffer = None
936
1498
  self._mm_row_counts = None
937
1499
  self._mm_rows = None
938
1500
  self._mm_cols = None
939
1501
  self._old_z_values = None
940
1502
  self._old_z_offsets = None
941
1503
  self._old_z_columns = None
1504
+ self._mm_nnz = 0
942
1505
 
943
- def _allocate_stage_1(self, device, z: BsrMatrix, copied_z_nnz: int, z_aliasing: bool):
1506
+ def _allocate_stage_1(self, device, z: BsrMatrix, beta: float, z_aliasing: bool):
944
1507
  if self.device != device:
945
1508
  self._reset(device)
946
1509
 
947
1510
  # Allocations that do not depend on any computation
948
- if self.device.is_cuda:
949
- if self._pinned_count_buffer is None:
950
- self._pinned_count_buffer = wp.empty(shape=(1,), dtype=int, pinned=True, device="cpu")
1511
+ z_nnz = z.nnz_sync()
1512
+ self._copied_z_nnz = z_nnz if beta != 0.0 or z_aliasing else 0
951
1513
 
952
1514
  if self._mm_row_counts is None or self._mm_row_counts.size < z.nrow + 1:
953
1515
  self._mm_row_counts = wp.empty(shape=(z.nrow + 1,), dtype=int, device=self.device)
954
1516
 
955
- if copied_z_nnz > 0:
956
- if self._old_z_values is None or self._old_z_values.size < copied_z_nnz:
957
- self._old_z_values = wp.empty(shape=(copied_z_nnz,), dtype=z.values.dtype, device=self.device)
1517
+ if self._copied_z_nnz > 0:
1518
+ if self._old_z_values is None or self._old_z_values.size < self._copied_z_nnz:
1519
+ self._old_z_values = wp.empty(shape=(self._copied_z_nnz,), dtype=z.values.dtype, device=self.device)
958
1520
 
959
1521
  if z_aliasing:
960
- if self._old_z_columns is None or self._old_z_columns.size < z.nnz:
961
- self._old_z_columns = wp.empty(shape=(z.nnz,), dtype=z.columns.dtype, device=self.device)
1522
+ if self._old_z_columns is None or self._old_z_columns.size < z_nnz:
1523
+ self._old_z_columns = wp.empty(shape=(z_nnz,), dtype=z.columns.dtype, device=self.device)
962
1524
  if self._old_z_offsets is None or self._old_z_offsets.size < z.nrow + 1:
963
1525
  self._old_z_offsets = wp.empty(shape=(z.nrow + 1,), dtype=z.offsets.dtype, device=self.device)
964
1526
 
965
1527
  def _allocate_stage_2(self, mm_nnz: int):
966
1528
  # Allocations that depend on unmerged nnz estimate
1529
+ self._mm_nnz = mm_nnz
967
1530
  if self._mm_rows is None or self._mm_rows.size < mm_nnz:
968
1531
  self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
969
1532
  if self._mm_cols is None or self._mm_cols.size < mm_nnz:
@@ -971,12 +1534,13 @@ class bsr_mm_work_arrays:
971
1534
 
972
1535
 
973
1536
  def bsr_mm(
974
- x: BsrMatrix[BlockType[Rows, Any, Scalar]],
975
- y: BsrMatrix[BlockType[Any, Cols, Scalar]],
1537
+ x: BsrMatrixOrExpression[BlockType[Rows, Any, Scalar]],
1538
+ y: BsrMatrixOrExpression[BlockType[Any, Cols, Scalar]],
976
1539
  z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None,
977
1540
  alpha: Scalar = 1.0,
978
1541
  beta: Scalar = 0.0,
979
1542
  work_arrays: Optional[bsr_mm_work_arrays] = None,
1543
+ reuse_topology: bool = False,
980
1544
  ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
981
1545
  """
982
1546
  Performs the sparse matrix-matrix multiplication ``z := alpha * x * y + beta * z`` on BSR matrices `x`, `y` and `z`, and returns `z`.
@@ -991,8 +1555,16 @@ def bsr_mm(
991
1555
  alpha: Uniform scaling factor for the ``x * y`` product
992
1556
  beta: Uniform scaling factor for `z`
993
1557
  work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_mm_work_arrays` in `work_arrays`.
1558
+ reuse_topology: If True, reuse the product topology information stored in `work_arrays` rather than recompute it from scratch.
1559
+ The matrices x, y and z must be structurally similar to the previous call in which `work_arrays` were populated.
1560
+ This is necessary for `bsr_mm` to be captured in a CUDA graph.
994
1561
  """
995
1562
 
1563
+ x, x_scale = _extract_matrix_and_scale(x)
1564
+ alpha *= x_scale
1565
+ y, y_scale = _extract_matrix_and_scale(y)
1566
+ alpha *= y_scale
1567
+
996
1568
  if z is None:
997
1569
  # If not output matrix is provided, allocate it for convenience
998
1570
  z_block_shape = (x.block_shape[0], y.block_shape[1])
@@ -1030,76 +1602,84 @@ def bsr_mm(
1030
1602
  if not isinstance(beta, z.scalar_type):
1031
1603
  beta = z.scalar_type(beta)
1032
1604
 
1033
- if work_arrays is None:
1034
- work_arrays = bsr_mm_work_arrays()
1035
-
1036
1605
  z_aliasing = z == x or z == y
1037
- copied_z_nnz = z.nnz if beta != 0.0 or z_aliasing else 0
1038
1606
 
1039
- work_arrays._allocate_stage_1(device, z, copied_z_nnz, z_aliasing)
1607
+ if reuse_topology:
1608
+ if work_arrays is None:
1609
+ raise ValueError("`work_arrays` must not be ``None`` in order to reuse matrix-matrix product topology")
1040
1610
 
1041
- # Prefix sum of number of (unmerged) mm blocks per row
1042
- wp.launch(
1043
- kernel=_bsr_mm_count_coeffs,
1044
- device=device,
1045
- dim=z.nrow,
1046
- inputs=[
1047
- copied_z_nnz,
1048
- x.offsets,
1049
- x.columns,
1050
- y.offsets,
1051
- work_arrays._mm_row_counts,
1052
- ],
1053
- )
1054
- warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
1055
-
1056
- # Get back total counts on host
1057
- if device.is_cuda:
1058
- wp.copy(
1059
- dest=work_arrays._pinned_count_buffer,
1060
- src=work_arrays._mm_row_counts,
1061
- src_offset=z.nrow,
1062
- count=1,
1063
- )
1064
- wp.synchronize_stream(wp.get_stream(device))
1065
- mm_nnz = int(work_arrays._pinned_count_buffer.numpy()[0])
1611
+ copied_z_nnz = work_arrays._copied_z_nnz
1612
+ mm_nnz = work_arrays._mm_nnz
1066
1613
  else:
1067
- mm_nnz = int(work_arrays._mm_row_counts.numpy()[z.nrow])
1614
+ if device.is_capturing:
1615
+ raise RuntimeError("`bsr_mm` requires `reuse_topology=True` for use in graph capture")
1068
1616
 
1069
- work_arrays._allocate_stage_2(mm_nnz)
1617
+ if work_arrays is None:
1618
+ work_arrays = bsr_mm_work_arrays()
1070
1619
 
1071
- # If z has a non-zero scale, save current data before overwriting it
1072
- if copied_z_nnz > 0:
1073
- # Copy z row and column indices
1074
- wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1620
+ work_arrays._allocate_stage_1(device, z, beta, z_aliasing)
1621
+ copied_z_nnz = work_arrays._copied_z_nnz
1622
+
1623
+ # Prefix sum of number of (unmerged) mm blocks per row
1075
1624
  wp.launch(
1076
- kernel=_bsr_get_block_row,
1625
+ kernel=_bsr_mm_count_coeffs,
1077
1626
  device=device,
1078
- dim=copied_z_nnz,
1079
- inputs=[0, z.offsets, work_arrays._mm_rows],
1627
+ dim=z.nrow,
1628
+ inputs=[
1629
+ copied_z_nnz,
1630
+ x.offsets,
1631
+ x.columns,
1632
+ y.offsets,
1633
+ work_arrays._mm_row_counts,
1634
+ ],
1635
+ )
1636
+ warp.utils.array_scan(work_arrays._mm_row_counts, work_arrays._mm_row_counts)
1637
+
1638
+ # Get back total counts on host -- we need a synchronization here
1639
+ # Use pinned buffer from z, we are going to need it later anyway
1640
+ nnz_buf, _ = z._nnz_transfer_buf_and_event()
1641
+ stream = wp.get_stream(device) if device.is_cuda else None
1642
+ wp.copy(dest=nnz_buf, src=work_arrays._mm_row_counts, src_offset=z.nrow, count=1, stream=stream)
1643
+ if device.is_cuda:
1644
+ wp.synchronize_stream(stream)
1645
+ mm_nnz = int(nnz_buf.numpy()[0])
1646
+
1647
+ work_arrays._allocate_stage_2(mm_nnz)
1648
+
1649
+ # If z has a non-zero scale, save current data before overwriting it
1650
+ if copied_z_nnz > 0:
1651
+ # Copy z row and column indices
1652
+ wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1653
+ wp.launch(
1654
+ kernel=_bsr_get_block_row,
1655
+ device=device,
1656
+ dim=copied_z_nnz,
1657
+ inputs=[0, z.nrow, z.offsets, work_arrays._mm_rows],
1658
+ )
1659
+ if z_aliasing:
1660
+ # If z is aliasing with x or y, need to save topology as well
1661
+ wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1662
+ wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1663
+
1664
+ # Fill unmerged mm blocks rows and columns
1665
+ wp.launch(
1666
+ kernel=_bsr_mm_list_coeffs,
1667
+ device=device,
1668
+ dim=z.nrow,
1669
+ inputs=[
1670
+ x.offsets,
1671
+ x.columns,
1672
+ y.offsets,
1673
+ y.columns,
1674
+ work_arrays._mm_row_counts,
1675
+ work_arrays._mm_rows,
1676
+ work_arrays._mm_cols,
1677
+ ],
1080
1678
  )
1679
+
1680
+ if copied_z_nnz > 0:
1081
1681
  # Save current z values in temporary buffer
1082
1682
  wp.copy(src=z.values, dest=work_arrays._old_z_values, count=copied_z_nnz)
1083
- if z_aliasing:
1084
- # If z is aliasing with x or y, need to save topology as well
1085
- wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1086
- wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1087
-
1088
- # Fill unmerged mm blocks rows and columns
1089
- wp.launch(
1090
- kernel=_bsr_mm_list_coeffs,
1091
- device=device,
1092
- dim=z.nrow,
1093
- inputs=[
1094
- x.offsets,
1095
- x.columns,
1096
- y.offsets,
1097
- y.columns,
1098
- work_arrays._mm_row_counts,
1099
- work_arrays._mm_rows,
1100
- work_arrays._mm_cols,
1101
- ],
1102
- )
1103
1683
 
1104
1684
  # Increase dest array size if needed
1105
1685
  if z.columns.shape[0] < mm_nnz:
@@ -1112,20 +1692,31 @@ def bsr_mm(
1112
1692
  else:
1113
1693
  native_func = runtime.core.bsr_matrix_from_triplets_float_device
1114
1694
 
1115
- z.nnz = native_func(
1116
- z.block_shape[0],
1117
- z.block_shape[1],
1118
- z.nrow,
1119
- mm_nnz,
1120
- work_arrays._mm_rows.ptr,
1121
- work_arrays._mm_cols.ptr,
1122
- 0,
1123
- z.offsets.ptr,
1124
- z.columns.ptr,
1125
- 0,
1126
- )
1695
+ nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
1696
+
1697
+ with wp.ScopedDevice(z.device):
1698
+ native_func(
1699
+ z.block_shape[0],
1700
+ z.block_shape[1],
1701
+ z.nrow,
1702
+ mm_nnz,
1703
+ ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1704
+ ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1705
+ 0,
1706
+ False,
1707
+ ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1708
+ ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1709
+ 0,
1710
+ ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1711
+ nnz_event,
1712
+ )
1713
+
1714
+ # Resize z to fit mm result if necessary
1715
+ # If we are not reusing the product topology, this needs another synchronization
1716
+ if not reuse_topology:
1717
+ work_arrays.result_nnz = z.nnz_sync()
1718
+ _bsr_ensure_fits(z, nnz=work_arrays.result_nnz)
1127
1719
 
1128
- _bsr_ensure_fits(z)
1129
1720
  z.values.zero_()
1130
1721
 
1131
1722
  if copied_z_nnz > 0:
@@ -1206,12 +1797,57 @@ def _bsr_mv_kernel(
1206
1797
  y[row] = v
1207
1798
 
1208
1799
 
1800
+ @wp.kernel
1801
+ def _bsr_mv_transpose_kernel(
1802
+ alpha: Any,
1803
+ A_offsets: wp.array(dtype=int),
1804
+ A_columns: wp.array(dtype=int),
1805
+ A_values: wp.array(dtype=Any),
1806
+ x: wp.array(dtype=Any),
1807
+ y: wp.array(dtype=Any),
1808
+ ):
1809
+ row = wp.tid()
1810
+ beg = A_offsets[row]
1811
+ end = A_offsets[row + 1]
1812
+ xr = alpha * x[row]
1813
+ for block in range(beg, end):
1814
+ v = wp.transpose(A_values[block]) * xr
1815
+ wp.atomic_add(y, A_columns[block], v)
1816
+
1817
+
1818
+ def _bsr_mv_as_vec_array(array: wp.array) -> wp.array:
1819
+ if array.ndim == 1:
1820
+ return array
1821
+
1822
+ if array.ndim > 2:
1823
+ raise ValueError(f"Incompatible array number of dimensions {array.ndim}")
1824
+
1825
+ if not array.is_contiguous:
1826
+ raise ValueError("2d array must be contiguous")
1827
+
1828
+ def vec_view(array):
1829
+ return wp.array(
1830
+ data=None,
1831
+ ptr=array.ptr,
1832
+ capacity=array.capacity,
1833
+ device=array.device,
1834
+ dtype=wp.vec(length=array.shape[1], dtype=array.dtype),
1835
+ shape=array.shape[0],
1836
+ grad=None if array.grad is None else vec_view(array.grad),
1837
+ )
1838
+
1839
+ view = vec_view(array)
1840
+ view._ref = array
1841
+ return view
1842
+
1843
+
1209
1844
  def bsr_mv(
1210
- A: BsrMatrix[BlockType[Rows, Cols, Scalar]],
1845
+ A: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
1211
1846
  x: "Array[Vector[Cols, Scalar] | Scalar]",
1212
1847
  y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1213
1848
  alpha: Scalar = 1.0,
1214
1849
  beta: Scalar = 0.0,
1850
+ transpose: bool = False,
1215
1851
  work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
1216
1852
  ) -> "Array[Vector[Rows, Scalar] | Scalar]":
1217
1853
  """
@@ -1225,16 +1861,26 @@ def bsr_mv(
1225
1861
  y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero.
1226
1862
  alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized.
1227
1863
  beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized.
1864
+ transpose: If ``True``, use the transpose of the matrix `A`. In this case the result is **non-deterministic**.
1228
1865
  work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array
1229
1866
  will be used for this purpose, otherwise a temporary allocation will be performed.
1230
1867
  """
1231
1868
 
1869
+ A, A_scale = _extract_matrix_and_scale(A)
1870
+ alpha *= A_scale
1871
+
1872
+ if transpose:
1873
+ block_shape = A.block_shape[1], A.block_shape[0]
1874
+ nrow, ncol = A.ncol, A.nrow
1875
+ else:
1876
+ block_shape = A.block_shape
1877
+ nrow, ncol = A.nrow, A.ncol
1878
+
1232
1879
  if y is None:
1233
1880
  # If no output array is provided, allocate one for convenience
1234
- y_vec_len = A.block_shape[0]
1881
+ y_vec_len = block_shape[0]
1235
1882
  y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1236
- y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype)
1237
- y.zero_()
1883
+ y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype)
1238
1884
  beta = 0.0
1239
1885
 
1240
1886
  if not isinstance(alpha, A.scalar_type):
@@ -1245,12 +1891,16 @@ def bsr_mv(
1245
1891
  if A.values.device != x.device or A.values.device != y.device:
1246
1892
  raise ValueError("A, x and y must reside on the same device")
1247
1893
 
1248
- if x.shape[0] != A.ncol:
1894
+ if x.shape[0] != ncol:
1249
1895
  raise ValueError("Number of columns of A must match number of rows of x")
1250
- if y.shape[0] != A.nrow:
1896
+ if y.shape[0] != nrow:
1251
1897
  raise ValueError("Number of rows of A must match number of rows of y")
1252
1898
 
1253
- if x == y:
1899
+ # View 2d arrays as arrays of vecs
1900
+ x = _bsr_mv_as_vec_array(x)
1901
+ y = _bsr_mv_as_vec_array(y)
1902
+
1903
+ if x.ptr == y.ptr:
1254
1904
  # Aliasing case, need temporary storage
1255
1905
  if work_buffer is None:
1256
1906
  work_buffer = wp.empty_like(y)
@@ -1265,25 +1915,39 @@ def bsr_mv(
1265
1915
 
1266
1916
  # Promote scalar vectors to length-1 vecs and conversely
1267
1917
  if warp.types.type_is_matrix(A.values.dtype):
1268
- if A.block_shape[0] == 1:
1269
- if y.dtype == A.scalar_type:
1270
- y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1271
- if A.block_shape[1] == 1:
1272
- if x.dtype == A.scalar_type:
1273
- x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1918
+ if block_shape[0] == 1 and y.dtype == A.scalar_type:
1919
+ y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1920
+ if block_shape[1] == 1 and x.dtype == A.scalar_type:
1921
+ x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type))
1274
1922
  else:
1275
- if A.block_shape[0] == 1:
1276
- if y.dtype != A.scalar_type:
1277
- y = y.view(dtype=A.scalar_type)
1278
- if A.block_shape[1] == 1:
1279
- if x.dtype != A.scalar_type:
1280
- x = x.view(dtype=A.scalar_type)
1281
-
1282
- wp.launch(
1283
- kernel=_bsr_mv_kernel,
1284
- device=A.values.device,
1285
- dim=A.nrow,
1286
- inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
1287
- )
1923
+ if block_shape[0] == 1 and y.dtype != A.scalar_type:
1924
+ y = y.view(dtype=A.scalar_type)
1925
+ if block_shape[1] == 1 and x.dtype != A.scalar_type:
1926
+ x = x.view(dtype=A.scalar_type)
1927
+
1928
+ if transpose:
1929
+ if beta.value == 0.0:
1930
+ y.zero_()
1931
+ elif beta.value != 1.0:
1932
+ wp.launch(
1933
+ kernel=_bsr_scale_kernel,
1934
+ device=y.device,
1935
+ dim=y.shape[0],
1936
+ inputs=[beta, y],
1937
+ )
1938
+ if alpha.value != 0.0:
1939
+ wp.launch(
1940
+ kernel=_bsr_mv_transpose_kernel,
1941
+ device=A.values.device,
1942
+ dim=ncol,
1943
+ inputs=[alpha, A.offsets, A.columns, A.values, x, y],
1944
+ )
1945
+ else:
1946
+ wp.launch(
1947
+ kernel=_bsr_mv_kernel,
1948
+ device=A.values.device,
1949
+ dim=nrow,
1950
+ inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y],
1951
+ )
1288
1952
 
1289
1953
  return y