warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -28,9 +28,10 @@ from warp.types import (
28
28
  is_array,
29
29
  scalar_types,
30
30
  type_is_matrix,
31
- type_length,
32
31
  type_repr,
33
32
  type_scalar_type,
33
+ type_size,
34
+ type_size_in_bytes,
34
35
  type_to_warp,
35
36
  types_equal,
36
37
  )
@@ -86,7 +87,7 @@ class BsrMatrix(Generic[_BlockType]):
86
87
  @property
87
88
  def block_size(self) -> int:
88
89
  """Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
89
- return type_length(self.values.dtype)
90
+ return type_size(self.values.dtype)
90
91
 
91
92
  @property
92
93
  def shape(self) -> Tuple[int, int]:
@@ -104,23 +105,15 @@ class BsrMatrix(Generic[_BlockType]):
104
105
  """Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
105
106
  return self.values.device
106
107
 
108
+ @property
109
+ def requires_grad(self) -> bool:
110
+ """Read-only property indicating whether the matrix participates in adjoint computations."""
111
+ return self.values.requires_grad
112
+
107
113
  @property
108
114
  def scalar_values(self) -> wp.array:
109
115
  """Accesses the ``values`` array as a 3d scalar array."""
110
- if self.block_shape == (1, 1):
111
- return self.values.reshape((self.nnz, 1, 1))
112
-
113
- def _as_3d_array(arr):
114
- return wp.array(
115
- ptr=arr.ptr,
116
- capacity=arr.capacity,
117
- device=arr.device,
118
- dtype=self.scalar_type,
119
- shape=(self.nnz, *self.block_shape),
120
- grad=None if arr.grad is None else _as_3d_array(arr.grad),
121
- )
122
-
123
- values_view = _as_3d_array(self.values)
116
+ values_view = _as_3d_array(self.values, self.block_shape)
124
117
  values_view._ref = self.values # keep ref in case we're garbage collected
125
118
  return values_view
126
119
 
@@ -144,13 +137,14 @@ class BsrMatrix(Generic[_BlockType]):
144
137
  See also :meth:`copy_nnz_async`.
145
138
  """
146
139
 
147
- if self._is_nnz_transfer_setup():
148
- if self.device.is_cuda:
149
- wp.synchronize_event(self._nnz_event)
150
- self.nnz = int(self._nnz_buf.numpy()[0])
140
+ buf, event = self._nnz_transfer_if_any()
141
+ if buf is not None:
142
+ if event is not None:
143
+ wp.synchronize_event(event)
144
+ self.nnz = int(buf.numpy()[0])
151
145
  return self.nnz
152
146
 
153
- def copy_nnz_async(self, known_nnz: int = None):
147
+ def copy_nnz_async(self) -> None:
154
148
  """
155
149
  Start the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
156
150
 
@@ -158,37 +152,25 @@ class BsrMatrix(Generic[_BlockType]):
158
152
 
159
153
  See also :meth:`nnz_sync`.
