warp-lang 1.8.1__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (134) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +47 -67
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +312 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1249 -784
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/fabric.py +1 -1
  18. warp/fem/cache.py +27 -19
  19. warp/fem/domain.py +2 -2
  20. warp/fem/field/nodal_field.py +2 -2
  21. warp/fem/field/virtual.py +264 -166
  22. warp/fem/geometry/geometry.py +5 -5
  23. warp/fem/integrate.py +129 -51
  24. warp/fem/space/restriction.py +4 -0
  25. warp/fem/space/shape/tet_shape_function.py +3 -10
  26. warp/jax_experimental/custom_call.py +1 -1
  27. warp/jax_experimental/ffi.py +2 -1
  28. warp/marching_cubes.py +708 -0
  29. warp/native/array.h +99 -4
  30. warp/native/builtin.h +82 -5
  31. warp/native/bvh.cpp +64 -28
  32. warp/native/bvh.cu +58 -58
  33. warp/native/bvh.h +2 -2
  34. warp/native/clang/clang.cpp +7 -7
  35. warp/native/coloring.cpp +8 -2
  36. warp/native/crt.cpp +2 -2
  37. warp/native/crt.h +3 -5
  38. warp/native/cuda_util.cpp +41 -10
  39. warp/native/cuda_util.h +10 -4
  40. warp/native/exports.h +1842 -1908
  41. warp/native/fabric.h +2 -1
  42. warp/native/hashgrid.cpp +37 -37
  43. warp/native/hashgrid.cu +2 -2
  44. warp/native/initializer_array.h +1 -1
  45. warp/native/intersect.h +2 -2
  46. warp/native/mat.h +1910 -116
  47. warp/native/mathdx.cpp +43 -43
  48. warp/native/mesh.cpp +24 -24
  49. warp/native/mesh.cu +26 -26
  50. warp/native/mesh.h +4 -2
  51. warp/native/nanovdb/GridHandle.h +179 -12
  52. warp/native/nanovdb/HostBuffer.h +8 -7
  53. warp/native/nanovdb/NanoVDB.h +517 -895
  54. warp/native/nanovdb/NodeManager.h +323 -0
  55. warp/native/nanovdb/PNanoVDB.h +2 -2
  56. warp/native/quat.h +331 -14
  57. warp/native/range.h +7 -1
  58. warp/native/reduce.cpp +10 -10
  59. warp/native/reduce.cu +13 -14
  60. warp/native/runlength_encode.cpp +2 -2
  61. warp/native/runlength_encode.cu +5 -5
  62. warp/native/scan.cpp +3 -3
  63. warp/native/scan.cu +4 -4
  64. warp/native/sort.cpp +10 -10
  65. warp/native/sort.cu +22 -22
  66. warp/native/sparse.cpp +8 -8
  67. warp/native/sparse.cu +13 -13
  68. warp/native/spatial.h +366 -17
  69. warp/native/temp_buffer.h +2 -2
  70. warp/native/tile.h +283 -69
  71. warp/native/vec.h +381 -14
  72. warp/native/volume.cpp +54 -54
  73. warp/native/volume.cu +1 -1
  74. warp/native/volume.h +2 -1
  75. warp/native/volume_builder.cu +30 -37
  76. warp/native/warp.cpp +150 -149
  77. warp/native/warp.cu +323 -192
  78. warp/native/warp.h +227 -226
  79. warp/optim/linear.py +736 -271
  80. warp/render/imgui_manager.py +289 -0
  81. warp/render/render_opengl.py +85 -6
  82. warp/sim/graph_coloring.py +2 -2
  83. warp/sparse.py +558 -175
  84. warp/tests/aux_test_module_aot.py +7 -0
  85. warp/tests/cuda/test_async.py +3 -3
  86. warp/tests/cuda/test_conditional_captures.py +101 -0
  87. warp/tests/geometry/test_marching_cubes.py +233 -12
  88. warp/tests/sim/test_coloring.py +6 -6
  89. warp/tests/test_array.py +56 -5
  90. warp/tests/test_codegen.py +3 -2
  91. warp/tests/test_context.py +8 -15
  92. warp/tests/test_enum.py +136 -0
  93. warp/tests/test_examples.py +2 -2
  94. warp/tests/test_fem.py +45 -2
  95. warp/tests/test_fixedarray.py +229 -0
  96. warp/tests/test_func.py +18 -15
  97. warp/tests/test_future_annotations.py +7 -5
  98. warp/tests/test_linear_solvers.py +30 -0
  99. warp/tests/test_map.py +1 -1
  100. warp/tests/test_mat.py +1518 -378
  101. warp/tests/test_mat_assign_copy.py +178 -0
  102. warp/tests/test_mat_constructors.py +574 -0
  103. warp/tests/test_module_aot.py +287 -0
  104. warp/tests/test_print.py +69 -0
  105. warp/tests/test_quat.py +140 -34
  106. warp/tests/test_quat_assign_copy.py +145 -0
  107. warp/tests/test_reload.py +2 -1
  108. warp/tests/test_sparse.py +71 -0
  109. warp/tests/test_spatial.py +140 -34
  110. warp/tests/test_spatial_assign_copy.py +160 -0
  111. warp/tests/test_struct.py +43 -3
  112. warp/tests/test_types.py +0 -20
  113. warp/tests/test_vec.py +179 -34
  114. warp/tests/test_vec_assign_copy.py +143 -0
  115. warp/tests/tile/test_tile.py +184 -18
  116. warp/tests/tile/test_tile_cholesky.py +605 -0
  117. warp/tests/tile/test_tile_load.py +169 -0
  118. warp/tests/tile/test_tile_mathdx.py +2 -558
  119. warp/tests/tile/test_tile_matmul.py +1 -1
  120. warp/tests/tile/test_tile_mlp.py +1 -1
  121. warp/tests/tile/test_tile_shared_memory.py +5 -5
  122. warp/tests/unittest_suites.py +6 -0
  123. warp/tests/walkthrough_debug.py +1 -1
  124. warp/thirdparty/unittest_parallel.py +108 -9
  125. warp/types.py +554 -264
  126. warp/utils.py +68 -86
  127. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  128. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
  129. warp/native/marching.cpp +0 -19
  130. warp/native/marching.cu +0 -514
  131. warp/native/marching.h +0 -19
  132. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  133. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/sparse.py CHANGED
@@ -36,6 +36,30 @@ from warp.types import (
36
36
  types_equal,
37
37
  )
38
38
 
39
+ __all__ = [
40
+ "BsrMatrix",
41
+ "bsr_assign",
42
+ "bsr_axpy",
43
+ "bsr_copy",
44
+ "bsr_diag",
45
+ "bsr_from_triplets",
46
+ "bsr_get_diag",
47
+ "bsr_identity",
48
+ "bsr_matrix_t",
49
+ "bsr_mm",
50
+ "bsr_mm_work_arrays",
51
+ "bsr_mv",
52
+ "bsr_scale",
53
+ "bsr_set_diag",
54
+ "bsr_set_from_triplets",
55
+ "bsr_set_identity",
56
+ "bsr_set_transpose",
57
+ "bsr_set_zero",
58
+ "bsr_transposed",
59
+ "bsr_zeros",
60
+ ]
61
+
62
+
39
63
  # typing hints
40
64
 
41
65
  _BlockType = TypeVar("BlockType") # noqa: PLC0132
