warp-lang 1.7.2__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +125 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +257 -101
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +657 -223
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +97 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +107 -52
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +12 -17
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +99 -0
- warp/native/builtin.h +174 -31
- warp/native/coloring.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +3 -3
- warp/native/mat.h +5 -10
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/quat.h +28 -4
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/tile.h +583 -72
- warp/native/tile_radix_sort.h +1108 -0
- warp/native/tile_reduce.h +237 -2
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +6 -16
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +574 -51
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +58 -29
- warp/render/render_usd.py +124 -61
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +252 -78
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +751 -320
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +52 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +15 -1
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_atomic_cas.py +299 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +1 -24
- warp/tests/test_quat.py +6 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +51 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/tile/test_tile.py +420 -1
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_reduce.py +213 -0
- warp/tests/tile/test_tile_shared_memory.py +130 -1
- warp/tests/tile/test_tile_sort.py +117 -0
- warp/tests/unittest_suites.py +4 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/sparse.py
CHANGED
|
@@ -28,9 +28,10 @@ from warp.types import (
|
|
|
28
28
|
is_array,
|
|
29
29
|
scalar_types,
|
|
30
30
|
type_is_matrix,
|
|
31
|
-
type_length,
|
|
32
31
|
type_repr,
|
|
33
32
|
type_scalar_type,
|
|
33
|
+
type_size,
|
|
34
|
+
type_size_in_bytes,
|
|
34
35
|
type_to_warp,
|
|
35
36
|
types_equal,
|
|
36
37
|
)
|
|
@@ -86,7 +87,7 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
86
87
|
@property
|
|
87
88
|
def block_size(self) -> int:
|
|
88
89
|
"""Size of the individual blocks, i.e. number of rows per block times number of columns per block."""
|
|
89
|
-
return
|
|
90
|
+
return type_size(self.values.dtype)
|
|
90
91
|
|
|
91
92
|
@property
|
|
92
93
|
def shape(self) -> Tuple[int, int]:
|
|
@@ -104,23 +105,15 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
104
105
|
"""Device on which ``offsets``, ``columns``, and ``values`` are allocated -- assumed to be the same for all three arrays."""
|
|
105
106
|
return self.values.device
|
|
106
107
|
|
|
108
|
+
@property
|
|
109
|
+
def requires_grad(self) -> bool:
|
|
110
|
+
"""Read-only property indicating whether the matrix participates in adjoint computations."""
|
|
111
|
+
return self.values.requires_grad
|
|
112
|
+
|
|
107
113
|
@property
|
|
108
114
|
def scalar_values(self) -> wp.array:
|
|
109
115
|
"""Accesses the ``values`` array as a 3d scalar array."""
|
|
110
|
-
|
|
111
|
-
return self.values.reshape((self.nnz, 1, 1))
|
|
112
|
-
|
|
113
|
-
def _as_3d_array(arr):
|
|
114
|
-
return wp.array(
|
|
115
|
-
ptr=arr.ptr,
|
|
116
|
-
capacity=arr.capacity,
|
|
117
|
-
device=arr.device,
|
|
118
|
-
dtype=self.scalar_type,
|
|
119
|
-
shape=(self.nnz, *self.block_shape),
|
|
120
|
-
grad=None if arr.grad is None else _as_3d_array(arr.grad),
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
values_view = _as_3d_array(self.values)
|
|
116
|
+
values_view = _as_3d_array(self.values, self.block_shape)
|
|
124
117
|
values_view._ref = self.values # keep ref in case we're garbage collected
|
|
125
118
|
return values_view
|
|
126
119
|
|
|
@@ -144,13 +137,14 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
144
137
|
See also :meth:`copy_nnz_async`.
|
|
145
138
|
"""
|
|
146
139
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
140
|
+
buf, event = self._nnz_transfer_if_any()
|
|
141
|
+
if buf is not None:
|
|
142
|
+
if event is not None:
|
|
143
|
+
wp.synchronize_event(event)
|
|
144
|
+
self.nnz = int(buf.numpy()[0])
|
|
151
145
|
return self.nnz
|
|
152
146
|
|
|
153
|
-
def copy_nnz_async(self
|
|
147
|
+
def copy_nnz_async(self) -> None:
|
|
154
148
|
"""
|
|
155
149
|
Start the asynchronous transfer of the exact nnz from the device offsets array to host and records an event for completion.
|
|
156
150
|
|
|
@@ -158,37 +152,25 @@ class BsrMatrix(Generic[_BlockType]):
|
|
|
158
152
|
|
|
159
153
|
See also :meth:`nnz_sync`.
|
|
160
154
|
"""
|
|
161
|
-
if known_nnz is not None:
|
|
162
|
-
self.nnz = int(known_nnz)
|
|
163
|
-
else:
|
|
164
|
-
self._setup_nnz_transfer()
|
|
165
155
|
|
|
166
|
-
|
|
167
|
-
if self.
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
stream.record_event(self._nnz_event)
|
|
156
|
+
buf, event = self._setup_nnz_transfer()
|
|
157
|
+
stream = wp.get_stream(self.device) if self.device.is_cuda else None
|
|
158
|
+
wp.copy(src=self.offsets, dest=buf, src_offset=self.nrow, count=1, stream=stream)
|
|
159
|
+
if event is not None:
|
|
160
|
+
stream.record_event(event)
|
|
172
161
|
|
|
173
162
|
def _setup_nnz_transfer(self):
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
BsrMatrix.__setattr__(
|
|
178
|
-
self, "_nnz_buf", wp.empty(dtype=int, shape=(1,), device="cpu", pinned=self.device.is_cuda)
|
|
179
|
-
)
|
|
180
|
-
if self.device.is_cuda:
|
|
181
|
-
BsrMatrix.__setattr__(self, "_nnz_event", wp.Event(self.device))
|
|
182
|
-
|
|
183
|
-
def _is_nnz_transfer_setup(self):
|
|
184
|
-
return hasattr(self, "_nnz_buf")
|
|
163
|
+
buf, event = self._nnz_transfer_if_any()
|
|
164
|
+
if buf is not None or self.device.is_capturing:
|
|
165
|
+
return buf, event
|
|
185
166
|
|
|
186
|
-
|
|
187
|
-
self.
|
|
167
|
+
buf = wp.empty(dtype=int, shape=(1,), device="cpu", pinned=self.device.is_cuda)
|
|
168
|
+
event = wp.Event(self.device) if self.device.is_cuda else None
|
|
169
|
+
BsrMatrix.__setattr__(self, "_nnz_transfer", (buf, event))
|
|
170
|
+
return buf, event
|
|
188
171
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
return self._nnz_buf, self._nnz_event.cuda_event
|
|
172
|
+
def _nnz_transfer_if_any(self):
|
|
173
|
+
return getattr(self, "_nnz_transfer", (None, None))
|
|
192
174
|
|
|
193
175
|
# Overloaded math operators
|
|
194
176
|
def __add__(self, y):
|
|
@@ -303,7 +285,7 @@ def bsr_zeros(
|
|
|
303
285
|
|
|
304
286
|
bsr.nrow = int(rows_of_blocks)
|
|
305
287
|
bsr.ncol = int(cols_of_blocks)
|
|
306
|
-
bsr.nnz =
|
|
288
|
+
bsr.nnz = 0
|
|
307
289
|
bsr.columns = wp.empty(shape=(0,), dtype=int, device=device)
|
|
308
290
|
bsr.values = wp.empty(shape=(0,), dtype=block_type, device=device)
|
|
309
291
|
bsr.offsets = wp.zeros(shape=(bsr.nrow + 1,), dtype=int, device=device)
|
|
@@ -311,7 +293,7 @@ def bsr_zeros(
|
|
|
311
293
|
return bsr
|
|
312
294
|
|
|
313
295
|
|
|
314
|
-
def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
|
|
296
|
+
def _bsr_ensure_fits(bsr: BsrMatrix, nrow: Optional[int] = None, nnz: Optional[int] = None) -> None:
|
|
315
297
|
if nrow is None:
|
|
316
298
|
nrow = bsr.nrow
|
|
317
299
|
if nnz is None:
|
|
@@ -325,7 +307,9 @@ def _bsr_ensure_fits(bsr: BsrMatrix, nrow: int = None, nnz: int = None):
|
|
|
325
307
|
if bsr.columns.size < nnz:
|
|
326
308
|
bsr.columns = wp.empty(shape=(nnz,), dtype=int, device=bsr.columns.device)
|
|
327
309
|
if bsr.values.size < nnz:
|
|
328
|
-
bsr.values = wp.empty(
|
|
310
|
+
bsr.values = wp.empty(
|
|
311
|
+
shape=(nnz,), dtype=bsr.values.dtype, device=bsr.values.device, requires_grad=bsr.values.requires_grad
|
|
312
|
+
)
|
|
329
313
|
|
|
330
314
|
|
|
331
315
|
def bsr_set_zero(
|
|
@@ -348,7 +332,64 @@ def bsr_set_zero(
|
|
|
348
332
|
|
|
349
333
|
_bsr_ensure_fits(bsr, nnz=0)
|
|
350
334
|
bsr.offsets.zero_()
|
|
351
|
-
bsr.copy_nnz_async(
|
|
335
|
+
bsr.copy_nnz_async()
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _as_3d_array(arr, block_shape):
|
|
339
|
+
return wp.array(
|
|
340
|
+
ptr=arr.ptr,
|
|
341
|
+
capacity=arr.capacity,
|
|
342
|
+
device=arr.device,
|
|
343
|
+
dtype=type_scalar_type(arr.dtype),
|
|
344
|
+
shape=(arr.shape[0], *block_shape),
|
|
345
|
+
grad=None if arr.grad is None else _as_3d_array(arr.grad, block_shape),
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _optional_ctypes_pointer(array: Optional[wp.array], ctype):
|
|
350
|
+
return None if array is None else ctypes.cast(array.ptr, ctypes.POINTER(ctype))
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _optional_ctypes_event(event: Optional[wp.Event]):
|
|
354
|
+
return None if event is None else event.cuda_event
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
_zero_value_masks = {
|
|
358
|
+
wp.float16: 0x7FFF,
|
|
359
|
+
wp.float32: 0x7FFFFFFF,
|
|
360
|
+
wp.float64: 0x7FFFFFFFFFFFFFFF,
|
|
361
|
+
wp.int8: 0xFF,
|
|
362
|
+
wp.int16: 0xFFFF,
|
|
363
|
+
wp.int32: 0xFFFFFFFF,
|
|
364
|
+
wp.int64: 0xFFFFFFFFFFFFFFFF,
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
@wp.kernel
|
|
369
|
+
def _bsr_accumulate_triplet_values(
|
|
370
|
+
row_count: int,
|
|
371
|
+
tpl_summed_offsets: wp.array(dtype=int),
|
|
372
|
+
tpl_summed_indices: wp.array(dtype=int),
|
|
373
|
+
tpl_values: wp.array3d(dtype=Any),
|
|
374
|
+
bsr_offsets: wp.array(dtype=int),
|
|
375
|
+
bsr_values: wp.array3d(dtype=Any),
|
|
376
|
+
):
|
|
377
|
+
block, i, j = wp.tid()
|
|
378
|
+
|
|
379
|
+
if block >= bsr_offsets[row_count]:
|
|
380
|
+
return
|
|
381
|
+
|
|
382
|
+
if block == 0:
|
|
383
|
+
beg = 0
|
|
384
|
+
else:
|
|
385
|
+
beg = tpl_summed_offsets[block - 1]
|
|
386
|
+
end = tpl_summed_offsets[block]
|
|
387
|
+
|
|
388
|
+
val = tpl_values[tpl_summed_indices[beg], i, j]
|
|
389
|
+
for k in range(beg + 1, end):
|
|
390
|
+
val += tpl_values[tpl_summed_indices[k], i, j]
|
|
391
|
+
|
|
392
|
+
bsr_values[block, i, j] = val
|
|
352
393
|
|
|
353
394
|
|
|
354
395
|
def bsr_set_from_triplets(
|
|
@@ -356,6 +397,7 @@ def bsr_set_from_triplets(
|
|
|
356
397
|
rows: "Array[int]",
|
|
357
398
|
columns: "Array[int]",
|
|
358
399
|
values: Optional["Array[Union[Scalar, BlockType[Rows, Cols, Scalar]]]"] = None,
|
|
400
|
+
count: Optional["Array[int]"] = None,
|
|
359
401
|
prune_numerical_zeros: bool = True,
|
|
360
402
|
masked: bool = False,
|
|
361
403
|
):
|
|
@@ -370,27 +412,50 @@ def bsr_set_from_triplets(
|
|
|
370
412
|
values: Block values for each non-zero. Must be either a one-dimensional array with data type identical
|
|
371
413
|
to the ``dest`` matrix's block type, or a 3d array with data type equal to the ``dest`` matrix's scalar type.
|
|
372
414
|
If ``None``, the values array of the resulting matrix will be allocated but uninitialized.
|
|
415
|
+
count: Single-element array indicating the number of triplets. If ``None``, the number of triplets is determined from the shape of
|
|
416
|
+
``rows`` and ``columns`` arrays.
|
|
373
417
|
prune_numerical_zeros: If ``True``, will ignore the zero-valued blocks.
|
|
374
418
|
masked: If ``True``, ignore blocks that are not existing non-zeros of ``dest``.
|
|
375
419
|
"""
|
|
376
420
|
|
|
377
421
|
if rows.device != columns.device or rows.device != dest.device:
|
|
378
|
-
raise ValueError(
|
|
422
|
+
raise ValueError(
|
|
423
|
+
f"Rows and columns must reside on the destination matrix device, got {rows.device}, {columns.device} and {dest.device}"
|
|
424
|
+
)
|
|
379
425
|
|
|
380
426
|
if rows.shape[0] != columns.shape[0]:
|
|
381
|
-
raise ValueError(
|
|
427
|
+
raise ValueError(
|
|
428
|
+
f"Rows and columns arrays must have the same length, got {rows.shape[0]} and {columns.shape[0]}"
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if rows.dtype != wp.int32 or columns.dtype != wp.int32:
|
|
432
|
+
raise TypeError("Rows and columns arrays must be of type int32")
|
|
433
|
+
|
|
434
|
+
if count is not None:
|
|
435
|
+
if count.device != rows.device:
|
|
436
|
+
raise ValueError(f"Count and rows must reside on the same device, got {count.device} and {rows.device}")
|
|
437
|
+
|
|
438
|
+
if count.shape != (1,):
|
|
439
|
+
raise ValueError(f"Count array must be a single-element array, got {count.shape}")
|
|
440
|
+
|
|
441
|
+
if count.dtype != wp.int32:
|
|
442
|
+
raise TypeError("Count array must be of type int32")
|
|
382
443
|
|
|
383
444
|
# Accept either array1d(dtype) or contiguous array3d(scalar_type) as values
|
|
384
445
|
if values is not None:
|
|
385
446
|
if values.device != rows.device:
|
|
386
|
-
raise ValueError("
|
|
447
|
+
raise ValueError(f"Values and rows must reside on the same device, got {values.device} and {rows.device}")
|
|
387
448
|
|
|
388
449
|
if values.shape[0] != rows.shape[0]:
|
|
389
|
-
raise ValueError(
|
|
450
|
+
raise ValueError(
|
|
451
|
+
f"Values and rows arrays must have the same length, got {values.shape[0]} and {rows.shape[0]}"
|
|
452
|
+
)
|
|
390
453
|
|
|
391
454
|
if values.ndim == 1:
|
|
392
|
-
if values.dtype
|
|
393
|
-
raise ValueError(
|
|
455
|
+
if not types_equal(values.dtype, dest.values.dtype):
|
|
456
|
+
raise ValueError(
|
|
457
|
+
f"Values array type must correspond to that of the dest matrix, got {type_repr(values.dtype)} and {type_repr(dest.values.dtype)}"
|
|
458
|
+
)
|
|
394
459
|
elif values.ndim == 3:
|
|
395
460
|
if values.shape[1:] != dest.block_shape:
|
|
396
461
|
raise ValueError(
|
|
@@ -398,12 +463,14 @@ def bsr_set_from_triplets(
|
|
|
398
463
|
)
|
|
399
464
|
|
|
400
465
|
if type_scalar_type(values.dtype) != dest.scalar_type:
|
|
401
|
-
raise ValueError(
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
raise ValueError("Multi-dimensional values array should be contiguous")
|
|
466
|
+
raise ValueError(
|
|
467
|
+
f"Scalar type of values array ({type_repr(values.dtype)}) should correspond to that of matrix ({type_repr(dest.scalar_type)})"
|
|
468
|
+
)
|
|
405
469
|
else:
|
|
406
|
-
raise ValueError("Number of dimension for values array should be 1 or 3")
|
|
470
|
+
raise ValueError(f"Number of dimension for values array should be 1 or 3, got {values.ndim}")
|
|
471
|
+
|
|
472
|
+
if prune_numerical_zeros and not values.is_contiguous:
|
|
473
|
+
raise ValueError("Values array should be contiguous for numerical zero pruning")
|
|
407
474
|
|
|
408
475
|
nnz = rows.shape[0]
|
|
409
476
|
if nnz == 0:
|
|
@@ -416,40 +483,54 @@ def bsr_set_from_triplets(
|
|
|
416
483
|
|
|
417
484
|
device = dest.values.device
|
|
418
485
|
scalar_type = dest.scalar_type
|
|
486
|
+
zero_value_mask = _zero_value_masks.get(scalar_type, 0)
|
|
487
|
+
|
|
488
|
+
# compute the BSR topology
|
|
489
|
+
|
|
419
490
|
from warp.context import runtime
|
|
420
491
|
|
|
421
492
|
if device.is_cpu:
|
|
422
|
-
|
|
423
|
-
native_func = runtime.core.bsr_matrix_from_triplets_float_host
|
|
424
|
-
elif scalar_type == wp.float64:
|
|
425
|
-
native_func = runtime.core.bsr_matrix_from_triplets_double_host
|
|
493
|
+
native_func = runtime.core.bsr_matrix_from_triplets_host
|
|
426
494
|
else:
|
|
427
|
-
|
|
428
|
-
native_func = runtime.core.bsr_matrix_from_triplets_float_device
|
|
429
|
-
elif scalar_type == wp.float64:
|
|
430
|
-
native_func = runtime.core.bsr_matrix_from_triplets_double_device
|
|
495
|
+
native_func = runtime.core.bsr_matrix_from_triplets_device
|
|
431
496
|
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
|
|
497
|
+
nnz_buf, nnz_event = dest._setup_nnz_transfer()
|
|
498
|
+
summed_triplet_offsets = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
|
|
499
|
+
summed_triplet_indices = wp.empty(shape=(nnz,), dtype=wp.int32, device=device)
|
|
436
500
|
|
|
437
501
|
with wp.ScopedDevice(device):
|
|
438
502
|
native_func(
|
|
439
|
-
dest.
|
|
440
|
-
|
|
503
|
+
dest.block_size,
|
|
504
|
+
type_size_in_bytes(scalar_type),
|
|
441
505
|
dest.nrow,
|
|
506
|
+
dest.ncol,
|
|
442
507
|
nnz,
|
|
508
|
+
_optional_ctypes_pointer(count, ctype=ctypes.c_int32),
|
|
443
509
|
ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
444
510
|
ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
445
|
-
|
|
446
|
-
|
|
511
|
+
_optional_ctypes_pointer(values, ctype=ctypes.c_int32),
|
|
512
|
+
zero_value_mask,
|
|
447
513
|
masked,
|
|
514
|
+
ctypes.cast(summed_triplet_offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
515
|
+
ctypes.cast(summed_triplet_indices.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
448
516
|
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
449
517
|
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
518
|
+
_optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
|
|
519
|
+
_optional_ctypes_event(nnz_event),
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
# now accumulate repeated blocks
|
|
523
|
+
wp.launch(
|
|
524
|
+
_bsr_accumulate_triplet_values,
|
|
525
|
+
dim=(nnz, *dest.block_shape),
|
|
526
|
+
inputs=[
|
|
527
|
+
dest.nrow,
|
|
528
|
+
summed_triplet_offsets,
|
|
529
|
+
summed_triplet_indices,
|
|
530
|
+
_as_3d_array(values, dest.block_shape),
|
|
531
|
+
dest.offsets,
|
|
532
|
+
],
|
|
533
|
+
outputs=[dest.scalar_values],
|
|
453
534
|
)
|
|
454
535
|
|
|
455
536
|
|
|
@@ -483,6 +564,7 @@ def bsr_from_triplets(
|
|
|
483
564
|
A = bsr_zeros(
|
|
484
565
|
rows_of_blocks=rows_of_blocks, cols_of_blocks=cols_of_blocks, block_type=block_type, device=values.device
|
|
485
566
|
)
|
|
567
|
+
A.values.requires_grad = values.requires_grad
|
|
486
568
|
bsr_set_from_triplets(A, rows, columns, values, prune_numerical_zeros=prune_numerical_zeros)
|
|
487
569
|
return A
|
|
488
570
|
|
|
@@ -539,6 +621,10 @@ class _BsrScalingExpression(_BsrExpression):
|
|
|
539
621
|
def dtype(self) -> type:
|
|
540
622
|
return self.mat.dtype
|
|
541
623
|
|
|
624
|
+
@property
|
|
625
|
+
def requires_grad(self) -> bool:
|
|
626
|
+
return self.mat.requires_grad
|
|
627
|
+
|
|
542
628
|
@property
|
|
543
629
|
def device(self) -> wp.context.Device:
|
|
544
630
|
return self.mat.device
|
|
@@ -721,10 +807,10 @@ def bsr_assign(
|
|
|
721
807
|
src: Matrix to be copied.
|
|
722
808
|
dest: Destination matrix. May have a different block shape or scalar type
|
|
723
809
|
than ``src``, in which case the required casting will be performed.
|
|
724
|
-
structure_only: If ``True``, only the non-
|
|
810
|
+
structure_only: If ``True``, only the non-zero indices are copied, and uninitialized value storage is allocated
|
|
725
811
|
to accommodate at least ``src.nnz`` blocks. If ``structure_only`` is ``False``, values are also copied with implicit
|
|
726
812
|
casting if the two matrices use distinct scalar types.
|
|
727
|
-
masked: If ``True``, prevent the assignment operation from adding new non-
|
|
813
|
+
masked: If ``True``, prevent the assignment operation from adding new non-zero blocks to ``dest``.
|
|
728
814
|
"""
|
|
729
815
|
|
|
730
816
|
src, src_scale = _extract_matrix_and_scale(src)
|
|
@@ -741,7 +827,7 @@ def bsr_assign(
|
|
|
741
827
|
|
|
742
828
|
if src_subrows * dest.block_shape[0] != src.block_shape[0] * dest_subrows:
|
|
743
829
|
raise ValueError(
|
|
744
|
-
f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {
|
|
830
|
+
f"Incompatible dest and src block shapes; block rows must evenly divide one another (Got {dest.block_shape[0]}, {src.block_shape[0]})"
|
|
745
831
|
)
|
|
746
832
|
|
|
747
833
|
if src.block_shape[1] >= dest.block_shape[1]:
|
|
@@ -753,14 +839,16 @@ def bsr_assign(
|
|
|
753
839
|
|
|
754
840
|
if src_subcols * dest.block_shape[1] != src.block_shape[1] * dest_subcols:
|
|
755
841
|
raise ValueError(
|
|
756
|
-
f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {
|
|
842
|
+
f"Incompatible dest and src block shapes; block columns must evenly divide one another (Got {dest.block_shape[1]}, {src.block_shape[1]})"
|
|
757
843
|
)
|
|
758
844
|
|
|
759
845
|
dest_nrow = (src.nrow * src_subrows) // dest_subrows
|
|
760
846
|
dest_ncol = (src.ncol * src_subcols) // dest_subcols
|
|
761
847
|
|
|
762
848
|
if src.nrow * src_subrows != dest_nrow * dest_subrows or src.ncol * src_subcols != dest_ncol * dest_subcols:
|
|
763
|
-
raise ValueError(
|
|
849
|
+
raise ValueError(
|
|
850
|
+
f"The requested block shape {dest.block_shape} does not evenly divide the source matrix of total size {src.shape}"
|
|
851
|
+
)
|
|
764
852
|
|
|
765
853
|
nnz_alloc = src.nnz * src_subrows * src_subcols
|
|
766
854
|
if masked:
|
|
@@ -813,27 +901,30 @@ def bsr_assign(
|
|
|
813
901
|
from warp.context import runtime
|
|
814
902
|
|
|
815
903
|
if dest.device.is_cpu:
|
|
816
|
-
native_func = runtime.core.
|
|
904
|
+
native_func = runtime.core.bsr_matrix_from_triplets_host
|
|
817
905
|
else:
|
|
818
|
-
native_func = runtime.core.
|
|
906
|
+
native_func = runtime.core.bsr_matrix_from_triplets_device
|
|
819
907
|
|
|
820
|
-
nnz_buf, nnz_event = dest.
|
|
908
|
+
nnz_buf, nnz_event = dest._setup_nnz_transfer()
|
|
821
909
|
with wp.ScopedDevice(dest.device):
|
|
822
910
|
native_func(
|
|
823
|
-
dest.
|
|
824
|
-
|
|
911
|
+
dest.block_size,
|
|
912
|
+
0, # scalar_size_in_bytes
|
|
825
913
|
dest.nrow,
|
|
914
|
+
dest.ncol,
|
|
826
915
|
nnz_alloc,
|
|
916
|
+
None, # device nnz
|
|
827
917
|
ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
828
918
|
ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
829
|
-
|
|
830
|
-
|
|
919
|
+
None, # triplet values
|
|
920
|
+
0, # zero_value_mask
|
|
831
921
|
masked,
|
|
922
|
+
None, # summed block offsets
|
|
923
|
+
None, # summed block indices
|
|
832
924
|
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
833
925
|
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
nnz_event,
|
|
926
|
+
_optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
|
|
927
|
+
_optional_ctypes_event(nnz_event),
|
|
837
928
|
)
|
|
838
929
|
|
|
839
930
|
# merge block values
|
|
@@ -893,10 +984,28 @@ def bsr_copy(
|
|
|
893
984
|
block_type=block_type,
|
|
894
985
|
device=A.device,
|
|
895
986
|
)
|
|
987
|
+
copy.values.requires_grad = A.requires_grad
|
|
896
988
|
bsr_assign(dest=copy, src=A, structure_only=structure_only)
|
|
897
989
|
return copy
|
|
898
990
|
|
|
899
991
|
|
|
992
|
+
@wp.kernel
|
|
993
|
+
def _bsr_transpose_values(
|
|
994
|
+
col_count: int,
|
|
995
|
+
scale: Any,
|
|
996
|
+
bsr_values: wp.array3d(dtype=Any),
|
|
997
|
+
block_index_map: wp.array(dtype=int),
|
|
998
|
+
transposed_bsr_offsets: wp.array(dtype=int),
|
|
999
|
+
transposed_bsr_values: wp.array3d(dtype=Any),
|
|
1000
|
+
):
|
|
1001
|
+
block, i, j = wp.tid()
|
|
1002
|
+
|
|
1003
|
+
if block >= transposed_bsr_offsets[col_count]:
|
|
1004
|
+
return
|
|
1005
|
+
|
|
1006
|
+
transposed_bsr_values[block, i, j] = bsr_values[block_index_map[block], j, i] * scale
|
|
1007
|
+
|
|
1008
|
+
|
|
900
1009
|
def bsr_set_transpose(
|
|
901
1010
|
dest: BsrMatrix[BlockType[Cols, Rows, Scalar]],
|
|
902
1011
|
src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]],
|
|
@@ -906,15 +1015,17 @@ def bsr_set_transpose(
|
|
|
906
1015
|
src, src_scale = _extract_matrix_and_scale(src)
|
|
907
1016
|
|
|
908
1017
|
if dest.values.device != src.values.device:
|
|
909
|
-
raise ValueError(
|
|
1018
|
+
raise ValueError(
|
|
1019
|
+
f"All arguments must reside on the same device, got {dest.values.device} and {src.values.device}"
|
|
1020
|
+
)
|
|
910
1021
|
|
|
911
1022
|
if dest.scalar_type != src.scalar_type:
|
|
912
|
-
raise ValueError("All arguments must have the same scalar type")
|
|
1023
|
+
raise ValueError(f"All arguments must have the same scalar type, got {dest.scalar_type} and {src.scalar_type}")
|
|
913
1024
|
|
|
914
1025
|
transpose_block_shape = src.block_shape[::-1]
|
|
915
1026
|
|
|
916
1027
|
if dest.block_shape != transpose_block_shape:
|
|
917
|
-
raise ValueError(f"Destination block shape must be {transpose_block_shape}")
|
|
1028
|
+
raise ValueError(f"Destination block shape must be {transpose_block_shape}, got {dest.block_shape}")
|
|
918
1029
|
|
|
919
1030
|
nnz = src.nnz
|
|
920
1031
|
dest.nrow = src.ncol
|
|
@@ -930,36 +1041,33 @@ def bsr_set_transpose(
|
|
|
930
1041
|
from warp.context import runtime
|
|
931
1042
|
|
|
932
1043
|
if dest.values.device.is_cpu:
|
|
933
|
-
|
|
934
|
-
native_func = runtime.core.bsr_transpose_float_host
|
|
935
|
-
elif dest.scalar_type == wp.float64:
|
|
936
|
-
native_func = runtime.core.bsr_transpose_double_host
|
|
1044
|
+
native_func = runtime.core.bsr_transpose_host
|
|
937
1045
|
else:
|
|
938
|
-
|
|
939
|
-
native_func = runtime.core.bsr_transpose_float_device
|
|
940
|
-
elif dest.scalar_type == wp.float64:
|
|
941
|
-
native_func = runtime.core.bsr_transpose_double_device
|
|
1046
|
+
native_func = runtime.core.bsr_transpose_device
|
|
942
1047
|
|
|
943
|
-
|
|
944
|
-
raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")
|
|
1048
|
+
block_index_map = wp.empty(shape=2 * nnz, dtype=int, device=src.device)
|
|
945
1049
|
|
|
946
1050
|
with wp.ScopedDevice(dest.device):
|
|
947
1051
|
native_func(
|
|
948
|
-
src.block_shape[0],
|
|
949
|
-
src.block_shape[1],
|
|
950
1052
|
src.nrow,
|
|
951
1053
|
src.ncol,
|
|
952
1054
|
nnz,
|
|
953
1055
|
ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
954
1056
|
ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
955
|
-
ctypes.cast(src.values.ptr, ctypes.c_void_p),
|
|
956
1057
|
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
957
1058
|
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
958
|
-
ctypes.cast(
|
|
1059
|
+
ctypes.cast(block_index_map.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
959
1060
|
)
|
|
960
1061
|
|
|
961
|
-
|
|
962
|
-
|
|
1062
|
+
dest.copy_nnz_async()
|
|
1063
|
+
|
|
1064
|
+
wp.launch(
|
|
1065
|
+
_bsr_transpose_values,
|
|
1066
|
+
dim=(nnz, *dest.block_shape),
|
|
1067
|
+
device=dest.device,
|
|
1068
|
+
inputs=[src.ncol, dest.scalar_type(src_scale), src.scalar_values, block_index_map, dest.offsets],
|
|
1069
|
+
outputs=[dest.scalar_values],
|
|
1070
|
+
)
|
|
963
1071
|
|
|
964
1072
|
|
|
965
1073
|
def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
|
|
@@ -976,6 +1084,7 @@ def bsr_transposed(A: BsrMatrixOrExpression) -> BsrMatrix:
|
|
|
976
1084
|
block_type=block_type,
|
|
977
1085
|
device=A.device,
|
|
978
1086
|
)
|
|
1087
|
+
transposed.values.requires_grad = A.requires_grad
|
|
979
1088
|
bsr_set_transpose(dest=transposed, src=A)
|
|
980
1089
|
return transposed
|
|
981
1090
|
|
|
@@ -1010,12 +1119,12 @@ def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[Block
|
|
|
1010
1119
|
if out is None:
|
|
1011
1120
|
out = wp.zeros(shape=(dim,), dtype=A.values.dtype, device=A.values.device)
|
|
1012
1121
|
else:
|
|
1013
|
-
if out.dtype
|
|
1014
|
-
raise ValueError(f"Output array must have type {A.values.dtype}")
|
|
1122
|
+
if not types_equal(out.dtype, A.values.dtype):
|
|
1123
|
+
raise ValueError(f"Output array must have type {A.values.dtype}, got {out.dtype}")
|
|
1015
1124
|
if out.device != A.values.device:
|
|
1016
|
-
raise ValueError(f"Output array must reside on device {A.values.device}")
|
|
1125
|
+
raise ValueError(f"Output array must reside on device {A.values.device}, got {out.device}")
|
|
1017
1126
|
if out.shape[0] < dim:
|
|
1018
|
-
raise ValueError(f"Output array must be of length at least {dim}")
|
|
1127
|
+
raise ValueError(f"Output array must be of length at least {dim}, got {out.shape[0]}")
|
|
1019
1128
|
|
|
1020
1129
|
wp.launch(
|
|
1021
1130
|
kernel=_bsr_get_diag_kernel,
|
|
@@ -1095,7 +1204,7 @@ def bsr_set_diag(
|
|
|
1095
1204
|
elif diag is not None:
|
|
1096
1205
|
A.values.fill_(diag)
|
|
1097
1206
|
|
|
1098
|
-
A.copy_nnz_async(
|
|
1207
|
+
A.copy_nnz_async()
|
|
1099
1208
|
|
|
1100
1209
|
|
|
1101
1210
|
def bsr_diag(
|
|
@@ -1151,6 +1260,8 @@ def bsr_diag(
|
|
|
1151
1260
|
block_type = wp.mat(shape=diag.shape, dtype=diag.dtype)
|
|
1152
1261
|
|
|
1153
1262
|
A = bsr_zeros(rows_of_blocks, cols_of_blocks, block_type=block_type, device=device)
|
|
1263
|
+
if is_array(diag):
|
|
1264
|
+
A.values.requires_grad = diag.requires_grad
|
|
1154
1265
|
bsr_set_diag(A, diag)
|
|
1155
1266
|
return A
|
|
1156
1267
|
|
|
@@ -1292,8 +1403,8 @@ def bsr_axpy(
|
|
|
1292
1403
|
The ``x`` and ``y`` matrices are allowed to alias.
|
|
1293
1404
|
|
|
1294
1405
|
Args:
|
|
1295
|
-
x: Read-only
|
|
1296
|
-
y: Mutable
|
|
1406
|
+
x: Read-only first operand.
|
|
1407
|
+
y: Mutable second operand and output matrix. If ``y`` is not provided, it will be allocated and treated as zero.
|
|
1297
1408
|
alpha: Uniform scaling factor for ``x``.
|
|
1298
1409
|
beta: Uniform scaling factor for ``y``.
|
|
1299
1410
|
masked: If ``True``, discard all blocks from ``x`` which are not
|
|
@@ -1312,6 +1423,7 @@ def bsr_axpy(
|
|
|
1312
1423
|
|
|
1313
1424
|
# If not output matrix is provided, allocate it for convenience
|
|
1314
1425
|
y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device)
|
|
1426
|
+
y.values.requires_grad = x.requires_grad
|
|
1315
1427
|
beta = 0.0
|
|
1316
1428
|
|
|
1317
1429
|
x_nnz = x.nnz
|
|
@@ -1337,13 +1449,17 @@ def bsr_axpy(
|
|
|
1337
1449
|
# General case
|
|
1338
1450
|
|
|
1339
1451
|
if x.values.device != y.values.device:
|
|
1340
|
-
raise ValueError("All arguments must reside on the same device")
|
|
1452
|
+
raise ValueError(f"All arguments must reside on the same device, got {x.values.device} and {y.values.device}")
|
|
1341
1453
|
|
|
1342
1454
|
if x.scalar_type != y.scalar_type or x.block_shape != y.block_shape:
|
|
1343
|
-
raise ValueError(
|
|
1455
|
+
raise ValueError(
|
|
1456
|
+
f"Matrices must have the same block type, got ({x.block_shape}, {x.scalar_type}) and ({y.block_shape}, {y.scalar_type})"
|
|
1457
|
+
)
|
|
1344
1458
|
|
|
1345
1459
|
if x.nrow != y.nrow or x.ncol != y.ncol:
|
|
1346
|
-
raise ValueError(
|
|
1460
|
+
raise ValueError(
|
|
1461
|
+
f"Matrices must have the same number of rows and columns, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
|
|
1462
|
+
)
|
|
1347
1463
|
|
|
1348
1464
|
if work_arrays is None:
|
|
1349
1465
|
work_arrays = bsr_axpy_work_arrays()
|
|
@@ -1368,29 +1484,32 @@ def bsr_axpy(
|
|
|
1368
1484
|
from warp.context import runtime
|
|
1369
1485
|
|
|
1370
1486
|
if device.is_cpu:
|
|
1371
|
-
native_func = runtime.core.
|
|
1487
|
+
native_func = runtime.core.bsr_matrix_from_triplets_host
|
|
1372
1488
|
else:
|
|
1373
|
-
native_func = runtime.core.
|
|
1489
|
+
native_func = runtime.core.bsr_matrix_from_triplets_device
|
|
1374
1490
|
|
|
1375
1491
|
old_y_nnz = y_nnz
|
|
1376
|
-
nnz_buf, nnz_event = y.
|
|
1492
|
+
nnz_buf, nnz_event = y._setup_nnz_transfer()
|
|
1377
1493
|
|
|
1378
1494
|
with wp.ScopedDevice(y.device):
|
|
1379
1495
|
native_func(
|
|
1380
|
-
y.
|
|
1381
|
-
|
|
1496
|
+
y.block_size,
|
|
1497
|
+
0, # scalar_size_in_bytes
|
|
1382
1498
|
y.nrow,
|
|
1499
|
+
y.ncol,
|
|
1383
1500
|
sum_nnz,
|
|
1501
|
+
None, # device nnz
|
|
1384
1502
|
ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1385
1503
|
ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1386
|
-
|
|
1387
|
-
|
|
1504
|
+
None, # triplet values
|
|
1505
|
+
0, # zero_value_mask
|
|
1388
1506
|
masked,
|
|
1507
|
+
None, # summed block offsets
|
|
1508
|
+
None, # summed block indices
|
|
1389
1509
|
ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1390
1510
|
ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
nnz_event,
|
|
1511
|
+
_optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
|
|
1512
|
+
_optional_ctypes_event(nnz_event),
|
|
1394
1513
|
)
|
|
1395
1514
|
|
|
1396
1515
|
y.values.zero_()
|
|
@@ -1617,9 +1736,9 @@ def bsr_mm(
|
|
|
1617
1736
|
If the matrix ``z`` is not provided as input, it will be allocated and treated as zero.
|
|
1618
1737
|
|
|
1619
1738
|
Args:
|
|
1620
|
-
x: Read-only left
|
|
1621
|
-
y: Read-only right
|
|
1622
|
-
z: Mutable
|
|
1739
|
+
x: Read-only left operand of the matrix-matrix product.
|
|
1740
|
+
y: Read-only right operand of the matrix-matrix product.
|
|
1741
|
+
z: Mutable affine operand and result matrix. If ``z`` is not provided, it will be allocated and treated as zero.
|
|
1623
1742
|
alpha: Uniform scaling factor for the ``x @ y`` product
|
|
1624
1743
|
beta: Uniform scaling factor for ``z``
|
|
1625
1744
|
masked: If ``True``, ignore all blocks from ``x @ y`` which are not existing non-zeros of ``y``
|
|
@@ -1649,23 +1768,32 @@ def bsr_mm(
|
|
|
1649
1768
|
else:
|
|
1650
1769
|
z_block_type = wp.mat(shape=z_block_shape, dtype=x.scalar_type)
|
|
1651
1770
|
z = bsr_zeros(x.nrow, y.ncol, block_type=z_block_type, device=x.values.device)
|
|
1771
|
+
z.values.requires_grad = x.requires_grad or y.requires_grad
|
|
1652
1772
|
beta = 0.0
|
|
1653
1773
|
|
|
1654
1774
|
if x.values.device != y.values.device or x.values.device != z.values.device:
|
|
1655
|
-
raise ValueError(
|
|
1775
|
+
raise ValueError(
|
|
1776
|
+
f"All arguments must reside on the same device, got {x.values.device}, {y.values.device} and {z.values.device}"
|
|
1777
|
+
)
|
|
1656
1778
|
|
|
1657
1779
|
if x.scalar_type != y.scalar_type or x.scalar_type != z.scalar_type:
|
|
1658
|
-
raise ValueError(
|
|
1780
|
+
raise ValueError(
|
|
1781
|
+
f"Matrices must have the same scalar type, got {x.scalar_type}, {y.scalar_type} and {z.scalar_type}"
|
|
1782
|
+
)
|
|
1659
1783
|
|
|
1660
1784
|
if (
|
|
1661
1785
|
x.block_shape[0] != z.block_shape[0]
|
|
1662
1786
|
or y.block_shape[1] != z.block_shape[1]
|
|
1663
1787
|
or x.block_shape[1] != y.block_shape[0]
|
|
1664
1788
|
):
|
|
1665
|
-
raise ValueError(
|
|
1789
|
+
raise ValueError(
|
|
1790
|
+
f"Incompatible block sizes for matrix multiplication, got ({x.block_shape}, {y.block_shape}) and ({z.block_shape})"
|
|
1791
|
+
)
|
|
1666
1792
|
|
|
1667
1793
|
if x.nrow != z.nrow or z.ncol != y.ncol or x.ncol != y.nrow:
|
|
1668
|
-
raise ValueError(
|
|
1794
|
+
raise ValueError(
|
|
1795
|
+
f"Incompatible number of rows/columns for matrix multiplication, got ({x.nrow}, {x.ncol}) and ({y.nrow}, {y.ncol})"
|
|
1796
|
+
)
|
|
1669
1797
|
|
|
1670
1798
|
device = z.values.device
|
|
1671
1799
|
|
|
@@ -1696,7 +1824,9 @@ def bsr_mm(
|
|
|
1696
1824
|
mm_nnz = work_arrays._mm_nnz
|
|
1697
1825
|
else:
|
|
1698
1826
|
if device.is_capturing:
|
|
1699
|
-
raise RuntimeError(
|
|
1827
|
+
raise RuntimeError(
|
|
1828
|
+
"`bsr_mm` requires either `reuse_topology=True` or `masked=True` for use in graph capture"
|
|
1829
|
+
)
|
|
1700
1830
|
|
|
1701
1831
|
if work_arrays is None:
|
|
1702
1832
|
work_arrays = bsr_mm_work_arrays()
|
|
@@ -1725,7 +1855,7 @@ def bsr_mm(
|
|
|
1725
1855
|
|
|
1726
1856
|
# Get back total counts on host -- we need a synchronization here
|
|
1727
1857
|
# Use pinned buffer from z, we are going to need it later anyway
|
|
1728
|
-
nnz_buf, _ = z.
|
|
1858
|
+
nnz_buf, _ = z._setup_nnz_transfer()
|
|
1729
1859
|
stream = wp.get_stream(device) if device.is_cuda else None
|
|
1730
1860
|
wp.copy(dest=nnz_buf, src=work_arrays._mm_block_counts, src_offset=x.nnz, count=1, stream=stream)
|
|
1731
1861
|
if device.is_cuda:
|
|
@@ -1782,28 +1912,31 @@ def bsr_mm(
|
|
|
1782
1912
|
from warp.context import runtime
|
|
1783
1913
|
|
|
1784
1914
|
if device.is_cpu:
|
|
1785
|
-
native_func = runtime.core.
|
|
1915
|
+
native_func = runtime.core.bsr_matrix_from_triplets_host
|
|
1786
1916
|
else:
|
|
1787
|
-
native_func = runtime.core.
|
|
1917
|
+
native_func = runtime.core.bsr_matrix_from_triplets_device
|
|
1788
1918
|
|
|
1789
|
-
nnz_buf, nnz_event = z.
|
|
1919
|
+
nnz_buf, nnz_event = z._setup_nnz_transfer()
|
|
1790
1920
|
|
|
1791
1921
|
with wp.ScopedDevice(z.device):
|
|
1792
1922
|
native_func(
|
|
1793
|
-
z.
|
|
1794
|
-
|
|
1923
|
+
z.block_size,
|
|
1924
|
+
0, # scalar_size_in_bytes
|
|
1795
1925
|
z.nrow,
|
|
1926
|
+
z.ncol,
|
|
1796
1927
|
mm_nnz,
|
|
1928
|
+
None, # device nnz
|
|
1797
1929
|
ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1798
1930
|
ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
|
|
1931
|
+
None, # triplet values
|
|
1932
|
+
0, # zero_value_mask
|
|
1933
|
+
False, # masked_topology
|
|
1934
|
+
None, # summed block offsets
|
|
1935
|
+
None, # summed block indices
|
|
1802
1936
|
ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1803
1937
|
ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
nnz_event,
|
|
1938
|
+
_optional_ctypes_pointer(nnz_buf, ctype=ctypes.c_int32),
|
|
1939
|
+
_optional_ctypes_event(nnz_event),
|
|
1807
1940
|
)
|
|
1808
1941
|
|
|
1809
1942
|
# Resize z to fit mm result if necessary
|
|
@@ -1912,7 +2045,7 @@ def _bsr_mv_transpose_kernel(
|
|
|
1912
2045
|
def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) -> wp.array:
|
|
1913
2046
|
# cast a 1d or 2d array to a 1d array with the target dtype, adjusting shape as required
|
|
1914
2047
|
|
|
1915
|
-
scalar_count = array.size *
|
|
2048
|
+
scalar_count = array.size * type_size(array.dtype)
|
|
1916
2049
|
if scalar_count != expected_scalar_count:
|
|
1917
2050
|
raise ValueError(f"Invalid array scalar size, expected {expected_scalar_count}, got {scalar_count}")
|
|
1918
2051
|
|
|
@@ -1920,15 +2053,15 @@ def _vec_array_view(array: wp.array, dtype: type, expected_scalar_count: int) ->
|
|
|
1920
2053
|
return array
|
|
1921
2054
|
|
|
1922
2055
|
if type_scalar_type(array.dtype) != type_scalar_type(dtype):
|
|
1923
|
-
raise ValueError(f"Incompatible scalar types, {type_repr(array.dtype)}
|
|
2056
|
+
raise ValueError(f"Incompatible scalar types, expected {type_repr(array.dtype)}, got {type_repr(dtype)}")
|
|
1924
2057
|
|
|
1925
2058
|
if array.ndim > 2:
|
|
1926
|
-
raise ValueError(f"Incompatible array number of dimensions {array.ndim}")
|
|
2059
|
+
raise ValueError(f"Incompatible array number of dimensions, expected 1 or 2, got {array.ndim}")
|
|
1927
2060
|
|
|
1928
2061
|
if not array.is_contiguous:
|
|
1929
2062
|
raise ValueError("Array must be contiguous")
|
|
1930
2063
|
|
|
1931
|
-
vec_length =
|
|
2064
|
+
vec_length = type_size(dtype)
|
|
1932
2065
|
vec_count = scalar_count // vec_length
|
|
1933
2066
|
if vec_count * vec_length != scalar_count:
|
|
1934
2067
|
raise ValueError(
|
|
@@ -1965,9 +2098,9 @@ def bsr_mv(
|
|
|
1965
2098
|
The ``x`` and ``y`` vectors are allowed to alias.
|
|
1966
2099
|
|
|
1967
2100
|
Args:
|
|
1968
|
-
A: Read-only, left matrix
|
|
1969
|
-
x: Read-only, right vector
|
|
1970
|
-
y: Mutable
|
|
2101
|
+
A: Read-only, left matrix operand of the matrix-vector product.
|
|
2102
|
+
x: Read-only, right vector operand of the matrix-vector product.
|
|
2103
|
+
y: Mutable affine operand and result vector. If ``y`` is not provided, it will be allocated and treated as zero.
|
|
1971
2104
|
alpha: Uniform scaling factor for ``x``. If zero, ``x`` will not be read and may be left uninitialized.
|
|
1972
2105
|
beta: Uniform scaling factor for ``y``. If zero, ``y`` will not be read and may be left uninitialized.
|
|
1973
2106
|
transpose: If ``True``, use the transpose of the matrix ``A``. In this case the result is **non-deterministic**.
|
|
@@ -1990,23 +2123,27 @@ def bsr_mv(
|
|
|
1990
2123
|
# If no output array is provided, allocate one for convenience
|
|
1991
2124
|
y_vec_len = block_shape[0]
|
|
1992
2125
|
y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type)
|
|
1993
|
-
y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype)
|
|
2126
|
+
y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype, requires_grad=x.requires_grad)
|
|
1994
2127
|
beta = 0.0
|
|
1995
2128
|
|
|
1996
2129
|
alpha = A.scalar_type(alpha)
|
|
1997
2130
|
beta = A.scalar_type(beta)
|
|
1998
2131
|
|
|
1999
2132
|
if A.values.device != x.device or A.values.device != y.device:
|
|
2000
|
-
raise ValueError(
|
|
2133
|
+
raise ValueError(
|
|
2134
|
+
f"A, x, and y must reside on the same device, got {A.values.device}, {x.device} and {y.device}"
|
|
2135
|
+
)
|
|
2001
2136
|
|
|
2002
2137
|
if x.ptr == y.ptr:
|
|
2003
2138
|
# Aliasing case, need temporary storage
|
|
2004
2139
|
if work_buffer is None:
|
|
2005
2140
|
work_buffer = wp.empty_like(y)
|
|
2006
2141
|
elif work_buffer.size < y.size:
|
|
2007
|
-
raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}")
|
|
2142
|
+
raise ValueError(f"Work buffer size is insufficient, needs to be at least {y.size}, got {work_buffer.size}")
|
|
2008
2143
|
elif not types_equal(work_buffer.dtype, y.dtype):
|
|
2009
|
-
raise ValueError(
|
|
2144
|
+
raise ValueError(
|
|
2145
|
+
f"Work buffer must have same data type as y, {type_repr(y.dtype)} vs {type_repr(work_buffer.dtype)}"
|
|
2146
|
+
)
|
|
2010
2147
|
|
|
2011
2148
|
# Save old y values before overwriting vector
|
|
2012
2149
|
wp.copy(dest=work_buffer, src=y, count=y.size)
|