160
154
  """
161
- if known_nnz is not None:
162
- self.nnz = int(known_nnz)
163
- else:
164
- self._setup_nnz_transfer()
165
155
 
166
- # If a transfer is already ongoing, or if the actual nnz is unknown, schedule a new transfer
167
- if self._is_nnz_transfer_setup():
168
- stream = wp.get_stream(self.device) if self.device.is_cuda else None
169
- wp.copy(src=self.offsets, dest=self._nnz_buf, src_offset=self.nrow, count=1, stream=stream)
170
- if self.device.is_cuda:
171
- stream.record_event(self._nnz_event)
156
+ buf, event = self._setup_nnz_transfer()
157
+ stream = wp.get_stream(self.device) if self.device.is_cuda else None
158
+ wp.copy(src=self.offsets, dest=buf, src_offset=self.nrow, count=1, stream=stream)
159
+ if event is not None:
160
+ stream.record_event(event)
172
161
 
173
162
  def _setup_nnz_transfer(self):
174
- if self._is_nnz_transfer_setup():
175
- return
176
-
177
- BsrMatrix.__setattr__(
178
- self, "_nnz_buf", wp.empty(dtype=int, shape=(1,), device="cpu", pinned=self.device.is_cuda)
179
- )
180
- if self.device.is_cuda:
181
- BsrMatrix.__setattr__(self, "_nnz_event", wp.Event(self.device))
182
-
183
- def _is_nnz_transfer_setup(self):
184
- return hasattr(self, "_nnz_buf")
163
+ buf, event = self._nnz_transfer_if_any()
164
+ if buf is not None or self.device.is_capturing:
165
+ return buf, event
185
166
 
186
- def _nnz_transfer_buf_and_event(self):
187
- self._setup_nnz_transfer()
167
+ buf = wp.empty(dtype=int, shape=(1,), device="cpu", pinned=self.device.is_cuda)
168
+ event = wp.Event(self.device) if self.device.is_cuda else None
169
+ BsrMatrix.__setattr__(self, "_nnz_transfer", (buf, event))
170
+ return buf, event
188
171
 
189
- if not self.device.is_cuda:
190
- return self._nnz_buf, ctypes.c_void_p(None)
191
- return self._nnz_buf, self._nnz_event.cuda_event
172
+ def _nnz_transfer_if_any(self):
173
+ return getattr(self, "_nnz_transfer", (None, None))
192
174
 
193
175
  # Overloaded math operators
194
176
  def __add__(self, y):
@@ -303,7 +285,7 @@ def bsr_zeros(
303
285
 
304
286
  bsr.nrow = int(rows_of_blocks)
305
287
  bsr.ncol = int(cols_of_blocks)
306
- bsr.nnz = int(0)
288
+ bsr.nnz = 0
307
289
  bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
308
290
  bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
309
291
  bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
@@ -311,7 +293,7 @@ def bsr_zeros(
311
293
  return bsr
312
294
 
313
295
 
314
- def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
296
+ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: Optional[int] = None, nnz: Optional[int] = None) -> None:
315
297
  if nrow is None:
316
298
  nrow = bsr.nrow
317
299
  if nnz is None:
@@ -325,7 +307,9 @@ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
325
307
  if bsr.columns.size < nnz:
326
308
  bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
327
309
  if bsr.values.size < nnz:
328
- bsr.values = wp.empty(shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device)
310
+ bsr.values = wp.empty(
311
+ shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device, requires_grad=bsr.values.requires_grad
312
+ )
329
313
 
330
314
 
331
315
  def bsr_set_zero(
@@ -348,7 +332,64 @@ def bsr_set_zero(
348
332
 
349
333
  _bsr_ensure_fits(bsr, nnz=0)
350
334
  bsr.offsets.zero_()
351
- bsr.copy_nnz_async(known_nnz=0)
335
+ bsr.copy_nnz_async()
336
+
337
+
338
+ def _as_3d_array(arr, block_shape):
339
+ return wp.array(
340
+ ptr=arr.ptr,
341
+ capacity=arr.capacity,
342
+ device=arr.device,
343
+ dtype=type_scalar_type(arr.dtype),
344
+ shape=(arr.shape[0], *block_shape),
345
+ grad=None if arr.grad is None else _as_3d_array(arr.grad, block_shape),
346
+ )
347
+
348
+
349
+ def _optional_ctypes_pointer(array: Optional[wp.array], ctype):
350
+ return None if array is None else ctypes.cast(array.ptr, ctypes.POINTER(ctype))
351
+
352
+
353
+ def _optional_ctypes_event(event: Optional[wp.Event]):
354
+ return None if event is None else event.cuda_event
355
+
356
+
357
+ _zero_value_masks = {
358
+ wp.float16: 0x7FFF,
359
+ wp.float32: 0x7FFFFFFF,
360
+ wp.float64: 0x7FFFFFFFFFFFFFFF,
361
+ wp.int8: 0xFF,
362
+ wp.int16: 0xFFFF,
363
+ wp.int32: 0xFFFFFFFF,
364
+ wp.int64: 0xFFFFFFFFFFFFFFFF,
365
+ }
366
+
367
+
368
+ @wp.kernel
369
+ def _bsr_accumulate_triplet_values(
370
+ row_count: int,
371
+ tpl_summed_offsets: wp.array(dtype=int),
372
+ tpl_summed_indices: wp.array(dtype=int),
373
+ tpl_values: wp.array3d(dtype=Any),
374
+ bsr_offsets: wp.array(dtype=int),
375
+ bsr_values: wp.array3d(dtype=Any),
376
+ ):
377
+ block, i, j = wp.tid()
378
+
379
+ if block >= bsr_offsets[row_count]:
380
+ return
381
+
382
+ if block == 0:
383
+ beg = 0
384
+ else:
385
+ beg = tpl_summed_offsets[block - 1]
386
+ end = tpl_summed_offsets[block]
387
+
388
+ val = tpl_values[tpl_summed_indices[beg], i, j]
389
+ for k in range(beg + 1, end):
390
+ val += tpl_values[tpl_summed_indices[k], i, j]
391
+
392
+ bsr_values[block, i, j] = val
352
393
 
353
394
 
354
395
  def bsr_set_from_triplets(
@@ -356,6 +397,7 @@ def bsr_set_from_triplets(
356
397
  rows: "Array[int]",
357
398
  columns: "Array[int]",
358
399
  values: Optional["Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]"] = None,
400
+ count: Optional["Array[int]"] = None,
359
401
  prune_numerical_zeros: bool = True,
360
402
  masked: bool = False,
361
403
  ):
@@ -370,27 +412,50 @@ def bsr_set_from_triplets(
370
412
  values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
371
413
  to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
372
414
  If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
415
+ count: Single-element array indicating the number of triplets. If ``None``, the number of triplets is determined from the shape of
416
+ ``rows`` and ``columns`` arrays.
373
417
  prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
374
418
  masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
375
419
  """