@@ -52,6 +76,7 @@ class _ScalarBlockType(Generic[Scalar]):
52
76
  BlockType = Union[_MatrixBlockType[Rows, Cols, Scalar], _ScalarBlockType[Scalar]]
53
77
 
54
78
  _struct_cache = {}
79
+ _transfer_buffer_cache = {}
55
80
 
56
81
 
57
82
  class BsrMatrix(Generic[_BlockType]):
@@ -131,17 +156,21 @@ class BsrMatrix(Generic[_BlockType]):
131
156
  return out
132
157
 
133
158
  def nnz_sync(self):
134
- """Ensure that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed
135
- and update the nnz upper bound.
159
+ """Ensures that any ongoing transfer of the exact nnz number from the device offsets array to the host has completed,
160
+ or, if none has been scheduled yet, starts a new transfer and waits for it to complete.
161
+ Then updates the nnz upper bound.
136
162
 
137
163
  See also :meth:`copy_nnz_async`.
138
164
  """
139
165
 
140
166
  buf, event = self._nnz_transfer_if_any()
141
- if buf is not None:
142
- if event is not None:
143
- wp.synchronize_event(event)
144
- self.nnz = int(buf.numpy()[0])
167
+ if buf is None:
168
+ self.copy_nnz_async()
169
+ buf, event = self._nnz_transfer_if_any()
170
+
171
+ if event is not None:
172
+ wp.synchronize_event(event)
173
+ self.nnz = int(buf.numpy()[0])
145
174
  return self.nnz
146
175
 
147
176
  def copy_nnz_async(self) -> None:
@@ -161,17 +190,23 @@ class BsrMatrix(Generic[_BlockType]):
161
190
 
162
191
  def _setup_nnz_transfer(self):
163
192
  buf, event = self._nnz_transfer_if_any()
164
- if buf is not None or self.device.is_capturing:
193
+ if buf is not None:
165
194
  return buf, event
166
195
 
167
- buf = wp.empty(dtype=int, shape=(1,), device="cpu", pinned=self.device.is_cuda)
168
- event = wp.Event(self.device) if self.device.is_cuda else None
169
- BsrMatrix.__setattr__(self, "_nnz_transfer", (buf, event))
196
+ buf, event = _allocate_transfer_buf(self.device)
197
+ if buf is not None:
198
+ BsrMatrix.__setattr__(self, "_nnz_transfer", (buf, event))
199
+
170
200
  return buf, event
171
201
 
172
202
  def _nnz_transfer_if_any(self):
173
203
  return getattr(self, "_nnz_transfer", (None, None))
174
204
 
205
+ def __del__(self):
206
+ buf, event = self._nnz_transfer_if_any()
207
+ if buf is not None:
208
+ _redeem_transfer_buf(self.device, buf, event)
209
+
175
210
  # Overloaded math operators
176
211
  def __add__(self, y):
177
212
  return bsr_axpy(y, bsr_copy(self))
@@ -226,6 +261,31 @@ class BsrMatrix(Generic[_BlockType]):
226
261
  return bsr_transposed(self)
227
262
 
228
263
 
264
+ def _allocate_transfer_buf(device):
265
+ if device.ordinal in _transfer_buffer_cache:
266
+ all_, pool = _transfer_buffer_cache[device.ordinal]
267
+ else:
268
+ all_ = []
269
+ pool = []
270
+ _transfer_buffer_cache[device.ordinal] = (all_, pool)
271
+
272
+ if pool:
273
+ return pool.pop()
274
+
275
+ if device.is_capturing:
276
+ return None, None
277
+
278
+ buf = wp.empty(dtype=int, shape=(1,), device="cpu", pinned=device.is_cuda)
279
+ event = wp.Event(device) if device.is_cuda else None
280
+ all_.append((buf, event)) # keep a reference to the buffer and event, prevent garbage collection before redeem
281
+ return buf, event
282
+
283
+
284
+ def _redeem_transfer_buf(device, buf, event):
285
+ all_, pool = _transfer_buffer_cache[device.ordinal]
286
+ pool.append((buf, event))
287
+
288
+
229
289
  def bsr_matrix_t(dtype: BlockType):
230
290
  dtype = type_to_warp(dtype)
231
291
 
@@ -490,9 +550,9 @@ def bsr_set_from_triplets(
490
550
  from warp.context import runtime
491
551
 
492
552
  if device.is_cpu:
493
- native_func = runtime.core.bsr_matrix_from_triplets_host
553
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
494
554
  else:
495
- native_func = runtime.core.bsr_matrix_from_triplets_device
555
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
496
556
 
497
557
  nnz_buf, nnz_event = dest._setup_nnz_transfer()
498
558
  summed_triplet_offsets = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
@@ -901,9 +961,9 @@ def bsr_assign(
901
961
  from warp.context import runtime
902
962
 
903
963
  if dest.device.is_cpu:
904
- native_func = runtime.core.bsr_matrix_from_triplets_host
964
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
905
965
  else:
906
- native_func = runtime.core.bsr_matrix_from_triplets_device
966
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
907
967
 
908
968
  nnz_buf, nnz_event = dest._setup_nnz_transfer()
909
969
  with wp.ScopedDevice(dest.device):
@@ -1041,9 +1101,9 @@ def bsr_set_transpose(
1041
1101
  from warp.context import runtime
1042
1102
 
1043
1103
  if dest.values.device.is_cpu:
1044
- native_func = runtime.core.bsr_transpose_host
1104
+ native_func = runtime.core.wp_bsr_transpose_host
1045
1105
  else:
1046
- native_func = runtime.core.bsr_transpose_device
1106
+ native_func = runtime.core.wp_bsr_transpose_device
1047
1107
 
1048
1108
  block_index_map = wp.empty(shape=2 * nnz, dtype=int, device=src.device)
1049
1109
 
@@ -1094,14 +1154,14 @@ def _bsr_get_diag_kernel(
1094
1154
  scale: Any,
1095
1155
  A_offsets: wp.array(dtype=int),
1096
1156
  A_columns: wp.array(dtype=int),
1097
- A_values: wp.array(dtype=Any),
1098
- out: wp.array(dtype=Any),
1157
+ A_values: wp.array3d(dtype=Any),
1158
+ out: wp.array3d(dtype=Any),
1099
1159
  ):
1100
- row = wp.tid()
1160
+ row, br, bc = wp.tid()
1101
1161
 
1102
1162
  diag = _bsr_block_index(row, row, A_offsets, A_columns)
1103
1163
  if diag != -1:
1104
- out[row] = scale * A_values[diag]
1164
+ out[row, br, bc] = scale * A_values[diag, br, bc]
1105
1165
 
1106
1166
 
1107
1167
  def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]":
@@ -1128,9 +1188,9 @@ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[Block
1128
1188
 
1129
1189
  wp.launch(
1130
1190
  kernel=_bsr_get_diag_kernel,
1131
- dim=dim,
1191
+ dim=(dim, *A.block_shape),
1132
1192
  device=A.values.device,
1133
- inputs=[A.scalar_type(scale), A.offsets, A.columns, A.values, out],
1193
+ inputs=[A.scalar_type(scale), A.offsets, A.columns, A.scalar_values, _as_3d_array(out, A.block_shape)],
1134
1194
  )
1135
1195
 
1136
1196
  return out
@@ -1312,7 +1372,17 @@ def _bsr_scale_kernel(
1312
1372
  alpha: Any,
1313
1373
  values: wp.array(dtype=Any),
1314
1374
  ):
1315
- values[wp.tid()] = alpha * values[wp.tid()]
1375
+ row = wp.tid()
1376
+ values[row] = alpha * values[row]
1377
+
1378
+
1379
+ @wp.kernel
1380
+ def _bsr_scale_kernel(
1381
+ alpha: Any,
1382
+ values: wp.array3d(dtype=Any),
1383
+ ):
1384
+ row, br, bc = wp.tid()
1385
+ values[row, br, bc] = alpha * values[row, br, bc]
1316
1386
 
1317
1387
 
1318
1388
  def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
@@ -1329,9 +1399,9 @@ def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix:
1329
1399
 
1330
1400
  wp.launch(
1331
1401
  kernel=_bsr_scale_kernel,
1332
- dim=x.nnz,
1402
+ dim=(x.nnz, *x.block_shape),
1333
1403
  device=x.values.device,
1334
- inputs=[alpha, x.values],
1404
+ inputs=[alpha, x.scalar_values],
1335
1405
  )
1336
1406
 
1337
1407
  return x
@@ -1351,16 +1421,16 @@ def _bsr_axpy_add_block(
1351
1421
  cols: wp.array(dtype=int),
1352
1422
  dst_offsets: wp.array(dtype=int),
1353
1423
  dst_columns: wp.array(dtype=int),
1354
- src_values: wp.array(dtype=Any),
1355
- dst_values: wp.array(dtype=Any),
1424
+ src_values: wp.array3d(dtype=Any),
1425
+ dst_values: wp.array3d(dtype=Any),
1356
1426
  ):
1357
- i = wp.tid()
1427
+ i, br, bc = wp.tid()
1358
1428
  row = rows[i + src_offset]
1359
1429
  col = cols[i + src_offset]
1360
1430
 
1361
1431
  block = _bsr_block_index(row, col, dst_offsets, dst_columns)
1362
1432
  if block != -1:
1363
- dst_values[block] += scale * src_values[i]
1433
+ dst_values[block, br, bc] += scale * src_values[i, br, bc]
1364
1434
 
1365
1435
 
1366
1436
  class bsr_axpy_work_arrays:
@@ -1386,7 +1456,7 @@ class bsr_axpy_work_arrays:
1386
1456
  self._sum_cols = wp.empty(shape=(sum_nnz), dtype=int, device=self.device)
1387
1457
 
1388
1458
  if self._old_y_values is None or self._old_y_values.size < y.nnz:
1389
- self._old_y_values = wp.empty(shape=(y.nnz,), dtype=y.values.dtype, device=self.device)
1459
+ self._old_y_values = wp.empty_like(y.values[: y.nnz])
1390
1460
 
1391
1461
 
1392
1462
  def bsr_axpy(
@@ -1475,7 +1545,7 @@ def bsr_axpy(
1475
1545
  x.uncompress_rows(out=work_arrays._sum_rows[y_nnz:])
1476
1546
 
1477
1547
  # Save old y values before overwriting matrix
1478
- wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y_nnz)
1548
+ wp.copy(dest=work_arrays._old_y_values, src=y.values, count=y.nnz)
1479
1549
 
1480
1550
  # Increase dest array sizes if needed
1481
1551
  if not masked:
@@ -1484,9 +1554,9 @@ def bsr_axpy(
1484
1554
  from warp.context import runtime
1485
1555
 
1486
1556
  if device.is_cpu:
1487
- native_func = runtime.core.bsr_matrix_from_triplets_host
1557
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
1488
1558
  else:
1489
- native_func = runtime.core.bsr_matrix_from_triplets_device
1559
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
1490
1560
 
1491
1561
  old_y_nnz = y_nnz
1492
1562
  nnz_buf, nnz_event = y._setup_nnz_transfer()
@@ -1517,7 +1587,7 @@ def bsr_axpy(
1517
1587
  wp.launch(
1518
1588
  kernel=_bsr_axpy_add_block,
1519
1589
  device=device,
1520
- dim=old_y_nnz,
1590
+ dim=(old_y_nnz, y.block_shape[0], y.block_shape[1]),
1521
1591
  inputs=[
1522
1592
  0,
1523
1593
  beta,
@@ -1525,15 +1595,15 @@ def bsr_axpy(
1525
1595
  work_arrays._sum_cols,
1526
1596
  y.offsets,
1527
1597
  y.columns,
1528
- work_arrays._old_y_values,
1529
- y.values,
1598
+ _as_3d_array(work_arrays._old_y_values, y.block_shape),
1599
+ y.scalar_values,
1530
1600
  ],
1531
1601
  )
1532
1602
 
1533
1603
  wp.launch(
1534
1604
  kernel=_bsr_axpy_add_block,
1535
1605
  device=device,
1536
- dim=x_nnz,
1606
+ dim=(x_nnz, y.block_shape[0], y.block_shape[1]),
1537
1607
  inputs=[
1538
1608
  old_y_nnz,
1539
1609
  alpha,
@@ -1541,60 +1611,78 @@ def bsr_axpy(
1541
1611
  work_arrays._sum_cols,
1542
1612
  y.offsets,
1543
1613
  y.columns,
1544
- x.values,
1545
- y.values,
1614
+ x.scalar_values,
1615
+ y.scalar_values,
1546
1616
  ],
1547
1617
  )
1548
1618
 
1549
1619
  return y
1550
1620
 
1551
1621
 
1552
- @wp.kernel(enable_backward=False)
1553
- def _bsr_mm_count_coeffs(
1554
- y_ncol: int,
1555
- z_nnz: int,
1556
- x_offsets: wp.array(dtype=int),
1557
- x_columns: wp.array(dtype=int),
1558
- y_offsets: wp.array(dtype=int),
1559
- y_columns: wp.array(dtype=int),
1560
- row_min: wp.array(dtype=int),
1561
- block_counts: wp.array(dtype=int),
1562
- ):
1563
- row = wp.tid()
1564
- row_count = int(0)
1622
+ def make_bsr_mm_count_coeffs(tile_size):
1623
+ from warp.fem.cache import dynamic_kernel
1565
1624
 
1566
- x_beg = x_offsets[row]
1567
- x_end = x_offsets[row + 1]
1625
+ @dynamic_kernel(suffix=tile_size)
1626
+ def bsr_mm_count_coeffs(
1627
+ y_ncol: int,
1628
+ z_nnz: int,
1629
+ x_offsets: wp.array(dtype=int),
1630
+ x_columns: wp.array(dtype=int),
1631
+ y_offsets: wp.array(dtype=int),
1632
+ y_columns: wp.array(dtype=int),
1633
+ row_min: wp.array(dtype=int),
1634
+ block_counts: wp.array(dtype=int),
1635
+ ):
1636
+ row, lane = wp.tid()
1637
+ row_count = int(0)
1568
1638
 
1569
- min_col = y_ncol
1570
- max_col = int(0)
1639
+ x_beg = x_offsets[row]
1640
+ x_end = x_offsets[row + 1]
1571
1641
 
1572
- for x_block in range(x_beg, x_end):
1573
- x_col = x_columns[x_block]
1574
- y_row_end = y_offsets[x_col + 1]
1575
- y_row_beg = y_offsets[x_col]
1576
- block_count = y_row_end - y_row_beg
1577
- if block_count != 0:
1578
- min_col = wp.min(y_columns[y_row_beg], min_col)
1579
- max_col = wp.max(y_columns[y_row_end - 1], max_col)
1580
-
1581
- block_counts[x_block + 1] = block_count
1582
- row_count += block_count
1583
-
1584
- if row_count > wp.max(0, max_col - min_col):
1585
- row_min[row] = min_col
1586
- block_counts[x_end] = max_col + 1 - min_col
1587
- for x_block in range(x_beg, x_end - 1):
1588
- block_counts[x_block + 1] = 0
1589
- else:
1590
- row_min[row] = -1
1642
+ min_col = y_ncol
1643
+ max_col = int(0)
1591
1644
 
1592
- if row == 0:
1593
- block_counts[0] = z_nnz
1645
+ for x_block in range(x_beg + lane, x_end, tile_size):
1646
+ x_col = x_columns[x_block]
1647
+ y_row_end = y_offsets[x_col + 1]
1648
+ y_row_beg = y_offsets[x_col]
1649
+ block_count = y_row_end - y_row_beg
1650
+ if block_count != 0:
1651
+ min_col = wp.min(y_columns[y_row_beg], min_col)
1652
+ max_col = wp.max(y_columns[y_row_end - 1], max_col)
1653
+
1654
+ block_counts[x_block + 1] = block_count
1655
+ row_count += block_count
1656
+
1657
+ if wp.static(tile_size) > 1:
1658
+ row_count = wp.tile_sum(wp.tile(row_count))[0]
1659
+ min_col = wp.tile_min(wp.tile(min_col))[0]
1660
+ max_col = wp.tile_max(wp.tile(max_col))[0]
1661
+ col_range_size = wp.max(0, max_col - min_col + 1)
1662
+
1663
+ if row_count > col_range_size:
1664
+ # Optimization for deep products.
1665
+ # Do not store the whole whole list of src product terms, they would be highly redundant
1666
+ # Instead just mark a range in the output matrix
1667
+
1668
+ if lane == 0:
1669
+ row_min[row] = min_col
1670
+ block_counts[x_end] = col_range_size
1671
+
1672
+ for x_block in range(x_beg + lane, x_end - 1, tile_size):
1673
+ block_counts[x_block + 1] = 0
1674
+ elif lane == 0:
1675
+ row_min[row] = -1
1676
+
1677
+ if lane == 0 and row == 0:
1678
+ block_counts[0] = z_nnz
1679
+
1680
+ return bsr_mm_count_coeffs
1594
1681
 
1595
1682
 
1596
1683
  @wp.kernel(enable_backward=False)
1597
1684
  def _bsr_mm_list_coeffs(
1685
+ copied_z_nnz: int,
1598
1686
  x_nrow: int,
1599
1687
  x_offsets: wp.array(dtype=int),
1600
1688
  x_columns: wp.array(dtype=int),
@@ -1604,38 +1692,58 @@ def _bsr_mm_list_coeffs(
1604
1692
  mm_offsets: wp.array(dtype=int),
1605
1693
  mm_rows: wp.array(dtype=int),
1606
1694
  mm_cols: wp.array(dtype=int),
1695
+ mm_src_blocks: wp.array(dtype=int),
1607
1696
  ):
1608
- x_block = wp.tid()
1609
- mm_block = mm_offsets[x_block]
1697
+ mm_block = wp.tid() + copied_z_nnz
1698
+
1699
+ x_nnz = x_offsets[x_nrow]
1700
+ x_block = wp.lower_bound(mm_offsets, 0, x_nnz + 1, mm_block + 1) - 1
1701
+ pos = mm_block - mm_offsets[x_block]
1610
1702
 
1611
1703
  row = _bsr_row_index(x_offsets, x_nrow, x_block)
1612
- if row == -1:
1613
- return
1614
1704
 
1615
1705
  row_min_col = mm_row_min[row]
1616
- if row_min_col != -1:
1706
+ if row_min_col == -1:
1617
1707
  x_col = x_columns[x_block]
1618
-
1619
1708
  y_beg = y_offsets[x_col]
1620
- y_end = y_offsets[x_col + 1]
1709
+ y_block = y_beg + pos
1710
+ col = y_columns[y_block]
1711
+ src_block = x_block
1712
+ else:
1713
+ col = row_min_col + pos
1714
+ src_block = -1
1621
1715
 
1622
- for y_block in range(y_beg, y_end):
1623
- col = y_columns[y_block]
1624
- mm_rows[mm_block + col - row_min_col] = row
1625
- mm_cols[mm_block + col - row_min_col] = col
1716
+ mm_cols[mm_block] = col
1717
+ mm_rows[mm_block] = row
1718
+ mm_src_blocks[mm_block] = src_block
1626
1719
 
1627
- return
1628
1720
 
1629
- x_col = x_columns[x_block]
1630
- y_beg = y_offsets[x_col]
1631
- y_end = y_offsets[x_col + 1]
1632
- for y_block in range(y_beg, y_end):
1633
- mm_cols[mm_block] = y_columns[y_block]
1634
- mm_rows[mm_block] = row
1635
- mm_block += 1
1721
+ @wp.func
1722
+ def _bsr_mm_use_triplets(
1723
+ row: int,
1724
+ mm_block: int,
1725
+ mm_row_min: wp.array(dtype=int),
1726
+ row_offsets: wp.array(dtype=int),
1727
+ summed_triplet_offsets: wp.array(dtype=int),
1728
+ ):
1729
+ x_beg = row_offsets[row]
1730
+ x_end = row_offsets[row + 1]
1636
1731
 
1732
+ if mm_row_min:
1733
+ if mm_row_min[row] == -1:
1734
+ if mm_block == 0:
1735
+ block_beg = 0
1736
+ else:
1737
+ block_beg = summed_triplet_offsets[mm_block - 1]
1738
+ block_end = summed_triplet_offsets[mm_block]
1637
1739
 
1638
- @wp.kernel
1740
+ if x_end - x_beg > 3 * (block_end - block_beg):
1741
+ return True, block_beg, block_end
1742
+
1743
+ return False, x_beg, x_end
1744
+
1745
+
1746
+ @wp.kernel(enable_backward=False)
1639
1747
  def _bsr_mm_compute_values(
1640
1748
  alpha: Any,
1641
1749
  x_offsets: wp.array(dtype=int),
@@ -1644,6 +1752,9 @@ def _bsr_mm_compute_values(
1644
1752
  y_offsets: wp.array(dtype=int),
1645
1753
  y_columns: wp.array(dtype=int),
1646
1754
  y_values: wp.array(dtype=Any),
1755
+ mm_row_min: wp.array(dtype=int),
1756
+ summed_triplet_offsets: wp.array(dtype=int),
1757
+ summed_triplet_src_blocks: wp.indexedarray(dtype=int),
1647
1758
  mm_row_count: int,
1648
1759
  mm_offsets: wp.array(dtype=int),
1649
1760
  mm_cols: wp.array(dtype=int),
@@ -1655,21 +1766,135 @@ def _bsr_mm_compute_values(
1655
1766
  if row == -1:
1656
1767
  return
1657
1768
 
1658
- col = mm_cols[mm_block]
1769
+ use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
1770
+ row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
1771
+ )
1659
1772
 
1660
1773
  mm_val = mm_values.dtype(type(alpha)(0.0))
1661
-
1662
- x_beg = x_offsets[row]
1663
- x_end = x_offsets[row + 1]
1664
- for x_block in range(x_beg, x_end):
1665
- x_col = x_columns[x_block]
1666
- y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
1667
- if y_block != -1:
1668
- mm_val += x_values[x_block] * y_values[y_block]
1774
+ col = mm_cols[mm_block]
1775
+ if use_triplets:
1776
+ for tpl_idx in range(block_beg, block_end):
1777
+ x_block = summed_triplet_src_blocks[tpl_idx]
1778
+ x_col = x_columns[x_block]
1779
+ if x_block != -1:
1780
+ y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
1781
+ mm_val += x_values[x_block] * y_values[y_block]
1782
+ else:
1783
+ for x_block in range(block_beg, block_end):
1784
+ x_col = x_columns[x_block]
1785
+ y_block = _bsr_block_index(x_col, col, y_offsets, y_columns)
1786
+ if y_block != -1:
1787
+ mm_val += x_values[x_block] * y_values[y_block]
1669
1788
 
1670
1789
  mm_values[mm_block] += alpha * mm_val
1671
1790
 
1672
1791
 
1792
+ def make_bsr_mm_compute_values_tiled_outer(subblock_rows, subblock_cols, block_depth, scalar_type, tile_size):
1793
+ from warp.fem.cache import dynamic_func, dynamic_kernel
1794
+
1795
+ mm_type = wp.mat(dtype=scalar_type, shape=(subblock_rows, subblock_cols))
1796
+
1797
+ x_col_vec_t = wp.vec(dtype=scalar_type, length=subblock_rows)
1798
+ y_row_vec_t = wp.vec(dtype=scalar_type, length=subblock_cols)
1799
+
1800
+ suffix = f"{subblock_rows}{subblock_cols}{block_depth}{tile_size}{scalar_type.__name__}"
1801
+
1802
+ @dynamic_func(suffix=suffix)
1803
+ def _outer_product(
1804
+ x_values: wp.array2d(dtype=scalar_type),
1805
+ y_values: wp.array2d(dtype=scalar_type),
1806
+ brow_off: int,
1807
+ bcol_off: int,
1808
+ block_col: int,
1809
+ brow_count: int,
1810
+ bcol_count: int,
1811
+ ):
1812
+ x_col_vec = x_col_vec_t()
1813
+ y_row_vec = y_row_vec_t()
1814
+
1815
+ for k in range(brow_count):
1816
+ x_col_vec[k] = x_values[brow_off + k, block_col]
1817
+ for k in range(bcol_count):
1818
+ y_row_vec[k] = y_values[block_col, bcol_off + k]
1819
+
1820
+ return wp.outer(x_col_vec, y_row_vec)
1821
+
1822
+ @dynamic_kernel(suffix=suffix, kernel_options={"enable_backward": False})
1823
+ def bsr_mm_compute_values(
1824
+ alpha: scalar_type,
1825
+ x_offsets: wp.array(dtype=int),
1826
+ x_columns: wp.array(dtype=int),
1827
+ x_values: wp.array3d(dtype=scalar_type),
1828
+ y_offsets: wp.array(dtype=int),
1829
+ y_columns: wp.array(dtype=int),
1830
+ y_values: wp.array3d(dtype=scalar_type),
1831
+ mm_row_min: wp.array(dtype=int),
1832
+ summed_triplet_offsets: wp.array(dtype=int),
1833
+ summed_triplet_src_blocks: wp.indexedarray(dtype=int),
1834
+ mm_row_count: int,
1835
+ mm_offsets: wp.array(dtype=int),
1836
+ mm_cols: wp.array(dtype=int),
1837
+ mm_values: wp.array3d(dtype=scalar_type),
1838
+ ):
1839
+ mm_block, subrow, subcol, lane = wp.tid()
1840
+
1841
+ brow_off = subrow * wp.static(subblock_rows)
1842
+ bcol_off = subcol * wp.static(subblock_cols)
1843
+
1844
+ brow_count = wp.min(mm_values.shape[1] - brow_off, subblock_rows)
1845
+ bcol_count = wp.min(mm_values.shape[2] - bcol_off, subblock_cols)
1846
+
1847
+ mm_row = _bsr_row_index(mm_offsets, mm_row_count, mm_block)
1848
+ if mm_row == -1:
1849
+ return
1850
+
1851
+ lane_val = mm_type()
1852
+
1853
+ use_triplets, block_beg, block_end = _bsr_mm_use_triplets(
1854
+ mm_row, mm_block, mm_row_min, x_offsets, summed_triplet_offsets
1855
+ )
1856
+
1857
+ col_count = (block_end - block_beg) * block_depth
1858
+
1859
+ mm_col = mm_cols[mm_block]
1860
+ if use_triplets:
1861
+ for col in range(lane, col_count, tile_size):
1862
+ tpl_block = col // wp.static(block_depth)
1863
+ block_col = col - tpl_block * wp.static(block_depth)
1864
+ tpl_block += block_beg
1865
+
1866
+ x_block = summed_triplet_src_blocks[tpl_block]
1867
+ if x_block != -1:
1868
+ x_col = x_columns[x_block]
1869
+ y_block = _bsr_block_index(x_col, mm_col, y_offsets, y_columns)
1870
+ lane_val += _outer_product(
1871
+ x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
1872
+ )
1873
+ else:
1874
+ for col in range(lane, col_count, tile_size):
1875
+ x_block = col // wp.static(block_depth)
1876
+ block_col = col - x_block * wp.static(block_depth)
1877
+ x_block += block_beg
1878
+
1879
+ x_col = x_columns[x_block]
1880
+ y_block = _bsr_block_index(x_col, mm_col, y_offsets, y_columns)
1881
+
1882
+ if y_block != -1:
1883
+ lane_val += _outer_product(
1884
+ x_values[x_block], y_values[y_block], brow_off, bcol_off, block_col, brow_count, bcol_count
1885
+ )
1886
+
1887
+ mm_val = wp.tile_sum(wp.tile(lane_val, preserve_type=True))[0]
1888
+
1889
+ for coef in range(lane, wp.static(subblock_cols * subblock_rows), tile_size):
1890
+ br = coef // subblock_cols
1891
+ bc = coef - br * subblock_cols
1892
+ if br < brow_count and bc < bcol_count:
1893
+ mm_values[mm_block, br + brow_off, bc + bcol_off] += mm_val[br, bc] * alpha
1894
+
1895
+ return bsr_mm_compute_values
1896
+
1897
+
1673
1898
  class bsr_mm_work_arrays:
1674
1899
  """Opaque structure for persisting :func:`bsr_mm` temporary work buffers across calls."""
1675
1900
 
@@ -1682,6 +1907,7 @@ class bsr_mm_work_arrays:
1682
1907
  self._mm_block_counts = None
1683
1908
  self._mm_rows = None
1684
1909
  self._mm_cols = None
1910
+ self._mm_src_blocks = None
1685
1911
  self._old_z_values = None
1686
1912
  self._old_z_offsets = None
1687
1913
  self._old_z_columns = None
@@ -1717,6 +1943,8 @@ class bsr_mm_work_arrays:
1717
1943
  self._mm_rows = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
1718
1944
  if self._mm_cols is None or self._mm_cols.size < mm_nnz:
1719
1945
  self._mm_cols = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
1946
+ if self._mm_src_blocks is None or self._mm_src_blocks.size < mm_nnz:
1947
+ self._mm_src_blocks = wp.empty(shape=(mm_nnz,), dtype=int, device=self.device)
1720
1948
 
1721
1949
 
1722
1950
  def bsr_mm(
@@ -1728,6 +1956,7 @@ def bsr_mm(
1728
1956
  masked: bool = False,
1729
1957
  work_arrays: Optional[bsr_mm_work_arrays] = None,
1730
1958
  reuse_topology: bool = False,
1959
+ tile_size: int = 0,
1731
1960
  ) -> BsrMatrix[BlockType[Rows, Cols, Scalar]]:
1732
1961
  """
