warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/types.py
CHANGED
|
@@ -20,7 +20,21 @@ import ctypes
|
|
|
20
20
|
import inspect
|
|
21
21
|
import struct
|
|
22
22
|
import zlib
|
|
23
|
-
from typing import
|
|
23
|
+
from typing import (
|
|
24
|
+
Any,
|
|
25
|
+
Callable,
|
|
26
|
+
Generic,
|
|
27
|
+
List,
|
|
28
|
+
Literal,
|
|
29
|
+
NamedTuple,
|
|
30
|
+
Optional,
|
|
31
|
+
Sequence,
|
|
32
|
+
Tuple,
|
|
33
|
+
TypeVar,
|
|
34
|
+
Union,
|
|
35
|
+
get_args,
|
|
36
|
+
get_origin,
|
|
37
|
+
)
|
|
24
38
|
|
|
25
39
|
import numpy as np
|
|
26
40
|
import numpy.typing as npt
|
|
@@ -56,7 +70,9 @@ class Transformation(Generic[Float]):
|
|
|
56
70
|
|
|
57
71
|
|
|
58
72
|
class Array(Generic[DType]):
|
|
59
|
-
|
|
73
|
+
device: Optional[warp.context.Device]
|
|
74
|
+
dtype: type
|
|
75
|
+
size: int
|
|
60
76
|
|
|
61
77
|
|
|
62
78
|
int_tuple_type_hints = {
|
|
@@ -1139,7 +1155,7 @@ ARRAY_TYPE_FABRIC_INDEXED = 3
|
|
|
1139
1155
|
class launch_bounds_t(ctypes.Structure):
|
|
1140
1156
|
_fields_ = [("shape", ctypes.c_int32 * LAUNCH_MAX_DIMS), ("ndim", ctypes.c_int32), ("size", ctypes.c_size_t)]
|
|
1141
1157
|
|
|
1142
|
-
def __init__(self, shape):
|
|
1158
|
+
def __init__(self, shape: Union[int, Sequence[int]]):
|
|
1143
1159
|
if isinstance(shape, int):
|
|
1144
1160
|
# 1d launch
|
|
1145
1161
|
self.ndim = 1
|
|
@@ -1260,7 +1276,7 @@ _type_size_cache = {
|
|
|
1260
1276
|
}
|
|
1261
1277
|
|
|
1262
1278
|
|
|
1263
|
-
def type_size_in_bytes(dtype):
|
|
1279
|
+
def type_size_in_bytes(dtype: type) -> int:
|
|
1264
1280
|
size = _type_size_cache.get(dtype)
|
|
1265
1281
|
|
|
1266
1282
|
if size is None:
|
|
@@ -1279,7 +1295,7 @@ def type_size_in_bytes(dtype):
|
|
|
1279
1295
|
return size
|
|
1280
1296
|
|
|
1281
1297
|
|
|
1282
|
-
def type_to_warp(dtype):
|
|
1298
|
+
def type_to_warp(dtype: type) -> type:
|
|
1283
1299
|
if dtype == float:
|
|
1284
1300
|
return float32
|
|
1285
1301
|
elif dtype == int:
|
|
@@ -1290,7 +1306,7 @@ def type_to_warp(dtype):
|
|
|
1290
1306
|
return dtype
|
|
1291
1307
|
|
|
1292
1308
|
|
|
1293
|
-
def type_typestr(dtype):
|
|
1309
|
+
def type_typestr(dtype: type) -> str:
|
|
1294
1310
|
if dtype == bool:
|
|
1295
1311
|
return "|b1"
|
|
1296
1312
|
elif dtype == float16:
|
|
@@ -1376,29 +1392,29 @@ def type_is_transformation(t):
|
|
|
1376
1392
|
return getattr(t, "_wp_generic_type_hint_", None) is Transformation
|
|
1377
1393
|
|
|
1378
1394
|
|
|
1379
|
-
value_types = (int, float, builtins.bool) +
|
|
1395
|
+
value_types = (int, float, builtins.bool) + scalar_and_bool_types
|
|
1380
1396
|
|
|
1381
1397
|
|
|
1382
1398
|
# returns true for all value types (int, float, bool, scalars, vectors, matrices)
|
|
1383
|
-
def type_is_value(x):
|
|
1399
|
+
def type_is_value(x: Any) -> builtins.bool:
|
|
1384
1400
|
return x in value_types or hasattr(x, "_wp_scalar_type_")
|
|
1385
1401
|
|
|
1386
1402
|
|
|
1387
1403
|
# equivalent of the above but for values
|
|
1388
|
-
def is_int(x):
|
|
1404
|
+
def is_int(x: Any) -> builtins.bool:
|
|
1389
1405
|
return type_is_int(type(x))
|
|
1390
1406
|
|
|
1391
1407
|
|
|
1392
|
-
def is_float(x):
|
|
1408
|
+
def is_float(x: Any) -> builtins.bool:
|
|
1393
1409
|
return type_is_float(type(x))
|
|
1394
1410
|
|
|
1395
1411
|
|
|
1396
|
-
def is_value(x):
|
|
1412
|
+
def is_value(x: Any) -> builtins.bool:
|
|
1397
1413
|
return type_is_value(type(x))
|
|
1398
1414
|
|
|
1399
1415
|
|
|
1400
|
-
|
|
1401
|
-
|
|
1416
|
+
def is_array(a) -> builtins.bool:
|
|
1417
|
+
"""Return true if the passed *instance* is one of the array types."""
|
|
1402
1418
|
return isinstance(a, array_types)
|
|
1403
1419
|
|
|
1404
1420
|
|
|
@@ -1465,21 +1481,21 @@ def types_equal(a, b, match_generic=False):
|
|
|
1465
1481
|
if a_length is None or b_length is None or a_length == b_length:
|
|
1466
1482
|
return True
|
|
1467
1483
|
|
|
1468
|
-
a_origin =
|
|
1469
|
-
b_origin =
|
|
1484
|
+
a_origin = get_origin(a)
|
|
1485
|
+
b_origin = get_origin(b)
|
|
1470
1486
|
if a_origin is tuple and b_origin is tuple:
|
|
1471
|
-
a_args =
|
|
1472
|
-
b_args =
|
|
1487
|
+
a_args = get_args(a)
|
|
1488
|
+
b_args = get_args(b)
|
|
1473
1489
|
if len(a_args) == len(b_args) and all(
|
|
1474
1490
|
scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b_args)
|
|
1475
1491
|
):
|
|
1476
1492
|
return True
|
|
1477
1493
|
elif a_origin is tuple and isinstance(b, Sequence):
|
|
1478
|
-
a_args =
|
|
1494
|
+
a_args = get_args(a)
|
|
1479
1495
|
if len(a_args) == len(b) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b)):
|
|
1480
1496
|
return True
|
|
1481
1497
|
elif b_origin is tuple and isinstance(a, Sequence):
|
|
1482
|
-
b_args =
|
|
1498
|
+
b_args = get_args(b)
|
|
1483
1499
|
if len(b_args) == len(a) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(b_args, a)):
|
|
1484
1500
|
return True
|
|
1485
1501
|
|
|
@@ -1600,7 +1616,7 @@ def array_ctype_from_interface(interface: dict, dtype=None, owner=None):
|
|
|
1600
1616
|
return array_ctype
|
|
1601
1617
|
|
|
1602
1618
|
|
|
1603
|
-
class array(Array):
|
|
1619
|
+
class array(Array[DType]):
|
|
1604
1620
|
"""A fixed-size multi-dimensional array containing values of the same type.
|
|
1605
1621
|
|
|
1606
1622
|
Attributes:
|
|
@@ -1629,21 +1645,21 @@ class array(Array):
|
|
|
1629
1645
|
|
|
1630
1646
|
def __init__(
|
|
1631
1647
|
self,
|
|
1632
|
-
data:
|
|
1633
|
-
dtype:
|
|
1634
|
-
shape:
|
|
1648
|
+
data: Union[List, Tuple, npt.NDArray, None] = None,
|
|
1649
|
+
dtype: Any = Any,
|
|
1650
|
+
shape: Union[int, Tuple[int, ...], List[int], None] = None,
|
|
1635
1651
|
strides: Optional[Tuple[int, ...]] = None,
|
|
1636
1652
|
length: Optional[int] = None,
|
|
1637
1653
|
ptr: Optional[int] = None,
|
|
1638
1654
|
capacity: Optional[int] = None,
|
|
1639
1655
|
device=None,
|
|
1640
|
-
pinned: bool = False,
|
|
1641
|
-
copy: bool = True,
|
|
1642
|
-
owner: bool = False, # deprecated - pass deleter instead
|
|
1656
|
+
pinned: builtins.bool = False,
|
|
1657
|
+
copy: builtins.bool = True,
|
|
1658
|
+
owner: builtins.bool = False, # deprecated - pass deleter instead
|
|
1643
1659
|
deleter: Optional[Callable[[int, int], None]] = None,
|
|
1644
1660
|
ndim: Optional[int] = None,
|
|
1645
1661
|
grad: Optional[array] = None,
|
|
1646
|
-
requires_grad: bool = False,
|
|
1662
|
+
requires_grad: builtins.bool = False,
|
|
1647
1663
|
):
|
|
1648
1664
|
"""Constructs a new Warp array object
|
|
1649
1665
|
|
|
@@ -2939,7 +2955,7 @@ def from_ipc_handle(
|
|
|
2939
2955
|
|
|
2940
2956
|
# A base class for non-contiguous arrays, providing the implementation of common methods like
|
|
2941
2957
|
# contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
|
|
2942
|
-
class noncontiguous_array_base(
|
|
2958
|
+
class noncontiguous_array_base(Array[T]):
|
|
2943
2959
|
def __init__(self, array_type_id):
|
|
2944
2960
|
self.type_id = array_type_id
|
|
2945
2961
|
self.is_contiguous = False
|
|
@@ -3036,12 +3052,18 @@ def check_index_array(indices, expected_device):
|
|
|
3036
3052
|
raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
|
|
3037
3053
|
|
|
3038
3054
|
|
|
3039
|
-
class indexedarray(noncontiguous_array_base
|
|
3055
|
+
class indexedarray(noncontiguous_array_base):
|
|
3040
3056
|
# member attributes available during code-gen (e.g.: d = arr.shape[0])
|
|
3041
3057
|
# (initialized when needed)
|
|
3042
3058
|
_vars = None
|
|
3043
3059
|
|
|
3044
|
-
def __init__(
|
|
3060
|
+
def __init__(
|
|
3061
|
+
self,
|
|
3062
|
+
data: Optional[array] = None,
|
|
3063
|
+
indices: Union[array, List[array], None] = None,
|
|
3064
|
+
dtype=None,
|
|
3065
|
+
ndim: Optional[int] = None,
|
|
3066
|
+
):
|
|
3045
3067
|
super().__init__(ARRAY_TYPE_INDEXED)
|
|
3046
3068
|
|
|
3047
3069
|
# canonicalize types
|
|
@@ -3232,7 +3254,7 @@ class Tile:
|
|
|
3232
3254
|
return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},wp::tile_shape_t<{','.join(map(str, self.shape))}>,{'true' if requires_grad else 'false'}>()"
|
|
3233
3255
|
else:
|
|
3234
3256
|
# tile will be initialized by another call, e.g.: tile_transpose()
|
|
3235
|
-
return "
|
|
3257
|
+
return "nullptr"
|
|
3236
3258
|
|
|
3237
3259
|
# return total tile size in bytes
|
|
3238
3260
|
def size_in_bytes(self):
|
|
@@ -3634,7 +3656,7 @@ class Volume:
|
|
|
3634
3656
|
instance.id = None
|
|
3635
3657
|
return instance
|
|
3636
3658
|
|
|
3637
|
-
def __init__(self, data: array, copy: bool = True):
|
|
3659
|
+
def __init__(self, data: array, copy: builtins.bool = True):
|
|
3638
3660
|
"""Class representing a sparse grid.
|
|
3639
3661
|
|
|
3640
3662
|
Args:
|
|
@@ -4361,6 +4383,15 @@ class Volume:
|
|
|
4361
4383
|
translation_buf = (ctypes.c_float * 3)(translation[0], translation[1], translation[2])
|
|
4362
4384
|
return transform_buf, translation_buf
|
|
4363
4385
|
|
|
4386
|
+
# nanovdb types for which we instantiate the grid builder
|
|
4387
|
+
# Should be in sync with WP_VOLUME_BUILDER_INSTANTIATE_TYPES in volume_builder.h
|
|
4388
|
+
_supported_allocation_types = [
|
|
4389
|
+
"int32",
|
|
4390
|
+
"float",
|
|
4391
|
+
"Vec3f",
|
|
4392
|
+
"Vec4f",
|
|
4393
|
+
]
|
|
4394
|
+
|
|
4364
4395
|
@classmethod
|
|
4365
4396
|
def allocate_by_tiles(
|
|
4366
4397
|
cls,
|
|
@@ -4388,7 +4419,8 @@ class Volume:
|
|
|
4388
4419
|
or a floating point scalar type (2D N-by-3 array of :class:`warp.float32` or 1D array of `warp.vec3f` values), indicating world space positions.
|
|
4389
4420
|
Repeated points per tile are allowed and will be efficiently deduplicated.
|
|
4390
4421
|
voxel_size (float or array-like): Voxel size(s) of the new volume. Ignored if `transform` is given.
|
|
4391
|
-
bg_value (array-like,
|
|
4422
|
+
bg_value (array-like, scalar or None): Value of unallocated voxels of the volume, also defines the volume's type. An index volume will be created if `bg_value` is ``None``.
|
|
4423
|
+
Other supported grid types are `int`, `float`, `vec3f`, and `vec4f`.
|
|
4392
4424
|
translation (array-like): Translation between the index and world spaces.
|
|
4393
4425
|
transform (array-like): Linear transform between the index and world spaces. If ``None``, deduced from `voxel_size`.
|
|
4394
4426
|
device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
@@ -4420,35 +4452,47 @@ class Volume:
|
|
|
4420
4452
|
translation_buf,
|
|
4421
4453
|
in_world_space,
|
|
4422
4454
|
)
|
|
4423
|
-
elif hasattr(bg_value, "__len__"):
|
|
4424
|
-
volume.id = volume.runtime.core.volume_v_from_tiles_device(
|
|
4425
|
-
volume.device.context,
|
|
4426
|
-
ctypes.c_void_p(tile_points.ptr),
|
|
4427
|
-
tile_points.shape[0],
|
|
4428
|
-
transform_buf,
|
|
4429
|
-
translation_buf,
|
|
4430
|
-
in_world_space,
|
|
4431
|
-
(ctypes.c_float * 3)(bg_value[0], bg_value[1], bg_value[2]),
|
|
4432
|
-
)
|
|
4433
|
-
elif isinstance(bg_value, int):
|
|
4434
|
-
volume.id = volume.runtime.core.volume_i_from_tiles_device(
|
|
4435
|
-
volume.device.context,
|
|
4436
|
-
ctypes.c_void_p(tile_points.ptr),
|
|
4437
|
-
tile_points.shape[0],
|
|
4438
|
-
transform_buf,
|
|
4439
|
-
translation_buf,
|
|
4440
|
-
in_world_space,
|
|
4441
|
-
bg_value,
|
|
4442
|
-
)
|
|
4443
4455
|
else:
|
|
4444
|
-
|
|
4456
|
+
# normalize background value type
|
|
4457
|
+
grid_type = type_to_warp(type(bg_value))
|
|
4458
|
+
if not (is_value(bg_value) or type_is_vector(grid_type)) and (
|
|
4459
|
+
hasattr(bg_value, "__len__") and is_value(bg_value[0])
|
|
4460
|
+
):
|
|
4461
|
+
# non-warp vectors are considered float, for backward compatibility
|
|
4462
|
+
grid_type = vector(len(bg_value), dtype=float)
|
|
4463
|
+
|
|
4464
|
+
# look for corresponding nvdb type
|
|
4465
|
+
try:
|
|
4466
|
+
nvdb_type = next(
|
|
4467
|
+
typ
|
|
4468
|
+
for typ in Volume._supported_allocation_types
|
|
4469
|
+
if types_equal(grid_type, Volume._nvdb_type_to_dtype[typ])
|
|
4470
|
+
)
|
|
4471
|
+
except StopIteration as err:
|
|
4472
|
+
raise TypeError(
|
|
4473
|
+
f"Unsupported bg_value type for volume allocation {type_repr(grid_type)}. Supported volume types are {', '.join(Volume._supported_allocation_types)}."
|
|
4474
|
+
) from err
|
|
4475
|
+
|
|
4476
|
+
# cast to ctype
|
|
4477
|
+
# wrap scalar values in length-1 vectors to handle specific ctype conversion
|
|
4478
|
+
if not type_is_vector(grid_type):
|
|
4479
|
+
grid_type = vector(length=1, dtype=grid_type)
|
|
4480
|
+
|
|
4481
|
+
cvalue = grid_type(bg_value)
|
|
4482
|
+
cvalue_ptr = ctypes.pointer(cvalue)
|
|
4483
|
+
cvalue_size = ctypes.sizeof(cvalue)
|
|
4484
|
+
cvalue_type = nvdb_type.encode("ascii")
|
|
4485
|
+
|
|
4486
|
+
volume.id = volume.runtime.core.volume_from_tiles_device(
|
|
4445
4487
|
volume.device.context,
|
|
4446
4488
|
ctypes.c_void_p(tile_points.ptr),
|
|
4447
4489
|
tile_points.shape[0],
|
|
4448
4490
|
transform_buf,
|
|
4449
4491
|
translation_buf,
|
|
4450
4492
|
in_world_space,
|
|
4451
|
-
|
|
4493
|
+
cvalue_ptr,
|
|
4494
|
+
cvalue_size,
|
|
4495
|
+
cvalue_type,
|
|
4452
4496
|
)
|
|
4453
4497
|
|
|
4454
4498
|
if volume.id == 0:
|
|
@@ -4606,6 +4650,8 @@ def matmul(
|
|
|
4606
4650
|
):
|
|
4607
4651
|
"""Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
4608
4652
|
|
|
4653
|
+
.. versionremoved:: 1.7
|
|
4654
|
+
|
|
4609
4655
|
.. deprecated:: 1.6
|
|
4610
4656
|
Use :doc:`tile primitives </modules/tiles>` instead.
|
|
4611
4657
|
|
|
@@ -4619,80 +4665,8 @@ def matmul(
|
|
|
4619
4665
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
4620
4666
|
while using Tensor Cores
|
|
4621
4667
|
"""
|
|
4622
|
-
from warp.context import runtime
|
|
4623
|
-
|
|
4624
|
-
warp.utils.warn(
|
|
4625
|
-
"wp.matmul() is deprecated and will be removed in a\nfuture version. Use tile primitives instead.",
|
|
4626
|
-
category=DeprecationWarning,
|
|
4627
|
-
stacklevel=2,
|
|
4628
|
-
)
|
|
4629
|
-
|
|
4630
|
-
device = a.device
|
|
4631
|
-
|
|
4632
|
-
if b.device != device or c.device != device or d.device != device:
|
|
4633
|
-
raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
|
|
4634
|
-
|
|
4635
|
-
if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
|
|
4636
|
-
raise RuntimeError(
|
|
4637
|
-
"wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
4638
|
-
)
|
|
4639
|
-
|
|
4640
|
-
if (
|
|
4641
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
4642
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
4643
|
-
or (not c.is_contiguous)
|
|
4644
|
-
or (not d.is_contiguous)
|
|
4645
|
-
):
|
|
4646
|
-
raise RuntimeError(
|
|
4647
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
4648
|
-
)
|
|
4649
4668
|
|
|
4650
|
-
|
|
4651
|
-
n = b.shape[1]
|
|
4652
|
-
k = a.shape[1]
|
|
4653
|
-
if b.shape != (k, n) or c.shape != (m, n) or d.shape != (m, n):
|
|
4654
|
-
raise RuntimeError(
|
|
4655
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
|
|
4656
|
-
)
|
|
4657
|
-
|
|
4658
|
-
if runtime.tape:
|
|
4659
|
-
runtime.tape.record_func(
|
|
4660
|
-
backward=lambda: adj_matmul(a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith),
|
|
4661
|
-
arrays=[a, b, c, d],
|
|
4662
|
-
)
|
|
4663
|
-
if warp.config.verify_autograd_array_access:
|
|
4664
|
-
d.mark_write()
|
|
4665
|
-
a.mark_read()
|
|
4666
|
-
b.mark_read()
|
|
4667
|
-
c.mark_read()
|
|
4668
|
-
|
|
4669
|
-
# cpu fallback if no cuda devices found
|
|
4670
|
-
if device == "cpu":
|
|
4671
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
4672
|
-
d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
|
|
4673
|
-
return
|
|
4674
|
-
|
|
4675
|
-
cc = device.arch
|
|
4676
|
-
ret = runtime.core.cutlass_gemm(
|
|
4677
|
-
device.context,
|
|
4678
|
-
cc,
|
|
4679
|
-
m,
|
|
4680
|
-
n,
|
|
4681
|
-
k,
|
|
4682
|
-
type_typestr(a.dtype).encode(),
|
|
4683
|
-
ctypes.c_void_p(a.ptr),
|
|
4684
|
-
ctypes.c_void_p(b.ptr),
|
|
4685
|
-
ctypes.c_void_p(c.ptr),
|
|
4686
|
-
ctypes.c_void_p(d.ptr),
|
|
4687
|
-
alpha,
|
|
4688
|
-
beta,
|
|
4689
|
-
not a.is_transposed,
|
|
4690
|
-
not b.is_transposed,
|
|
4691
|
-
allow_tf32x3_arith,
|
|
4692
|
-
1,
|
|
4693
|
-
)
|
|
4694
|
-
if not ret:
|
|
4695
|
-
raise RuntimeError("matmul failed.")
|
|
4669
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
4696
4670
|
|
|
4697
4671
|
|
|
4698
4672
|
def adj_matmul(
|
|
@@ -4724,171 +4698,8 @@ def adj_matmul(
|
|
|
4724
4698
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
4725
4699
|
while using Tensor Cores
|
|
4726
4700
|
"""
|
|
4727
|
-
from warp.context import runtime
|
|
4728
|
-
|
|
4729
|
-
device = a.device
|
|
4730
|
-
|
|
4731
|
-
if (
|
|
4732
|
-
b.device != device
|
|
4733
|
-
or c.device != device
|
|
4734
|
-
or adj_a.device != device
|
|
4735
|
-
or adj_b.device != device
|
|
4736
|
-
or adj_c.device != device
|
|
4737
|
-
or adj_d.device != device
|
|
4738
|
-
):
|
|
4739
|
-
raise RuntimeError(
|
|
4740
|
-
"Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
|
|
4741
|
-
)
|
|
4742
|
-
|
|
4743
|
-
if (
|
|
4744
|
-
a.dtype != b.dtype
|
|
4745
|
-
or a.dtype != c.dtype
|
|
4746
|
-
or a.dtype != adj_a.dtype
|
|
4747
|
-
or a.dtype != adj_b.dtype
|
|
4748
|
-
or a.dtype != adj_c.dtype
|
|
4749
|
-
or a.dtype != adj_d.dtype
|
|
4750
|
-
):
|
|
4751
|
-
raise RuntimeError(
|
|
4752
|
-
"wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
|
|
4753
|
-
)
|
|
4754
|
-
|
|
4755
|
-
if (
|
|
4756
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
4757
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
4758
|
-
or (not c.is_contiguous)
|
|
4759
|
-
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
4760
|
-
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
4761
|
-
or (not adj_c.is_contiguous)
|
|
4762
|
-
or (not adj_d.is_contiguous)
|
|
4763
|
-
):
|
|
4764
|
-
raise RuntimeError(
|
|
4765
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
4766
|
-
)
|
|
4767
4701
|
|
|
4768
|
-
|
|
4769
|
-
n = b.shape[1]
|
|
4770
|
-
k = a.shape[1]
|
|
4771
|
-
if (
|
|
4772
|
-
a.shape != (m, k)
|
|
4773
|
-
or b.shape != (k, n)
|
|
4774
|
-
or c.shape != (m, n)
|
|
4775
|
-
or adj_d.shape != (m, n)
|
|
4776
|
-
or adj_a.shape != (m, k)
|
|
4777
|
-
or adj_b.shape != (k, n)
|
|
4778
|
-
or adj_c.shape != (m, n)
|
|
4779
|
-
):
|
|
4780
|
-
raise RuntimeError(
|
|
4781
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
|
|
4782
|
-
a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
|
|
4783
|
-
)
|
|
4784
|
-
)
|
|
4785
|
-
|
|
4786
|
-
# cpu fallback if no cuda devices found
|
|
4787
|
-
if device == "cpu":
|
|
4788
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
4789
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose(), dtype=np_dtype) + adj_a.numpy())
|
|
4790
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose(), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
|
|
4791
|
-
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
4792
|
-
return
|
|
4793
|
-
|
|
4794
|
-
cc = device.arch
|
|
4795
|
-
|
|
4796
|
-
# adj_a
|
|
4797
|
-
if not a.is_transposed:
|
|
4798
|
-
ret = runtime.core.cutlass_gemm(
|
|
4799
|
-
device.context,
|
|
4800
|
-
cc,
|
|
4801
|
-
m,
|
|
4802
|
-
k,
|
|
4803
|
-
n,
|
|
4804
|
-
type_typestr(a.dtype).encode(),
|
|
4805
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4806
|
-
ctypes.c_void_p(b.ptr),
|
|
4807
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4808
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4809
|
-
alpha,
|
|
4810
|
-
1.0,
|
|
4811
|
-
True,
|
|
4812
|
-
b.is_transposed,
|
|
4813
|
-
allow_tf32x3_arith,
|
|
4814
|
-
1,
|
|
4815
|
-
)
|
|
4816
|
-
if not ret:
|
|
4817
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4818
|
-
else:
|
|
4819
|
-
ret = runtime.core.cutlass_gemm(
|
|
4820
|
-
device.context,
|
|
4821
|
-
cc,
|
|
4822
|
-
k,
|
|
4823
|
-
m,
|
|
4824
|
-
n,
|
|
4825
|
-
type_typestr(a.dtype).encode(),
|
|
4826
|
-
ctypes.c_void_p(b.ptr),
|
|
4827
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4828
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4829
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4830
|
-
alpha,
|
|
4831
|
-
1.0,
|
|
4832
|
-
not b.is_transposed,
|
|
4833
|
-
False,
|
|
4834
|
-
allow_tf32x3_arith,
|
|
4835
|
-
1,
|
|
4836
|
-
)
|
|
4837
|
-
if not ret:
|
|
4838
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4839
|
-
|
|
4840
|
-
# adj_b
|
|
4841
|
-
if not b.is_transposed:
|
|
4842
|
-
ret = runtime.core.cutlass_gemm(
|
|
4843
|
-
device.context,
|
|
4844
|
-
cc,
|
|
4845
|
-
k,
|
|
4846
|
-
n,
|
|
4847
|
-
m,
|
|
4848
|
-
type_typestr(a.dtype).encode(),
|
|
4849
|
-
ctypes.c_void_p(a.ptr),
|
|
4850
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4851
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4852
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4853
|
-
alpha,
|
|
4854
|
-
1.0,
|
|
4855
|
-
a.is_transposed,
|
|
4856
|
-
True,
|
|
4857
|
-
allow_tf32x3_arith,
|
|
4858
|
-
1,
|
|
4859
|
-
)
|
|
4860
|
-
if not ret:
|
|
4861
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4862
|
-
else:
|
|
4863
|
-
ret = runtime.core.cutlass_gemm(
|
|
4864
|
-
device.context,
|
|
4865
|
-
cc,
|
|
4866
|
-
n,
|
|
4867
|
-
k,
|
|
4868
|
-
m,
|
|
4869
|
-
type_typestr(a.dtype).encode(),
|
|
4870
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4871
|
-
ctypes.c_void_p(a.ptr),
|
|
4872
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4873
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4874
|
-
alpha,
|
|
4875
|
-
1.0,
|
|
4876
|
-
False,
|
|
4877
|
-
not a.is_transposed,
|
|
4878
|
-
allow_tf32x3_arith,
|
|
4879
|
-
1,
|
|
4880
|
-
)
|
|
4881
|
-
if not ret:
|
|
4882
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4883
|
-
|
|
4884
|
-
# adj_c
|
|
4885
|
-
warp.launch(
|
|
4886
|
-
kernel=warp.utils.add_kernel_2d,
|
|
4887
|
-
dim=adj_c.shape,
|
|
4888
|
-
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
4889
|
-
device=device,
|
|
4890
|
-
record_tape=False,
|
|
4891
|
-
)
|
|
4702
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
4892
4703
|
|
|
4893
4704
|
|
|
4894
4705
|
def batched_matmul(
|
|
@@ -4902,6 +4713,8 @@ def batched_matmul(
|
|
|
4902
4713
|
):
|
|
4903
4714
|
"""Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
4904
4715
|
|
|
4716
|
+
.. versionremoved:: 1.7
|
|
4717
|
+
|
|
4905
4718
|
.. deprecated:: 1.6
|
|
4906
4719
|
Use :doc:`tile primitives </modules/tiles>` instead.
|
|
4907
4720
|
|
|
@@ -4915,107 +4728,8 @@ def batched_matmul(
|
|
|
4915
4728
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
4916
4729
|
while using Tensor Cores
|
|
4917
4730
|
"""
|
|
4918
|
-
from warp.context import runtime
|
|
4919
|
-
|
|
4920
|
-
device = a.device
|
|
4921
|
-
|
|
4922
|
-
if b.device != device or c.device != device or d.device != device:
|
|
4923
|
-
raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
|
|
4924
|
-
|
|
4925
|
-
if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
|
|
4926
|
-
raise RuntimeError(
|
|
4927
|
-
"wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
4928
|
-
)
|
|
4929
|
-
|
|
4930
|
-
if (
|
|
4931
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
4932
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
4933
|
-
or (not c.is_contiguous)
|
|
4934
|
-
or (not d.is_contiguous)
|
|
4935
|
-
):
|
|
4936
|
-
raise RuntimeError(
|
|
4937
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
4938
|
-
)
|
|
4939
|
-
|
|
4940
|
-
m = a.shape[1]
|
|
4941
|
-
n = b.shape[2]
|
|
4942
|
-
k = a.shape[2]
|
|
4943
|
-
batch_count = a.shape[0]
|
|
4944
|
-
if b.shape != (batch_count, k, n) or c.shape != (batch_count, m, n) or d.shape != (batch_count, m, n):
|
|
4945
|
-
raise RuntimeError(
|
|
4946
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
|
|
4947
|
-
)
|
|
4948
4731
|
|
|
4949
|
-
|
|
4950
|
-
runtime.tape.record_func(
|
|
4951
|
-
backward=lambda: adj_batched_matmul(
|
|
4952
|
-
a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith
|
|
4953
|
-
),
|
|
4954
|
-
arrays=[a, b, c, d],
|
|
4955
|
-
)
|
|
4956
|
-
if warp.config.verify_autograd_array_access:
|
|
4957
|
-
d.mark_write()
|
|
4958
|
-
a.mark_read()
|
|
4959
|
-
b.mark_read()
|
|
4960
|
-
c.mark_read()
|
|
4961
|
-
|
|
4962
|
-
# cpu fallback if no cuda devices found
|
|
4963
|
-
if device == "cpu":
|
|
4964
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
4965
|
-
d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
|
|
4966
|
-
return
|
|
4967
|
-
|
|
4968
|
-
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
4969
|
-
max_batch_count = 65535
|
|
4970
|
-
iters = int(batch_count / max_batch_count)
|
|
4971
|
-
remainder = batch_count % max_batch_count
|
|
4972
|
-
|
|
4973
|
-
cc = device.arch
|
|
4974
|
-
for i in range(iters):
|
|
4975
|
-
idx_start = i * max_batch_count
|
|
4976
|
-
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
4977
|
-
ret = runtime.core.cutlass_gemm(
|
|
4978
|
-
device.context,
|
|
4979
|
-
cc,
|
|
4980
|
-
m,
|
|
4981
|
-
n,
|
|
4982
|
-
k,
|
|
4983
|
-
type_typestr(a.dtype).encode(),
|
|
4984
|
-
ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
|
|
4985
|
-
ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
|
|
4986
|
-
ctypes.c_void_p(c[idx_start:idx_end, :, :].ptr),
|
|
4987
|
-
ctypes.c_void_p(d[idx_start:idx_end, :, :].ptr),
|
|
4988
|
-
alpha,
|
|
4989
|
-
beta,
|
|
4990
|
-
not a.is_transposed,
|
|
4991
|
-
not b.is_transposed,
|
|
4992
|
-
allow_tf32x3_arith,
|
|
4993
|
-
max_batch_count,
|
|
4994
|
-
)
|
|
4995
|
-
if not ret:
|
|
4996
|
-
raise RuntimeError("Batched matmul failed.")
|
|
4997
|
-
|
|
4998
|
-
idx_start = iters * max_batch_count
|
|
4999
|
-
ret = runtime.core.cutlass_gemm(
|
|
5000
|
-
device.context,
|
|
5001
|
-
cc,
|
|
5002
|
-
m,
|
|
5003
|
-
n,
|
|
5004
|
-
k,
|
|
5005
|
-
type_typestr(a.dtype).encode(),
|
|
5006
|
-
ctypes.c_void_p(a[idx_start:, :, :].ptr),
|
|
5007
|
-
ctypes.c_void_p(b[idx_start:, :, :].ptr),
|
|
5008
|
-
ctypes.c_void_p(c[idx_start:, :, :].ptr),
|
|
5009
|
-
ctypes.c_void_p(d[idx_start:, :, :].ptr),
|
|
5010
|
-
alpha,
|
|
5011
|
-
beta,
|
|
5012
|
-
not a.is_transposed,
|
|
5013
|
-
not b.is_transposed,
|
|
5014
|
-
allow_tf32x3_arith,
|
|
5015
|
-
remainder,
|
|
5016
|
-
)
|
|
5017
|
-
if not ret:
|
|
5018
|
-
raise RuntimeError("Batched matmul failed.")
|
|
4732
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
5019
4733
|
|
|
5020
4734
|
|
|
5021
4735
|
def adj_batched_matmul(
|
|
@@ -5045,270 +4759,8 @@ def adj_batched_matmul(
|
|
|
5045
4759
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
5046
4760
|
while using Tensor Cores
|
|
5047
4761
|
"""
|
|
5048
|
-
from warp.context import runtime
|
|
5049
4762
|
|
|
5050
|
-
|
|
5051
|
-
|
|
5052
|
-
if (
|
|
5053
|
-
b.device != device
|
|
5054
|
-
or c.device != device
|
|
5055
|
-
or adj_a.device != device
|
|
5056
|
-
or adj_b.device != device
|
|
5057
|
-
or adj_c.device != device
|
|
5058
|
-
or adj_d.device != device
|
|
5059
|
-
):
|
|
5060
|
-
raise RuntimeError(
|
|
5061
|
-
"Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
|
|
5062
|
-
)
|
|
5063
|
-
|
|
5064
|
-
if (
|
|
5065
|
-
a.dtype != b.dtype
|
|
5066
|
-
or a.dtype != c.dtype
|
|
5067
|
-
or a.dtype != adj_a.dtype
|
|
5068
|
-
or a.dtype != adj_b.dtype
|
|
5069
|
-
or a.dtype != adj_c.dtype
|
|
5070
|
-
or a.dtype != adj_d.dtype
|
|
5071
|
-
):
|
|
5072
|
-
raise RuntimeError(
|
|
5073
|
-
"wp.adj_batched_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
|
|
5074
|
-
)
|
|
5075
|
-
|
|
5076
|
-
m = a.shape[1]
|
|
5077
|
-
n = b.shape[2]
|
|
5078
|
-
k = a.shape[2]
|
|
5079
|
-
batch_count = a.shape[0]
|
|
5080
|
-
if (
|
|
5081
|
-
b.shape != (batch_count, k, n)
|
|
5082
|
-
or c.shape != (batch_count, m, n)
|
|
5083
|
-
or adj_d.shape != (batch_count, m, n)
|
|
5084
|
-
or adj_a.shape != (batch_count, m, k)
|
|
5085
|
-
or adj_b.shape != (batch_count, k, n)
|
|
5086
|
-
or adj_c.shape != (batch_count, m, n)
|
|
5087
|
-
):
|
|
5088
|
-
raise RuntimeError(
|
|
5089
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
|
|
5090
|
-
a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
|
|
5091
|
-
)
|
|
5092
|
-
)
|
|
5093
|
-
|
|
5094
|
-
if (
|
|
5095
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
5096
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
5097
|
-
or (not c.is_contiguous)
|
|
5098
|
-
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
5099
|
-
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
5100
|
-
or (not adj_c.is_contiguous)
|
|
5101
|
-
or (not adj_d.is_contiguous)
|
|
5102
|
-
):
|
|
5103
|
-
raise RuntimeError(
|
|
5104
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
5105
|
-
)
|
|
5106
|
-
|
|
5107
|
-
# cpu fallback if no cuda devices found
|
|
5108
|
-
if device == "cpu":
|
|
5109
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
5110
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1)), dtype=np_dtype) + adj_a.numpy())
|
|
5111
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
|
|
5112
|
-
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
5113
|
-
return
|
|
5114
|
-
|
|
5115
|
-
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
5116
|
-
max_batch_count = 65535
|
|
5117
|
-
iters = int(batch_count / max_batch_count)
|
|
5118
|
-
remainder = batch_count % max_batch_count
|
|
5119
|
-
|
|
5120
|
-
cc = device.arch
|
|
5121
|
-
|
|
5122
|
-
for i in range(iters):
|
|
5123
|
-
idx_start = i * max_batch_count
|
|
5124
|
-
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
5125
|
-
|
|
5126
|
-
# adj_a
|
|
5127
|
-
if not a.is_transposed:
|
|
5128
|
-
ret = runtime.core.cutlass_gemm(
|
|
5129
|
-
device.context,
|
|
5130
|
-
cc,
|
|
5131
|
-
m,
|
|
5132
|
-
k,
|
|
5133
|
-
n,
|
|
5134
|
-
type_typestr(a.dtype).encode(),
|
|
5135
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5136
|
-
ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
|
|
5137
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5138
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5139
|
-
alpha,
|
|
5140
|
-
1.0,
|
|
5141
|
-
True,
|
|
5142
|
-
b.is_transposed,
|
|
5143
|
-
allow_tf32x3_arith,
|
|
5144
|
-
max_batch_count,
|
|
5145
|
-
)
|
|
5146
|
-
if not ret:
|
|
5147
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5148
|
-
else:
|
|
5149
|
-
ret = runtime.core.cutlass_gemm(
|
|
5150
|
-
device.context,
|
|
5151
|
-
cc,
|
|
5152
|
-
k,
|
|
5153
|
-
m,
|
|
5154
|
-
n,
|
|
5155
|
-
type_typestr(a.dtype).encode(),
|
|
5156
|
-
ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
|
|
5157
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5158
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5159
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5160
|
-
alpha,
|
|
5161
|
-
1.0,
|
|
5162
|
-
not b.is_transposed,
|
|
5163
|
-
False,
|
|
5164
|
-
allow_tf32x3_arith,
|
|
5165
|
-
max_batch_count,
|
|
5166
|
-
)
|
|
5167
|
-
if not ret:
|
|
5168
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5169
|
-
|
|
5170
|
-
# adj_b
|
|
5171
|
-
if not b.is_transposed:
|
|
5172
|
-
ret = runtime.core.cutlass_gemm(
|
|
5173
|
-
device.context,
|
|
5174
|
-
cc,
|
|
5175
|
-
k,
|
|
5176
|
-
n,
|
|
5177
|
-
m,
|
|
5178
|
-
type_typestr(a.dtype).encode(),
|
|
5179
|
-
ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
|
|
5180
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5181
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5182
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5183
|
-
alpha,
|
|
5184
|
-
1.0,
|
|
5185
|
-
a.is_transposed,
|
|
5186
|
-
True,
|
|
5187
|
-
allow_tf32x3_arith,
|
|
5188
|
-
max_batch_count,
|
|
5189
|
-
)
|
|
5190
|
-
if not ret:
|
|
5191
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5192
|
-
else:
|
|
5193
|
-
ret = runtime.core.cutlass_gemm(
|
|
5194
|
-
device.context,
|
|
5195
|
-
cc,
|
|
5196
|
-
n,
|
|
5197
|
-
k,
|
|
5198
|
-
m,
|
|
5199
|
-
type_typestr(a.dtype).encode(),
|
|
5200
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5201
|
-
ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
|
|
5202
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5203
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5204
|
-
alpha,
|
|
5205
|
-
1.0,
|
|
5206
|
-
False,
|
|
5207
|
-
not a.is_transposed,
|
|
5208
|
-
allow_tf32x3_arith,
|
|
5209
|
-
max_batch_count,
|
|
5210
|
-
)
|
|
5211
|
-
if not ret:
|
|
5212
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5213
|
-
|
|
5214
|
-
idx_start = iters * max_batch_count
|
|
5215
|
-
|
|
5216
|
-
# adj_a
|
|
5217
|
-
if not a.is_transposed:
|
|
5218
|
-
ret = runtime.core.cutlass_gemm(
|
|
5219
|
-
device.context,
|
|
5220
|
-
cc,
|
|
5221
|
-
m,
|
|
5222
|
-
k,
|
|
5223
|
-
n,
|
|
5224
|
-
type_typestr(a.dtype).encode(),
|
|
5225
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5226
|
-
ctypes.c_void_p(b[idx_start:, :, :].ptr),
|
|
5227
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5228
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5229
|
-
alpha,
|
|
5230
|
-
1.0,
|
|
5231
|
-
True,
|
|
5232
|
-
b.is_transposed,
|
|
5233
|
-
allow_tf32x3_arith,
|
|
5234
|
-
remainder,
|
|
5235
|
-
)
|
|
5236
|
-
if not ret:
|
|
5237
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5238
|
-
else:
|
|
5239
|
-
ret = runtime.core.cutlass_gemm(
|
|
5240
|
-
device.context,
|
|
5241
|
-
cc,
|
|
5242
|
-
k,
|
|
5243
|
-
m,
|
|
5244
|
-
n,
|
|
5245
|
-
type_typestr(a.dtype).encode(),
|
|
5246
|
-
ctypes.c_void_p(b[idx_start:, :, :].ptr),
|
|
5247
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5248
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5249
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5250
|
-
alpha,
|
|
5251
|
-
1.0,
|
|
5252
|
-
not b.is_transposed,
|
|
5253
|
-
False,
|
|
5254
|
-
allow_tf32x3_arith,
|
|
5255
|
-
remainder,
|
|
5256
|
-
)
|
|
5257
|
-
if not ret:
|
|
5258
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5259
|
-
|
|
5260
|
-
# adj_b
|
|
5261
|
-
if not b.is_transposed:
|
|
5262
|
-
ret = runtime.core.cutlass_gemm(
|
|
5263
|
-
device.context,
|
|
5264
|
-
cc,
|
|
5265
|
-
k,
|
|
5266
|
-
n,
|
|
5267
|
-
m,
|
|
5268
|
-
type_typestr(a.dtype).encode(),
|
|
5269
|
-
ctypes.c_void_p(a[idx_start:, :, :].ptr),
|
|
5270
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5271
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5272
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5273
|
-
alpha,
|
|
5274
|
-
1.0,
|
|
5275
|
-
a.is_transposed,
|
|
5276
|
-
True,
|
|
5277
|
-
allow_tf32x3_arith,
|
|
5278
|
-
remainder,
|
|
5279
|
-
)
|
|
5280
|
-
if not ret:
|
|
5281
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5282
|
-
else:
|
|
5283
|
-
ret = runtime.core.cutlass_gemm(
|
|
5284
|
-
device.context,
|
|
5285
|
-
cc,
|
|
5286
|
-
n,
|
|
5287
|
-
k,
|
|
5288
|
-
m,
|
|
5289
|
-
type_typestr(a.dtype).encode(),
|
|
5290
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5291
|
-
ctypes.c_void_p(a[idx_start:, :, :].ptr),
|
|
5292
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5293
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5294
|
-
alpha,
|
|
5295
|
-
1.0,
|
|
5296
|
-
False,
|
|
5297
|
-
not a.is_transposed,
|
|
5298
|
-
allow_tf32x3_arith,
|
|
5299
|
-
remainder,
|
|
5300
|
-
)
|
|
5301
|
-
if not ret:
|
|
5302
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5303
|
-
|
|
5304
|
-
# adj_c
|
|
5305
|
-
warp.launch(
|
|
5306
|
-
kernel=warp.utils.add_kernel_3d,
|
|
5307
|
-
dim=adj_c.shape,
|
|
5308
|
-
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
5309
|
-
device=device,
|
|
5310
|
-
record_tape=False,
|
|
5311
|
-
)
|
|
4763
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
5312
4764
|
|
|
5313
4765
|
|
|
5314
4766
|
class HashGrid:
|
|
@@ -5691,7 +5143,7 @@ simple_type_codes = {
|
|
|
5691
5143
|
}
|
|
5692
5144
|
|
|
5693
5145
|
|
|
5694
|
-
def get_type_code(arg_type):
|
|
5146
|
+
def get_type_code(arg_type: type) -> str:
|
|
5695
5147
|
if arg_type == Any:
|
|
5696
5148
|
# special case for generics
|
|
5697
5149
|
# note: since Python 3.11 Any is a type, so we check for it first
|
|
@@ -5755,8 +5207,8 @@ def get_type_code(arg_type):
|
|
|
5755
5207
|
raise TypeError(f"Unrecognized type '{arg_type}'")
|
|
5756
5208
|
|
|
5757
5209
|
|
|
5758
|
-
def get_signature(arg_types, func_name=None, arg_names=None):
|
|
5759
|
-
type_codes = []
|
|
5210
|
+
def get_signature(arg_types: List[type], func_name: Optional[str] = None, arg_names: Optional[List[str]] = None) -> str:
|
|
5211
|
+
type_codes: List[str] = []
|
|
5760
5212
|
for i, arg_type in enumerate(arg_types):
|
|
5761
5213
|
try:
|
|
5762
5214
|
type_codes.append(get_type_code(arg_type))
|