376
420
 
377
421
  if rows.device != columns.device or rows.device != dest.device:
378
- raise ValueError("All arguments must reside on the same device")
422
+ raise ValueError(
423
+ f"Rows and columns must reside on the destination matrix device, got {rows.device}, {columns.device} and {dest.device}"
424
+ )
379
425
 
380
426
  if rows.shape[0] != columns.shape[0]:
381
- raise ValueError("All triplet arrays must have the same length")
427
+ raise ValueError(
428
+ f"Rows and columns arrays must have the same length, got {rows.shape[0]} and {columns.shape[0]}"
429
+ )
430
+
431
+ if rows.dtype != wp.int32 or columns.dtype != wp.int32:
432
+ raise TypeError("Rows and columns arrays must be of type int32")
433
+
434
+ if count is not None:
435
+ if count.device != rows.device:
436
+ raise ValueError(f"Count and rows must reside on the same device, got {count.device} and {rows.device}")
437
+
438
+ if count.shape != (1,):
439
+ raise ValueError(f"Count array must be a single-element array, got {count.shape}")
440
+
441
+ if count.dtype != wp.int32:
442
+ raise TypeError("Count array must be of type int32")
382
443
 
383
444
  # Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
384
445
  if values is not None:
385
446
  if values.device != rows.device:
386
- raise ValueError("All arguments must reside on the same device")
447
+ raise ValueError(f"Values and rows must reside on the same device, got {values.device} and {rows.device}")
387
448
 
388
449
  if values.shape[0] != rows.shape[0]:
389
- raise ValueError("All triplet arrays must have the same length")
450
+ raise ValueError(
451
+ f"Values and rows arrays must have the same length, got {values.shape[0]} and {rows.shape[0]}"
452
+ )
390
453
 
391
454
  if values.ndim == 1:
392
- if values.dtype != dest.values.dtype:
393
- raise ValueError("Values array type must correspond to that of dest matrix")
455
+ if not types_equal(values.dtype, dest.values.dtype):
456
+ raise ValueError(
457
+ f"Values array type must correspond to that of the dest matrix, got {type_repr(values.dtype)} and {type_repr(dest.values.dtype)}"
458
+ )
394
459
  elif values.ndim == 3:
395
460
  if values.shape[1:] != dest.block_shape:
396
461
  raise ValueError(
@@ -398,12 +463,14 @@ def bsr_set_from_triplets(
398
463
  )
399
464
 
400
465
  if type_scalar_type(values.dtype) != dest.scalar_type:
401
- raise ValueError("Scalar type of values array should correspond to that of matrix")
402
-
403
- if not values.is_contiguous:
404
- raise ValueError("Multi-dimensional values array should be contiguous")
466
+ raise ValueError(
467
+ f"Scalar type of values array ({type_repr(values.dtype)}) should correspond to that of matrix ({type_repr(dest.scalar_type)})"
468
+ )
405
469
  else:
406
- raise ValueError("Number of dimension for values array should be 1 or 3")
470
+ raise ValueError(f"Number of dimension for values array should be 1 or 3, got {values.ndim}")
471
+
472
+ if prune_numerical_zeros and not values.is_contiguous:
473
+ raise ValueError("Values array should be contiguous for numerical zero pruning")
407
474
 
408
475
  nnz = rows.shape[0]
409
476
  if nnz == 0:
@@ -416,40 +483,54 @@ def bsr_set_from_triplets(
416
483
 
417
484
  device = dest.values.device
418
485
  scalar_type = dest.scalar_type
486
+ zero_value_mask = _zero_value_masks.get(scalar_type, 0)
487
+
488
+ # compute the BSR topology
489
+
419
490
  from warp.context import runtime
420
491
 
421
492
  if device.is_cpu:
422
- if scalar_type == wp.float32:
423
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
424
- elif scalar_type == wp.float64:
425
- native_func = runtime.core.bsr_matrix_from_triplets_double_host
493
+ native_func = runtime.core.bsr_matrix_from_triplets_host
426
494
  else:
427
- if scalar_type == wp.float32:
428
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
429
- elif scalar_type == wp.float64:
430
- native_func = runtime.core.bsr_matrix_from_triplets_double_device
495
+ native_func = runtime.core.bsr_matrix_from_triplets_device
431
496
 
432
- if not native_func:
433
- raise NotImplementedError(f"bsr_from_triplets not implemented for scalar type {scalar_type}")
434
-
435
- nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
497
+ nnz_buf, nnz_event = dest._setup_nnz_transfer()
498
+ summed_triplet_offsets = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
499
+ summed_triplet_indices = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
436
500
 
437
501
  with wp.ScopedDevice(device):
438
502
  native_func(
439
- dest.block_shape[0],
440
- dest.block_shape[1],
503
+ dest.block_size,
504
+ type_size_in_bytes(scalar_type),
441
505
  dest.nrow,
506
+ dest.ncol,
442
507
  nnz,
508
+ _optional_ctypes_pointer(count, ctype=ctypes.c_int32),
443
509
  ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
444
510
  ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
445
- None if values is None else ctypes.cast(values.ptr, ctypes.c_void_p),
446
- prune_numerical_zeros,
511
+ _optional_ctypes_pointer(values, ctype=ctypes.c_int32),
512
+ zero_value_mask,
447
513
  masked,
514
+ ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
515
+ ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
448
516
  ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
449
517
  ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
450
- None if values is None else ctypes.cast(dest.values.ptr, ctypes.c_void_p),
451
- ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
452
- nnz_event,
518
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
519
+ _optional_ctypes_event(nnz_event),
520
+ )
521
+
522
+ # now accumulate repeated blocks
523
+ wp.launch(
524
+ _bsr_accumulate_triplet_values,
525
+ dim=(nnz, *dest.block_shape),
526
+ inputs=[
527
+ dest.nrow,
528
+ summed_triplet_offsets,
529
+ summed_triplet_indices,
530
+ _as_3d_array(values, dest.block_shape),
531
+ dest.offsets,
532
+ ],
533
+ outputs=[dest.scalar_values],
453
534
  )
454
535
 
455
536
 
@@ -483,6 +564,7 @@ def bsr_from_triplets(
483
564
  A = bsr_zeros(
484
565
  rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
485
566
  )
567
+ A.values.requires_grad = values.requires_grad
486
568
  bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
487
569
  return A
488
570
 
@@ -539,6 +621,10 @@ class _BsrScalingExpression(_BsrExpression):
539
621
  def dtype(self) -> type:
540
622
  return self.mat.dtype
541
623
 
624
+ @property
625
+ def requires_grad(self) -> bool:
626
+ return self.mat.requires_grad
627
+
542
628
  @property
543
629
  def device(self) -> wp.context.Device:
544
630
  return self.mat.device
@@ -721,10 +807,10 @@ def bsr_assign(
721
807
  src: Matrix to be copied.
722
808
  dest: Destination matrix. May have a different block shape or scalar type
723
809
  than ``src``, in which case the required casting will be performed.
724
- structure_only: If ``True``, only the non-zeros indices are copied, and uninitialized value storage is allocated
810
+ structure_only: If ``True``, only the non-zero indices are copied, and uninitialized value storage is allocated
725
811
  to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
726
812
  casting if the two matrices use distinct scalar types.
727
- masked: If ``True``, prevent the assignment operation from adding new non-zeros blocks to ``dest``.
813
+ masked: If ``True``, prevent the assignment operation from adding new non-zero blocks to ``dest``.
728
814
  """
729
815
 
730
816
  src, src_scale = _extract_matrix_and_scale(src)
@@ -741,7 +827,7 @@ def bsr_assign(
741
827
 
742
828
  if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
743
829
  raise ValueError(
744
- f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {src.block_shape[0]}, {dest.block_shape[0]})"
830
+ f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {dest.block_shape[0]}, {src.block_shape[0]})"
745
831
  )
746
832
 
747
833
  if src.block_shape[1] >= dest.block_shape[1]:
@@ -753,14 +839,16 @@ def bsr_assign(
753
839
 
754
840
  if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
755
841
  raise ValueError(
756
- f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {src.block_shape[1]}, {dest.block_shape[1]})"
842
+ f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {dest.block_shape[1]}, {src.block_shape[1]})"
757
843
  )
758
844
 
759
845
  dest_nrow = (src.nrow * src_subrows) // dest_subrows
760
846
  dest_ncol = (src.ncol * src_subcols) // dest_subcols
761
847
 
762
848
  if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
763
- raise ValueError("The requested block shape does not evenly divide the source matrix")
849
+ raise ValueError(
850
+ f"The requested block shape {dest.block_shape} does not evenly divide the source matrix of total size {src.shape}"
851
+ )
764
852
 
765
853
  nnz_alloc = src.nnz * src_subrows * src_subcols
766
854
  if masked:
@@ -813,27 +901,30 @@ def bsr_assign(
813
901
  from warp.context import runtime
814
902
 
815
903
  if dest.device.is_cpu:
816
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
904
+ native_func = runtime.core.bsr_matrix_from_triplets_host
817
905
  else:
818
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
906
+ native_func = runtime.core.bsr_matrix_from_triplets_device
819
907
 
820
- nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
908
+ nnz_buf, nnz_event = dest._setup_nnz_transfer()
821
909
  with wp.ScopedDevice(dest.device):
822
910
  native_func(
823
- dest.block_shape[0],
824
- dest.block_shape[1],
911
+ dest.block_size,
912
+ 0, # scalar_size_in_bytes
825
913
  dest.nrow,
914
+ dest.ncol,
826
915
  nnz_alloc,
916
+ None, # device nnz
827
917
  ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
828
918
  ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
829
- 0,
830
- False,
919
+ None, # triplet values
920
+ 0, # zero_value_mask
831
921
  masked,
922
+ None, # summed block offsets
923
+ None, # summed block indices
832
924
  ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
833
925
  ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
834
- 0,
835
- ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
836
- nnz_event,
926
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
927
+ _optional_ctypes_event(nnz_event),
837
928
  )
838
929
 
839
930
  # merge block values
@@ -893,10 +984,28 @@ def bsr_copy(
893
984
  block_type=block_type,
894
985
  device=A.device,
895
986
  )
987
+ copy.values.requires_grad = A.requires_grad
896
988
  bsr_assign(dest=copy, src=A, structure_only=structure_only)
897
989
  return copy
898
990
 
899
991
 
992
+ @wp.kernel
993
+ def _bsr_transpose_values(
994
+ col_count: int,
995
+ scale: Any,
996
+ bsr_values: wp.array3d(dtype=Any),
997
+ block_index_map: wp.array(dtype=int),
998
+ transposed_bsr_offsets: wp.array(dtype=int),
999
+ transposed_bsr_values: wp.array3d(dtype=Any),
1000
+ ):
1001
+ block, i, j = wp.tid()
1002
+
1003
+ if block >= transposed_bsr_offsets[col_count]:
1004
+ return
1005
+
1006
+ transposed_bsr_values[block, i, j] = bsr_values[block_index_map[block], j, i] * scale
1007
+
1008
+
900
1009
  def bsr_set_transpose(
901
1010
  dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
902
1011
  src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
@@ -906,15 +1015,17 @@ def bsr_set_transpose(
906
1015
  src, src_scale = _extract_matrix_and_scale(src)
907
1016
 
908
1017
  if dest.values.device != src.values.device:
909
- raise ValueError("All arguments must reside on the same device")
1018
+ raise ValueError(
1019
+ f"All arguments must reside on the same device, got {dest.values.device} and {src.values.device}"
1020
+ )
910
1021
 
911
1022
  if dest.scalar_type != src.scalar_type:
912
- raise ValueError("All arguments must have the same scalar type")
1023
+ raise ValueError(f"All arguments must have the same scalar type, got {dest.scalar_type} and {src.scalar_type}")
913
1024
 
914
1025
  transpose_block_shape = src.block_shape[::-1]
915
1026
 
916
1027
  if dest.block_shape != transpose_block_shape:
917
- raise ValueError(f"Destination block shape must be {transpose_block_shape}")
1028
+ raise ValueError(f"Destination block shape must be {transpose_block_shape}, got {dest.block_shape}")
918
1029
 
919
1030
  nnz = src.nnz
920
1031
  dest.nrow = src.ncol
@@ -930,36 +1041,33 @@ def bsr_set_transpose(
930
1041
  from warp.context import runtime
931
1042
 
932
1043
  if dest.values.device.is_cpu:
933
- if dest.scalar_type == wp.float32:
934
- native_func = runtime.core.bsr_transpose_float_host
935
- elif dest.scalar_type == wp.float64:
936
- native_func = runtime.core.bsr_transpose_double_host
1044
+ native_func = runtime.core.bsr_transpose_host
937
1045
  else:
938
- if dest.scalar_type == wp.float32:
939
- native_func = runtime.core.bsr_transpose_float_device
940
- elif dest.scalar_type == wp.float64:
941
- native_func = runtime.core.bsr_transpose_double_device
1046
+ native_func = runtime.core.bsr_transpose_device
942
1047
 
943
- if not native_func:
944
- raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
1048
+ block_index_map = wp.empty(shape=2 * nnz, dtype=int, device=src.device)
945
1049
 
946
1050
  with wp.ScopedDevice(dest.device):
947
1051
  native_func(
948
- src.block_shape[0],
949
- src.block_shape[1],
950
1052
  src.nrow,
951
1053
  src.ncol,
952
1054
  nnz,
953
1055
  ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
954
1056
  ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
955
- ctypes.cast(src.values.ptr, ctypes.c_void_p),
956
1057
  ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
957
1058
  ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
958
- ctypes.cast(dest.values.ptr, ctypes.c_void_p),
1059
+ ctypes.cast(block_index_map.ptr, ctypes.POINTER(ctypes.c_int32)),
959
1060
  )
960
1061
 
961
- dest.copy_nnz_async()
962
- bsr_scale(dest, src_scale)
1062
+ dest.copy_nnz_async()
1063
+
1064
+ wp.launch(
1065
+ _bsr_transpose_values,
1066
+ dim=(nnz, *dest.block_shape),
1067
+ device=dest.device,
1068
+ inputs=[src.ncol, dest.scalar_type(src_scale), src.scalar_values, block_index_map, dest.offsets],
1069
+ outputs=[dest.scalar_values],
1070
+ )
963
1071
 
964
1072
 
965
1073
  def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
@@ -976,6 +1084,7 @@ def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
976
1084
  block_type=block_type,
977
1085
  device=A.device,
978
1086
  )
1087
+ transposed.values.requires_grad = A.requires_grad
979
1088
  bsr_set_transpose(dest=transposed, src=A)
980
1089
  return transposed
981
1090
 
@@ -1010,12 +1119,12 @@ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[Block
1010
1119
  if out is None:
1011
1120
  out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
1012
1121
  else:
1013
- if out.dtype != A.values.dtype:
1014
- raise ValueError(f"Output array must have type {A.values.dtype}")
1122
+ if not types_equal(out.dtype, A.values.dtype):
1123
+ raise ValueError(f"Output array must have type {A.values.dtype}, got {out.dtype}")
1015
1124
  if out.device != A.values.device:
1016
- raise ValueError(f"Output array must reside on device {A.values.device}")
1125
+ raise ValueError(f"Output array must reside on device {A.values.device}, got {out.device}")
1017
1126
  if out.shape[0] < dim:
1018
- raise ValueError(f"Output array must be of length at least {dim}")
1127
+ raise ValueError(f"Output array must be of length at least {dim}, got {out.shape[0]}")
1019
1128
 
1020
1129
  wp.launch(
1021
1130
  kernel=_bsr_get_diag_kernel,
@@ -1095,7 +1204,7 @@ def bsr_set_diag(
1095
1204
  elif diag is not None:
1096
1205
  A.values.fill_(diag)
1097
1206
 
1098
- A.copy_nnz_async(known_nnz=nnz)
1207
+ A.copy_nnz_async()
1099
1208
 
1100
1209
 
1101
1210
  def bsr_diag(
@@ -1151,6 +1260,8 @@ def bsr_diag(
1151
1260
  block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
1152
1261
 
1153
1262
  A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
1263
+ if is_array(diag):
1264
+ A.values.requires_grad = diag.requires_grad
1154
1265
  bsr_set_diag(A, diag)
1155
1266
  return A
1156
1267
 
@@ -1292,8 +1403,8 @@ def bsr_axpy(
1292
1403
  The ``x`` and ``y`` matrices are allowed to alias.
1293
1404
 
1294
1405
  Args:
1295
- x: Read-only right-hand-side.
1296
- y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
1406
+ x: Read-only first operand.
1407
+ y: Mutable second operand and output matrix. If ``y`` is not provided, it will be allocated and treated as zero.
1297
1408
  alpha: Uniform scaling factor for ``x``.
1298
1409
  beta: Uniform scaling factor for ``y``.
1299
1410
  masked: If ``True``, discard all blocks from ``x`` which are not
@@ -1312,6 +1423,7 @@ def bsr_axpy(
1312
1423
 
1313
1424
  # If not output matrix is provided, allocate it for convenience
1314
1425
  y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
1426
+ y.values.requires_grad = x.requires_grad
1315
1427
  beta = 0.0
1316
1428
 
1317
1429
  x_nnz = x.nnz
@@ -1337,13 +1449,17 @@ def bsr_axpy(
1337
1449
  # General case
1338
1450
 
1339
1451
  if x.values.device != y.values.device:
1340
- raise ValueError("All arguments must reside on the same device")
1452
+ raise ValueError(f"All arguments must reside on the same device, got {x.values.device} and {y.values.device}")
1341
1453
 
1342
1454
  if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
1343
- raise ValueError("Matrices must have the same block type")
1455
+ raise ValueError(
1456
+ f"Matrices must have the same block type, got ({x.block_shape}, {x.scalar_type}) and ({y.block_shape}, {y.scalar_type})"
1457
+ )
1344
1458
 
1345
1459
  if x.nrow != y.nrow or x.ncol != y.ncol:
1346
- raise ValueError("Matrices must have the same number of rows and columns")
1460
+ raise ValueError(
1461
+ f"Matrices must have the same number of rows and columns, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
1462
+ )
1347
1463
 
1348
1464
  if work_arrays is None:
1349
1465
  work_arrays = bsr_axpy_work_arrays()
@@ -1368,29 +1484,32 @@ def bsr_axpy(
1368
1484
  from warp.context import runtime
1369
1485
 
1370
1486
  if device.is_cpu:
1371
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
1487
+ native_func = runtime.core.bsr_matrix_from_triplets_host
1372
1488
  else:
1373
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
1489
+ native_func = runtime.core.bsr_matrix_from_triplets_device
1374
1490
 
1375
1491
  old_y_nnz = y_nnz
1376
- nnz_buf, nnz_event = y._nnz_transfer_buf_and_event()
1492
+ nnz_buf, nnz_event = y._setup_nnz_transfer()
1377
1493
 
1378
1494
  with wp.ScopedDevice(y.device):
1379
1495
  native_func(
1380
- y.block_shape[0],
1381
- y.block_shape[1],
1496
+ y.block_size,
1497
+ 0, # scalar_size_in_bytes
1382
1498
  y.nrow,
1499
+ y.ncol,
1383
1500
  sum_nnz,
1501
+ None, # device nnz
1384
1502
  ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1385
1503
  ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1386
- 0,
1387
- False,
1504
+ None, # triplet values
1505
+ 0, # zero_value_mask
1388
1506
  masked,
1507
+ None, # summed block offsets
1508
+ None, # summed block indices
1389
1509
  ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1390
1510
  ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1391
- 0,
1392
- ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1393
- nnz_event,
1511
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
1512
+ _optional_ctypes_event(nnz_event),
1394
1513
  )
1395
1514
 
1396
1515
  y.values.zero_()
@@ -1617,9 +1736,9 @@ def bsr_mm(
1617
1736
  If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
1618
1737
 
1619
1738
  Args:
1620
- x: Read-only left factor of the matrix-matrix product.
1621
- y: Read-only right factor of the matrix-matrix product.
1622
- z: Mutable left-hand-side. If ``z`` is not provided, it will be allocated and treated as zero.
1739
+ x: Read-only left operand of the matrix-matrix product.
1740
+ y: Read-only right operand of the matrix-matrix product.
1741
+ z: Mutable affine operand and result matrix. If ``z`` is not provided, it will be allocated and treated as zero.
1623
1742
  alpha: Uniform scaling factor for the ``x @ y`` product
1624
1743
  beta: Uniform scaling factor for ``z``
1625
1744
  masked: If ``True``, ignore all blocks from ``x @ y`` which are not existing non-zeros of ``y``
@@ -1649,23 +1768,32 @@ def bsr_mm(
1649
1768
  else:
1650
1769
  z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
1651
1770
  z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
1771
+ z.values.requires_grad = x.requires_grad or y.requires_grad
1652
1772
  beta = 0.0
1653
1773
 
1654
1774
  if x.values.device != y.values.device or x.values.device != z.values.device:
1655
- raise ValueError("All arguments must reside on the same device")
1775
+ raise ValueError(
1776
+ f"All arguments must reside on the same device, got {x.values.device}, {y.values.device} and {z.values.device}"
1777
+ )
1656
1778
 
1657
1779
  if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
1658
- raise ValueError("Matrices must have the same scalar type")
1780
+ raise ValueError(
1781
+ f"Matrices must have the same scalar type, got {x.scalar_type}, {y.scalar_type} and {z.scalar_type}"
1782
+ )
1659
1783
 
1660
1784
  if (
1661
1785
  x.block_shape[0] != z.block_shape[0]
1662
1786
  or y.block_shape[1] != z.block_shape[1]
1663
1787
  or x.block_shape[1] != y.block_shape[0]
1664
1788
  ):
1665
- raise ValueError("Incompatible block sizes for matrix multiplication")
1789
+ raise ValueError(
1790
+ f"Incompatible block sizes for matrix multiplication, got ({x.block_shape}, {y.block_shape}) and ({z.block_shape})"
1791
+ )
1666
1792
 
1667
1793
  if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
1668
- raise ValueError("Incompatible number of rows/columns for matrix multiplication")
1794
+ raise ValueError(
1795
+ f"Incompatible number of rows/columns for matrix multiplication, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
1796
+ )
1669
1797
 
1670
1798
  device = z.values.device
1671
1799
 
@@ -1696,7 +1824,9 @@ def bsr_mm(
1696
1824
  mm_nnz = work_arrays._mm_nnz
1697
1825
  else:
1698
1826
  if device.is_capturing:
1699
- raise RuntimeError("`bsr_mm` requires `reuse_topology=True` for use in graph capture")
1827
+ raise RuntimeError(
1828
+ "`bsr_mm` requires either `reuse_topology=True` or `masked=True` for use in graph capture"
1829
+ )
1700
1830
 
1701
1831
  if work_arrays is None:
1702
1832
  work_arrays = bsr_mm_work_arrays()
@@ -1725,7 +1855,7 @@ def bsr_mm(
1725
1855
 
1726
1856
  # Get back total counts on host -- we need a synchronization here
1727
1857
  # Use pinned buffer from z, we are going to need it later anyway
1728
- nnz_buf, _ = z._nnz_transfer_buf_and_event()
1858
+ nnz_buf, _ = z._setup_nnz_transfer()
1729
1859
  stream = wp.get_stream(device) if device.is_cuda else None
1730
1860
  wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
1731
1861
  if device.is_cuda:
@@ -1782,28 +1912,31 @@ def bsr_mm(
1782
1912
  from warp.context import runtime
1783
1913
 
1784
1914
  if device.is_cpu:
1785
- native_func = runtime.core.bsr_matrix_from_triplets_float_host
1915
+ native_func = runtime.core.bsr_matrix_from_triplets_host
1786
1916
  else:
1787
- native_func = runtime.core.bsr_matrix_from_triplets_float_device
1917
+ native_func = runtime.core.bsr_matrix_from_triplets_device
1788
1918
 
1789
- nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
1919
+ nnz_buf, nnz_event = z._setup_nnz_transfer()
1790
1920
 
1791
1921
  with wp.ScopedDevice(z.device):
1792
1922
  native_func(
1793
- z.block_shape[0],
1794
- z.block_shape[1],
1923
+ z.block_size,
1924
+ 0, # scalar_size_in_bytes
1795
1925
  z.nrow,
1926
+ z.ncol,
1796
1927
  mm_nnz,
1928
+ None, # device nnz
1797
1929
  ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
1798
1930
  ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
1799
- 0,
1800
- False,
1801
- masked,
1931
+ None, # triplet values
1932
+ 0, # zero_value_mask
1933
+ False, # masked_topology
1934
+ None, # summed block offsets
1935
+ None, # summed block indices
1802
1936
  ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1803
1937
  ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1804
- 0,
1805
- ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
1806
- nnz_event,
1938
+ _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
1939
+ _optional_ctypes_event(nnz_event),
1807
1940
  )
1808
1941
 
1809
1942
  # Resize z to fit mm result if necessary
@@ -1912,7 +2045,7 @@ def _bsr_mv_transpose_kernel(
1912
2045
  def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
1913
2046
  # cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
1914
2047
 
1915
- scalar_count = array.size * type_length(array.dtype)
2048
+ scalar_count = array.size * type_size(array.dtype)
1916
2049
  if scalar_count != expected_scalar_count:
1917
2050
  raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
1918
2051
 
@@ -1920,15 +2053,15 @@ def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) ->
1920
2053
  return array
1921
2054
 
1922
2055
  if type_scalar_type(array.dtype) != type_scalar_type(dtype):
1923
- raise ValueError(f"Incompatible scalar types, {type_repr(array.dtype)} vs {type_repr(dtype)}")
2056
+ raise ValueError(f"Incompatible scalar types, expected {type_repr(array.dtype)}, got {type_repr(dtype)}")
1924
2057
 
1925
2058
  if array.ndim > 2:
1926
- raise ValueError(f"Incompatible array number of dimensions {array.ndim}")
2059
+ raise ValueError(f"Incompatible array number of dimensions, expected 1 or 2, got {array.ndim}")
1927
2060
 
1928
2061
  if not array.is_contiguous:
1929
2062
  raise ValueError("Array must be contiguous")
1930
2063
 
1931
- vec_length = type_length(dtype)
2064
+ vec_length = type_size(dtype)
1932
2065
  vec_count = scalar_count // vec_length
1933
2066
  if vec_count * vec_length != scalar_count:
1934
2067
  raise ValueError(
@@ -1965,9 +2098,9 @@ def bsr_mv(
1965
2098
  The ``x`` and ``y`` vectors are allowed to alias.
1966
2099
 
1967
2100
  Args:
1968
- A: Read-only, left matrix factor of the matrix-vector product.
1969
- x: Read-only, right vector factor of the matrix-vector product.
1970
- y: Mutable left-hand-side. If ``y`` is not provided, it will be allocated and treated as zero.
2101
+ A: Read-only, left matrix operand of the matrix-vector product.
2102
+ x: Read-only, right vector operand of the matrix-vector product.
2103
+ y: Mutable affine operand and result vector. If ``y`` is not provided, it will be allocated and treated as zero.
1971
2104
  alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
1972
2105
  beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
1973
2106
  transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
@@ -1990,23 +2123,27 @@ def bsr_mv(
1990
2123
  # If no output array is provided, allocate one for convenience
1991
2124
  y_vec_len = block_shape[0]
1992
2125
  y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
1993
- y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype)
2126
+ y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype, requires_grad=x.requires_grad)
1994
2127
  beta = 0.0
1995
2128
 
1996
2129
  alpha = A.scalar_type(alpha)
1997
2130
  beta = A.scalar_type(beta)
1998
2131
 
1999
2132
  if A.values.device != x.device or A.values.device != y.device:
2000
- raise ValueError("A, x, and y must reside on the same device")
2133
+ raise ValueError(
2134
+ f"A, x, and y must reside on the same device, got {A.values.device}, {x.device} and {y.device}"
2135
+ )
2001
2136
 
2002
2137
  if x.ptr == y.ptr:
2003
2138
  # Aliasing case, need temporary storage
2004
2139
  if work_buffer is None:
2005
2140
  work_buffer = wp.empty_like(y)
2006
2141
  elif work_buffer.size < y.size:
2007
- raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
2142
+ raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}, got {work_buffer.size}")
2008
2143
  elif not types_equal(work_buffer.dtype, y.dtype):
2009
- raise ValueError(f"Work buffer must have same data type as y, {type_repr(y.dtype)}")
2144
+ raise ValueError(
2145
+ f"Work buffer must have same data type as y, {type_repr(y.dtype)} vs {type_repr(work_buffer.dtype)}"
2146
+ )
2010
2147
 
2011
2148
  # Save old y values before overwriting vector
2012
2149
  wp.copy(dest=work_buffer, src=y, count=y.size)