1733
1962
  Perform the sparse matrix-matrix multiplication ``z := alpha * x @ y + beta * z`` on BSR matrices ``x``, ``y`` and ``z``, and return ``z``.
@@ -1750,6 +1979,9 @@ def bsr_mm(
1750
1979
  The matrices ``x``, ``y`` and ``z`` must be structurally similar to
1751
1980
  the previous call in which ``work_arrays`` were populated.
1752
1981
  This is necessary for ``bsr_mm`` to be captured in a CUDA graph.
1982
+ tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
1983
+ If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
1984
+ use tiles using using an heuristic based on the matrix shape and number of non-zeros..
1753
1985
  """
1754
1986
 
1755
1987
  x, x_scale = _extract_matrix_and_scale(x)
@@ -1835,11 +2067,17 @@ def bsr_mm(
1835
2067
  copied_z_nnz = work_arrays._copied_z_nnz
1836
2068
 
1837
2069
  # Prefix sum of number of (unmerged) mm blocks per row
2070
+ # Use either a thread or a block per row depending on avg nnz/row
1838
2071
  work_arrays._mm_block_counts.zero_()
2072
+ count_tile_size = 32
2073
+ if not device.is_cuda or x.nnz < 3 * count_tile_size * x.nrow:
2074
+ count_tile_size = 1
2075
+
1839
2076
  wp.launch(
1840
- kernel=_bsr_mm_count_coeffs,
2077
+ kernel=make_bsr_mm_count_coeffs(count_tile_size),
1841
2078
  device=device,
1842
- dim=z.nrow,
2079
+ dim=(z.nrow, count_tile_size),
2080
+ block_dim=count_tile_size if count_tile_size > 1 else 256,
1843
2081
  inputs=[
1844
2082
  y.ncol,
1845
2083
  copied_z_nnz,
@@ -1851,7 +2089,7 @@ def bsr_mm(
1851
2089
  work_arrays._mm_block_counts,
1852
2090
  ],
1853
2091
  )
1854
- warp.utils.array_scan(work_arrays._mm_block_counts, work_arrays._mm_block_counts)
2092
+ warp.utils.array_scan(work_arrays._mm_block_counts[: x.nnz + 1], work_arrays._mm_block_counts[: x.nnz + 1])
1855
2093
 
1856
2094
  # Get back total counts on host -- we need a synchronization here
1857
2095
  # Use pinned buffer from z, we are going to need it later anyway
@@ -1873,18 +2111,19 @@ def bsr_mm(
1873
2111
  # Copy z row and column indices
1874
2112
  wp.copy(dest=work_arrays._mm_cols, src=z.columns, count=copied_z_nnz)
1875
2113
  z.uncompress_rows(out=work_arrays._mm_rows)
2114
+ work_arrays._mm_src_blocks[:copied_z_nnz].fill_(-1)
1876
2115
  if z_aliasing:
1877
2116
  # If z is aliasing with x or y, need to save topology as well
1878
2117
  wp.copy(src=z.columns, dest=work_arrays._old_z_columns, count=copied_z_nnz)
1879
2118
  wp.copy(src=z.offsets, dest=work_arrays._old_z_offsets, count=z.nrow + 1)
1880
2119
 
1881
2120
  # Fill unmerged mm blocks rows and columns
1882
- work_arrays._mm_rows[copied_z_nnz:].fill_(-1)
1883
2121
  wp.launch(
1884
2122
  kernel=_bsr_mm_list_coeffs,
1885
2123
  device=device,
1886
- dim=x.nnz,
2124
+ dim=mm_nnz - copied_z_nnz,
1887
2125
  inputs=[
2126
+ copied_z_nnz,
1888
2127
  x.nrow,
1889
2128
  x.offsets,
1890
2129
  x.columns,
@@ -1894,6 +2133,7 @@ def bsr_mm(
1894
2133
  work_arrays._mm_block_counts,
1895
2134
  work_arrays._mm_rows,
1896
2135
  work_arrays._mm_cols,
2136
+ work_arrays._mm_src_blocks,
1897
2137
  ],
1898
2138
  )
1899
2139
 
@@ -1912,11 +2152,13 @@ def bsr_mm(
1912
2152
  from warp.context import runtime
1913
2153
 
1914
2154
  if device.is_cpu:
1915
- native_func = runtime.core.bsr_matrix_from_triplets_host
2155
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_host
1916
2156
  else:
1917
- native_func = runtime.core.bsr_matrix_from_triplets_device
2157
+ native_func = runtime.core.wp_bsr_matrix_from_triplets_device
1918
2158
 
1919
2159
  nnz_buf, nnz_event = z._setup_nnz_transfer()
2160
+ summed_triplet_offsets = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
2161
+ summed_triplet_indices = wp.empty(shape=(mm_nnz,), dtype=wp.int32, device=device)
1920
2162
 
1921
2163
  with wp.ScopedDevice(z.device):
1922
2164
  native_func(
@@ -1931,8 +2173,8 @@ def bsr_mm(
1931
2173
  None, # triplet values
1932
2174
  0, # zero_value_mask
1933
2175
  False, # masked_topology
1934
- None, # summed block offsets
1935
- None, # summed block indices
2176
+ ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
2177
+ ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
1936
2178
  ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
1937
2179
  ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
1938
2180
  _optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
@@ -1952,7 +2194,7 @@ def bsr_mm(
1952
2194
  wp.launch(
1953
2195
  kernel=_bsr_axpy_add_block,
1954
2196
  device=device,
1955
- dim=copied_z_nnz,
2197
+ dim=(copied_z_nnz, z.block_shape[0], z.block_shape[1]),
1956
2198
  inputs=[
1957
2199
  0,
1958
2200
  beta,
@@ -1960,11 +2202,61 @@ def bsr_mm(
1960
2202
  work_arrays._mm_cols,
1961
2203
  z.offsets,
1962
2204
  z.columns,
1963
- work_arrays._old_z_values,
1964
- z.values,
2205
+ _as_3d_array(work_arrays._old_z_values, z.block_shape),
2206
+ z.scalar_values,
1965
2207
  ],
1966
2208
  )
1967
2209
 
2210
+ max_subblock_dim = 12
2211
+ if tile_size > 0:
2212
+ use_tiles = True
2213
+ elif tile_size < 0:
2214
+ use_tiles = False
2215
+ else:
2216
+ # Heuristic for using tiled variant: few or very large blocks
2217
+ tile_size = 64
2218
+ max_tiles_per_sm = 2048 // tile_size # assume 64 resident warps per SM
2219
+ use_tiles = device.is_cuda and (
2220
+ max(x.block_size, y.block_size, z.block_size) > max_subblock_dim**2
2221
+ or mm_nnz < max_tiles_per_sm * device.sm_count
2222
+ )
2223
+
2224
+ if use_tiles:
2225
+ subblock_rows = min(max_subblock_dim, z.block_shape[0])
2226
+ subblock_cols = min(max_subblock_dim, z.block_shape[1])
2227
+
2228
+ wp.launch(
2229
+ kernel=make_bsr_mm_compute_values_tiled_outer(
2230
+ subblock_rows, subblock_cols, x.block_shape[1], z.scalar_type, tile_size
2231
+ ),
2232
+ device=device,
2233
+ dim=(
2234
+ z.nnz,
2235
+ (z.block_shape[0] + subblock_rows - 1) // subblock_rows,
2236
+ (z.block_shape[1] + subblock_cols - 1) // subblock_cols,
2237
+ tile_size,
2238
+ ),
2239
+ block_dim=tile_size,
2240
+ inputs=[
2241
+ alpha,
2242
+ work_arrays._old_z_offsets if x == z else x.offsets,
2243
+ work_arrays._old_z_columns if x == z else x.columns,
2244
+ _as_3d_array(work_arrays._old_z_values, z.block_shape) if x == z else x.scalar_values,
2245
+ work_arrays._old_z_offsets if y == z else y.offsets,
2246
+ work_arrays._old_z_columns if y == z else y.columns,
2247
+ _as_3d_array(work_arrays._old_z_values, z.block_shape) if y == z else y.scalar_values,
2248
+ None if masked else work_arrays._mm_row_min,
2249
+ None if masked else summed_triplet_offsets,
2250
+ None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
2251
+ z.nrow,
2252
+ z.offsets,
2253
+ z.columns,
2254
+ z.scalar_values,
2255
+ ],
2256
+ )
2257
+
2258
+ return z
2259
+
1968
2260
  # Add mm blocks to z values
1969
2261
  if (type_is_matrix(x.values.dtype) or type_is_matrix(y.values.dtype)) and not (type_is_matrix(z.values.dtype)):
1970
2262
  # Result block type is scalar, but operands are matrices
@@ -1985,6 +2277,9 @@ def bsr_mm(
1985
2277
  work_arrays._old_z_offsets if y == z else y.offsets,
1986
2278
  work_arrays._old_z_columns if y == z else y.columns,
1987
2279
  work_arrays._old_z_values if y == z else y.values,
2280
+ None if masked else work_arrays._mm_row_min,
2281
+ None if masked else summed_triplet_offsets,
2282
+ None if masked else work_arrays._mm_src_blocks[summed_triplet_indices],
1988
2283
  z.nrow,
1989
2284
  z.offsets,
1990
2285
  z.columns,
@@ -1995,51 +2290,125 @@ def bsr_mm(
1995
2290
  return z
1996
2291
 
1997
2292
 
1998
- @wp.kernel
1999
- def _bsr_mv_kernel(
2000
- alpha: Any,
2001
- A_offsets: wp.array(dtype=int),
2002
- A_columns: wp.array(dtype=int),
2003
- A_values: wp.array(dtype=Any),
2004
- x: wp.array(dtype=Any),
2005
- beta: Any,
2006
- y: wp.array(dtype=Any),
2007
- ):
2008
- row = wp.tid()
2293
+ def make_bsr_mv_kernel(block_cols: int):
2294
+ from warp.fem.cache import dynamic_kernel
2009
2295
 
2010
- # zero-initialize with type of y elements
2011
- scalar_zero = type(alpha)(0)
2012
- v = y.dtype(scalar_zero)
2296
+ @dynamic_kernel(suffix=f"{block_cols}", kernel_options={"enable_backward": False})
2297
+ def bsr_mv_kernel(
2298
+ alpha: Any,
2299
+ A_offsets: wp.array(dtype=int),
2300
+ A_columns: wp.array(dtype=int),
2301
+ A_values: wp.array3d(dtype=Any),
2302
+ x: wp.array(dtype=Any),
2303
+ beta: Any,
2304
+ y: wp.array(dtype=Any),
2305
+ ):
2306
+ row, subrow = wp.tid()
2013
2307
 
2014
- if alpha != scalar_zero:
2015
- beg = A_offsets[row]
2016
- end = A_offsets[row + 1]
2017
- for block in range(beg, end):
2018
- v += A_values[block] * x[A_columns[block]]
2019
- v *= alpha
2308
+ block_rows = A_values.shape[1]
2020
2309
 
2021
- if beta != scalar_zero:
2022
- v += beta * y[row]
2310
+ yi = row * block_rows + subrow
2023
2311
 
2024
- y[row] = v
2312
+ # zero-initialize with type of y elements
2313
+ scalar_zero = type(alpha)(0)
2314
+ v = scalar_zero
2025
2315
 
2316
+ if alpha != scalar_zero:
2317
+ beg = A_offsets[row]
2318
+ end = A_offsets[row + 1]
2319
+ for block in range(beg, end):
2320
+ xs = A_columns[block] * block_cols
2321
+ for col in range(wp.static(block_cols)):
2322
+ v += A_values[block, subrow, col] * x[xs + col]
2323
+ v *= alpha
2026
2324
 
2027
- @wp.kernel
2028
- def _bsr_mv_transpose_kernel(
2029
- alpha: Any,
2030
- A_offsets: wp.array(dtype=int),
2031
- A_columns: wp.array(dtype=int),
2032
- A_values: wp.array(dtype=Any),
2033
- x: wp.array(dtype=Any),
2034
- y: wp.array(dtype=Any),
2035
- ):
2036
- row = wp.tid()
2037
- beg = A_offsets[row]
2038
- end = A_offsets[row + 1]
2039
- xr = alpha * x[row]
2040
- for block in range(beg, end):
2041
- v = wp.transpose(A_values[block]) * xr
2042
- wp.atomic_add(y, A_columns[block], v)
2325
+ if beta != scalar_zero:
2326
+ v += beta * y[yi]
2327
+
2328
+ y[yi] = v
2329
+
2330
+ return bsr_mv_kernel
2331
+
2332
+
2333
+ def make_bsr_mv_tiled_kernel(tile_size: int):
2334
+ from warp.fem.cache import dynamic_kernel
2335
+
2336
+ @dynamic_kernel(suffix=f"{tile_size}", kernel_options={"enable_backward": False})
2337
+ def bsr_mv_tiled_kernel(
2338
+ alpha: Any,
2339
+ A_offsets: wp.array(dtype=int),
2340
+ A_columns: wp.array(dtype=int),
2341
+ A_values: wp.array3d(dtype=Any),
2342
+ x: wp.array(dtype=Any),
2343
+ beta: Any,
2344
+ y: wp.array(dtype=Any),
2345
+ ):
2346
+ row, subrow, lane = wp.tid()
2347
+
2348
+ scalar_zero = type(alpha)(0)
2349
+ block_rows = A_values.shape[1]
2350
+ block_cols = A_values.shape[2]
2351
+
2352
+ yi = row * block_rows + subrow
2353
+
2354
+ if beta == scalar_zero:
2355
+ subrow_sum = wp.tile_zeros(shape=(1,), dtype=y.dtype)
2356
+ else:
2357
+ subrow_sum = beta * wp.tile_load(y, 1, yi)
2358
+
2359
+ if alpha != scalar_zero:
2360
+ block_beg = A_offsets[row]
2361
+ col_count = (A_offsets[row + 1] - block_beg) * block_cols
2362
+
2363
+ col = lane
2364
+ lane_sum = y.dtype(0)
2365
+
2366
+ for col in range(lane, col_count, tile_size):
2367
+ block = col // block_cols
2368
+ block_col = col - block * block_cols
2369
+ block += block_beg
2370
+
2371
+ xi = x[A_columns[block] * block_cols + block_col]
2372
+ lane_sum += A_values[block, subrow, block_col] * xi
2373
+
2374
+ lane_sum *= alpha
2375
+ subrow_sum += wp.tile_sum(wp.tile(lane_sum))
2376
+
2377
+ wp.tile_store(y, subrow_sum, yi)
2378
+
2379
+ return bsr_mv_tiled_kernel
2380
+
2381
+
2382
+ def make_bsr_mv_transpose_kernel(block_rows: int):
2383
+ from warp.fem.cache import dynamic_kernel
2384
+
2385
+ @dynamic_kernel(suffix=f"{block_rows}", kernel_options={"enable_backward": False})
2386
+ def bsr_mv_transpose_kernel(
2387
+ alpha: Any,
2388
+ A_row_count: int,
2389
+ A_offsets: wp.array(dtype=int),
2390
+ A_columns: wp.array(dtype=int),
2391
+ A_values: wp.array3d(dtype=Any),
2392
+ x: wp.array(dtype=Any),
2393
+ y: wp.array(dtype=Any),
2394
+ ):
2395
+ block, subcol = wp.tid()
2396
+
2397
+ row = _bsr_row_index(A_offsets, A_row_count, block)
2398
+ if row == -1:
2399
+ return
2400
+
2401
+ block_cols = A_values.shape[2]
2402
+
2403
+ A_block = A_values[block]
2404
+
2405
+ col_sum = type(alpha)(0)
2406
+ for subrow in range(wp.static(block_rows)):
2407
+ col_sum += A_block[subrow, subcol] * x[row * block_rows + subrow]
2408
+
2409
+ wp.atomic_add(y, A_columns[block] * block_cols + subcol, alpha * col_sum)
2410
+
2411
+ return bsr_mv_transpose_kernel
2043
2412
 
2044
2413
 
2045
2414
  def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
@@ -2092,6 +2461,7 @@ def bsr_mv(
2092
2461
  beta: Scalar = 0.0,
2093
2462
  transpose: bool = False,
2094
2463
  work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None,
2464
+ tile_size: int = 0,
2095
2465
  ) -> "Array[Vector[Rows, Scalar] | Scalar]":
2096
2466
  """Perform the sparse matrix-vector product ``y := alpha * A * x + beta * y`` and return ``y``.
2097
2467
 
@@ -2107,6 +2477,9 @@ def bsr_mv(
2107
2477
  work_buffer: Temporary storage is required if and only if ``x`` and ``y`` are the same vector.
2108
2478
  If provided, the ``work_buffer`` array will be used for this purpose,
2109
2479
  otherwise a temporary allocation will be performed.
2480
+ tile_size: If a positive integer, use tiles of this size to compute the matrix-matrix product.
2481
+ If negative, disable tile-based computation. Defaults to ``0``, which determines whether to
2482
+ use tiles using using an heuristic based on the matrix shape and number of non-zeros..
2110
2483
  """
2111
2484
 
2112
2485
  A, A_scale = _extract_matrix_and_scale(A)
@@ -2129,6 +2502,7 @@ def bsr_mv(
2129
2502
  alpha = A.scalar_type(alpha)
2130
2503
  beta = A.scalar_type(beta)
2131
2504
 
2505
+ device = A.values.device
2132
2506
  if A.values.device != x.device or A.values.device != y.device:
2133
2507
  raise ValueError(
2134
2508
  f"A, x, and y must reside on the same device, got {A.values.device}, {x.device} and {y.device}"
@@ -2149,23 +2523,24 @@ def bsr_mv(
2149
2523
  wp.copy(dest=work_buffer, src=y, count=y.size)
2150
2524
  x = work_buffer
2151
2525
 
2152
- # Promote scalar vectors to length-1 vecs and conversely
2153
- if type_is_matrix(A.values.dtype):
2154
- x_dtype = wp.vec(length=block_shape[1], dtype=A.scalar_type)
2155
- y_dtype = wp.vec(length=block_shape[0], dtype=A.scalar_type)
2156
- else:
2157
- x_dtype = A.scalar_type
2158
- y_dtype = A.scalar_type
2159
-
2160
2526
  try:
2161
- x_view = _vec_array_view(x, x_dtype, expected_scalar_count=ncol * block_shape[1])
2527
+ x_view = _vec_array_view(x, A.scalar_type, expected_scalar_count=ncol * block_shape[1])
2162
2528
  except ValueError as err:
2163
2529
  raise ValueError("Incompatible 'x' vector for bsr_mv") from err
2164
2530
  try:
2165
- y_view = _vec_array_view(y, y_dtype, expected_scalar_count=nrow * block_shape[0])
2531
+ y_view = _vec_array_view(y, A.scalar_type, expected_scalar_count=nrow * block_shape[0])
2166
2532
  except ValueError as err:
2167
2533
  raise ValueError("Incompatible 'y' vector for bsr_mv") from err
2168
2534
 
2535
+ # heuristic to use tiled version for long rows
2536
+ if tile_size > 0:
2537
+ use_tiles = True
2538
+ elif tile_size < 0:
2539
+ use_tiles = False
2540
+ else:
2541
+ tile_size = 64
2542
+ use_tiles = device.is_cuda and A.nnz * A.block_size > 2 * tile_size * A.shape[0]
2543
+
2169
2544
  if transpose:
2170
2545
  if beta.value == 0.0:
2171
2546
  y.zero_()
@@ -2173,22 +2548,30 @@ def bsr_mv(
2173
2548
  wp.launch(
2174
2549
  kernel=_bsr_scale_kernel,
2175
2550
  device=y.device,
2176
- dim=y.shape[0],
2177
- inputs=[beta, y],
2551
+ dim=y_view.shape[0],
2552
+ inputs=[beta, y_view],
2178
2553
  )
2179
2554
  if alpha.value != 0.0:
2180
2555
  wp.launch(
2181
- kernel=_bsr_mv_transpose_kernel,
2556
+ kernel=make_bsr_mv_transpose_kernel(block_rows=block_shape[1]),
2182
2557
  device=A.values.device,
2183
- dim=ncol,
2184
- inputs=[alpha, A.offsets, A.columns, A.values, x_view, y_view],
2558
+ dim=(A.nnz, block_shape[0]),
2559
+ inputs=[alpha, A.nrow, A.offsets, A.columns, A.scalar_values, x_view, y_view],
2185
2560
  )
2561
+ elif use_tiles:
2562
+ wp.launch(
2563
+ kernel=make_bsr_mv_tiled_kernel(tile_size),
2564
+ device=A.values.device,
2565
+ dim=(nrow, block_shape[0], tile_size),
2566
+ block_dim=tile_size,
2567
+ inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
2568
+ )
2186
2569
  else:
2187
2570
  wp.launch(
2188
- kernel=_bsr_mv_kernel,
2571
+ kernel=make_bsr_mv_kernel(block_cols=block_shape[1]),
2189
2572
  device=A.values.device,
2190
- dim=nrow,
2191
- inputs=[alpha, A.offsets, A.columns, A.values, x_view, beta, y_view],
2573
+ dim=(nrow, block_shape[0]),
2574
+ inputs=[alpha, A.offsets, A.columns, A.scalar_values, x_view, beta, y_view],
2192
2575
  )
2193
2576
 
2194
2577
  return y