warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/autograd.py +12 -2
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +463 -372
- warp/codegen.py +196 -124
- warp/config.py +42 -6
- warp/context.py +496 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_cloth.py +1 -1
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/distributed/example_jacobi_mpi.py +507 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/field.py +11 -1
- warp/fem/field/nodal_field.py +56 -88
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +16 -13
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +7 -20
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
- warp/jax_experimental/ffi.py +702 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +312 -116
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +100 -11
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/render/render_opengl.py +19 -17
- warp/render/render_usd.py +93 -3
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +32 -19
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/inertia.py +189 -156
- warp/sim/integrator_euler.py +8 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +8 -5
- warp/sim/model.py +71 -25
- warp/sim/render.py +4 -0
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +217 -20
- warp/tests/__main__.py +0 -15
- warp/tests/assets/torus.usda +1 -1
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
- warp/tests/sim/test_inertia.py +161 -0
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/sim/test_xpbd.py +399 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_codegen.py +24 -3
- warp/tests/test_examples.py +40 -38
- warp/tests/test_fem.py +98 -14
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +577 -156
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +356 -151
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +336 -178
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -62
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +175 -666
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/types.py
CHANGED
|
@@ -20,7 +20,21 @@ import ctypes
|
|
|
20
20
|
import inspect
|
|
21
21
|
import struct
|
|
22
22
|
import zlib
|
|
23
|
-
from typing import
|
|
23
|
+
from typing import (
|
|
24
|
+
Any,
|
|
25
|
+
Callable,
|
|
26
|
+
Generic,
|
|
27
|
+
List,
|
|
28
|
+
Literal,
|
|
29
|
+
NamedTuple,
|
|
30
|
+
Optional,
|
|
31
|
+
Sequence,
|
|
32
|
+
Tuple,
|
|
33
|
+
TypeVar,
|
|
34
|
+
Union,
|
|
35
|
+
get_args,
|
|
36
|
+
get_origin,
|
|
37
|
+
)
|
|
24
38
|
|
|
25
39
|
import numpy as np
|
|
26
40
|
import numpy.typing as npt
|
|
@@ -56,7 +70,9 @@ class Transformation(Generic[Float]):
|
|
|
56
70
|
|
|
57
71
|
|
|
58
72
|
class Array(Generic[DType]):
|
|
59
|
-
|
|
73
|
+
device: Optional[warp.context.Device]
|
|
74
|
+
dtype: type
|
|
75
|
+
size: int
|
|
60
76
|
|
|
61
77
|
|
|
62
78
|
int_tuple_type_hints = {
|
|
@@ -262,7 +278,12 @@ def vector(length, dtype):
|
|
|
262
278
|
def __str__(self):
|
|
263
279
|
return f"[{', '.join(map(str, self))}]"
|
|
264
280
|
|
|
281
|
+
def __repr__(self):
|
|
282
|
+
return f"{type_repr(self)}([{', '.join(map(repr, self))}])"
|
|
283
|
+
|
|
265
284
|
def __eq__(self, other):
|
|
285
|
+
if self._length_ != len(other):
|
|
286
|
+
return False
|
|
266
287
|
for i in range(self._length_):
|
|
267
288
|
if self[i] != other[i]:
|
|
268
289
|
return False
|
|
@@ -405,7 +426,11 @@ def matrix(shape, dtype):
|
|
|
405
426
|
return "[" + ",\n ".join(row_str) + "]"
|
|
406
427
|
|
|
407
428
|
def __eq__(self, other):
|
|
429
|
+
if self._shape_[0] != len(other):
|
|
430
|
+
return False
|
|
408
431
|
for i in range(self._shape_[0]):
|
|
432
|
+
if self._shape_[1] != len(other[i]):
|
|
433
|
+
return False
|
|
409
434
|
for j in range(self._shape_[1]):
|
|
410
435
|
if self[i][j] != other[i][j]:
|
|
411
436
|
return False
|
|
@@ -1139,7 +1164,7 @@ ARRAY_TYPE_FABRIC_INDEXED = 3
|
|
|
1139
1164
|
class launch_bounds_t(ctypes.Structure):
|
|
1140
1165
|
_fields_ = [("shape", ctypes.c_int32 * LAUNCH_MAX_DIMS), ("ndim", ctypes.c_int32), ("size", ctypes.c_size_t)]
|
|
1141
1166
|
|
|
1142
|
-
def __init__(self, shape):
|
|
1167
|
+
def __init__(self, shape: Union[int, Sequence[int]]):
|
|
1143
1168
|
if isinstance(shape, int):
|
|
1144
1169
|
# 1d launch
|
|
1145
1170
|
self.ndim = 1
|
|
@@ -1260,7 +1285,7 @@ _type_size_cache = {
|
|
|
1260
1285
|
}
|
|
1261
1286
|
|
|
1262
1287
|
|
|
1263
|
-
def type_size_in_bytes(dtype):
|
|
1288
|
+
def type_size_in_bytes(dtype: type) -> int:
|
|
1264
1289
|
size = _type_size_cache.get(dtype)
|
|
1265
1290
|
|
|
1266
1291
|
if size is None:
|
|
@@ -1279,7 +1304,7 @@ def type_size_in_bytes(dtype):
|
|
|
1279
1304
|
return size
|
|
1280
1305
|
|
|
1281
1306
|
|
|
1282
|
-
def type_to_warp(dtype):
|
|
1307
|
+
def type_to_warp(dtype: type) -> type:
|
|
1283
1308
|
if dtype == float:
|
|
1284
1309
|
return float32
|
|
1285
1310
|
elif dtype == int:
|
|
@@ -1290,7 +1315,7 @@ def type_to_warp(dtype):
|
|
|
1290
1315
|
return dtype
|
|
1291
1316
|
|
|
1292
1317
|
|
|
1293
|
-
def type_typestr(dtype):
|
|
1318
|
+
def type_typestr(dtype: type) -> str:
|
|
1294
1319
|
if dtype == bool:
|
|
1295
1320
|
return "|b1"
|
|
1296
1321
|
elif dtype == float16:
|
|
@@ -1323,22 +1348,67 @@ def type_typestr(dtype):
|
|
|
1323
1348
|
raise Exception("Unknown ctype")
|
|
1324
1349
|
|
|
1325
1350
|
|
|
1351
|
+
def scalar_short_name(t):
|
|
1352
|
+
if t == float32:
|
|
1353
|
+
return "f"
|
|
1354
|
+
elif t == float64:
|
|
1355
|
+
return "d"
|
|
1356
|
+
elif t == int8:
|
|
1357
|
+
return "b"
|
|
1358
|
+
elif t == int16:
|
|
1359
|
+
return "s"
|
|
1360
|
+
elif t == int32:
|
|
1361
|
+
return "i"
|
|
1362
|
+
elif t == int64:
|
|
1363
|
+
return "l"
|
|
1364
|
+
elif t == uint8:
|
|
1365
|
+
return "ub"
|
|
1366
|
+
elif t == uint16:
|
|
1367
|
+
return "us"
|
|
1368
|
+
elif t == uint32:
|
|
1369
|
+
return "ui"
|
|
1370
|
+
elif t == uint64:
|
|
1371
|
+
return "ul"
|
|
1372
|
+
return None
|
|
1373
|
+
|
|
1374
|
+
|
|
1326
1375
|
# converts any known type to a human readable string, good for error messages, reporting etc
|
|
1327
1376
|
def type_repr(t):
|
|
1328
1377
|
if is_array(t):
|
|
1329
|
-
|
|
1378
|
+
if t.device is None:
|
|
1379
|
+
# array is used as a type annotation - display ndim instead of shape
|
|
1380
|
+
return f"array(ndim={t.ndim}, dtype={type_repr(t.dtype)})"
|
|
1381
|
+
return f"array(shape={t.shape}, dtype={type_repr(t.dtype)})"
|
|
1330
1382
|
if is_tile(t):
|
|
1331
|
-
return
|
|
1332
|
-
if type_is_vector(t):
|
|
1333
|
-
return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
|
|
1334
|
-
if type_is_matrix(t):
|
|
1335
|
-
return str(f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={t._wp_scalar_type_})")
|
|
1383
|
+
return f"tile(shape={t.shape}, dtype={type_repr(t.dtype)})"
|
|
1336
1384
|
if isinstance(t, warp.codegen.Struct):
|
|
1337
1385
|
return type_repr(t.cls)
|
|
1386
|
+
sn = None
|
|
1387
|
+
if hasattr(t, "_wp_scalar_type_"):
|
|
1388
|
+
sn = scalar_short_name(t._wp_scalar_type_)
|
|
1389
|
+
if type_is_transformation(t):
|
|
1390
|
+
if sn is not None:
|
|
1391
|
+
return f"transform{sn}"
|
|
1392
|
+
return f"transform(dtype={type_repr(t._wp_scalar_type_)})"
|
|
1393
|
+
if type_is_quaternion(t):
|
|
1394
|
+
if sn is not None:
|
|
1395
|
+
return f"quat{sn}"
|
|
1396
|
+
return f"quat(dtype={type_repr(t._wp_scalar_type_)})"
|
|
1397
|
+
if type_is_vector(t):
|
|
1398
|
+
if sn is not None and t._shape_[0] <= 4:
|
|
1399
|
+
return f"vec{t._shape_[0]}{sn}"
|
|
1400
|
+
return f"vector(length={t._shape_[0]}, dtype={type_repr(t._wp_scalar_type_)})"
|
|
1401
|
+
if type_is_matrix(t):
|
|
1402
|
+
if sn is not None and t._shape_[0] <= 4 and t._shape_[1] <= 4:
|
|
1403
|
+
return f"mat{t._shape_[0]}{t._shape_[1]}({sn})"
|
|
1404
|
+
return f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={type_repr(t._wp_scalar_type_)})"
|
|
1338
1405
|
if t in scalar_types:
|
|
1339
1406
|
return t.__name__
|
|
1340
1407
|
|
|
1341
|
-
name = getattr(t, "
|
|
1408
|
+
name = getattr(t, "__name__", None)
|
|
1409
|
+
if name is None:
|
|
1410
|
+
return repr(t)
|
|
1411
|
+
name = getattr(t, "__qualname__", name)
|
|
1342
1412
|
return t.__module__ + "." + name
|
|
1343
1413
|
|
|
1344
1414
|
|
|
@@ -1376,33 +1446,33 @@ def type_is_transformation(t):
|
|
|
1376
1446
|
return getattr(t, "_wp_generic_type_hint_", None) is Transformation
|
|
1377
1447
|
|
|
1378
1448
|
|
|
1379
|
-
value_types = (int, float, builtins.bool) +
|
|
1449
|
+
value_types = (int, float, builtins.bool) + scalar_and_bool_types
|
|
1380
1450
|
|
|
1381
1451
|
|
|
1382
1452
|
# returns true for all value types (int, float, bool, scalars, vectors, matrices)
|
|
1383
|
-
def type_is_value(x):
|
|
1453
|
+
def type_is_value(x: Any) -> builtins.bool:
|
|
1384
1454
|
return x in value_types or hasattr(x, "_wp_scalar_type_")
|
|
1385
1455
|
|
|
1386
1456
|
|
|
1387
1457
|
# equivalent of the above but for values
|
|
1388
|
-
def is_int(x):
|
|
1458
|
+
def is_int(x: Any) -> builtins.bool:
|
|
1389
1459
|
return type_is_int(type(x))
|
|
1390
1460
|
|
|
1391
1461
|
|
|
1392
|
-
def is_float(x):
|
|
1462
|
+
def is_float(x: Any) -> builtins.bool:
|
|
1393
1463
|
return type_is_float(type(x))
|
|
1394
1464
|
|
|
1395
1465
|
|
|
1396
|
-
def is_value(x):
|
|
1466
|
+
def is_value(x: Any) -> builtins.bool:
|
|
1397
1467
|
return type_is_value(type(x))
|
|
1398
1468
|
|
|
1399
1469
|
|
|
1400
|
-
|
|
1401
|
-
|
|
1470
|
+
def is_array(a) -> builtins.bool:
|
|
1471
|
+
"""Return true if the passed *instance* is one of the array types."""
|
|
1402
1472
|
return isinstance(a, array_types)
|
|
1403
1473
|
|
|
1404
1474
|
|
|
1405
|
-
def scalars_equal(a, b, match_generic):
|
|
1475
|
+
def scalars_equal(a, b, match_generic=False):
|
|
1406
1476
|
# convert to canonical types
|
|
1407
1477
|
if a == float:
|
|
1408
1478
|
a = float32
|
|
@@ -1465,21 +1535,21 @@ def types_equal(a, b, match_generic=False):
|
|
|
1465
1535
|
if a_length is None or b_length is None or a_length == b_length:
|
|
1466
1536
|
return True
|
|
1467
1537
|
|
|
1468
|
-
a_origin =
|
|
1469
|
-
b_origin =
|
|
1538
|
+
a_origin = get_origin(a)
|
|
1539
|
+
b_origin = get_origin(b)
|
|
1470
1540
|
if a_origin is tuple and b_origin is tuple:
|
|
1471
|
-
a_args =
|
|
1472
|
-
b_args =
|
|
1541
|
+
a_args = get_args(a)
|
|
1542
|
+
b_args = get_args(b)
|
|
1473
1543
|
if len(a_args) == len(b_args) and all(
|
|
1474
1544
|
scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b_args)
|
|
1475
1545
|
):
|
|
1476
1546
|
return True
|
|
1477
1547
|
elif a_origin is tuple and isinstance(b, Sequence):
|
|
1478
|
-
a_args =
|
|
1548
|
+
a_args = get_args(a)
|
|
1479
1549
|
if len(a_args) == len(b) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b)):
|
|
1480
1550
|
return True
|
|
1481
1551
|
elif b_origin is tuple and isinstance(a, Sequence):
|
|
1482
|
-
b_args =
|
|
1552
|
+
b_args = get_args(b)
|
|
1483
1553
|
if len(b_args) == len(a) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(b_args, a)):
|
|
1484
1554
|
return True
|
|
1485
1555
|
|
|
@@ -1600,7 +1670,7 @@ def array_ctype_from_interface(interface: dict, dtype=None, owner=None):
|
|
|
1600
1670
|
return array_ctype
|
|
1601
1671
|
|
|
1602
1672
|
|
|
1603
|
-
class array(Array):
|
|
1673
|
+
class array(Array[DType]):
|
|
1604
1674
|
"""A fixed-size multi-dimensional array containing values of the same type.
|
|
1605
1675
|
|
|
1606
1676
|
Attributes:
|
|
@@ -1629,21 +1699,21 @@ class array(Array):
|
|
|
1629
1699
|
|
|
1630
1700
|
def __init__(
|
|
1631
1701
|
self,
|
|
1632
|
-
data:
|
|
1633
|
-
dtype:
|
|
1634
|
-
shape:
|
|
1702
|
+
data: Union[List, Tuple, npt.NDArray, None] = None,
|
|
1703
|
+
dtype: Any = Any,
|
|
1704
|
+
shape: Union[int, Tuple[int, ...], List[int], None] = None,
|
|
1635
1705
|
strides: Optional[Tuple[int, ...]] = None,
|
|
1636
1706
|
length: Optional[int] = None,
|
|
1637
1707
|
ptr: Optional[int] = None,
|
|
1638
1708
|
capacity: Optional[int] = None,
|
|
1639
1709
|
device=None,
|
|
1640
|
-
pinned: bool = False,
|
|
1641
|
-
copy: bool = True,
|
|
1642
|
-
owner: bool = False, # deprecated - pass deleter instead
|
|
1710
|
+
pinned: builtins.bool = False,
|
|
1711
|
+
copy: builtins.bool = True,
|
|
1712
|
+
owner: builtins.bool = False, # deprecated - pass deleter instead
|
|
1643
1713
|
deleter: Optional[Callable[[int, int], None]] = None,
|
|
1644
1714
|
ndim: Optional[int] = None,
|
|
1645
1715
|
grad: Optional[array] = None,
|
|
1646
|
-
requires_grad: bool = False,
|
|
1716
|
+
requires_grad: builtins.bool = False,
|
|
1647
1717
|
):
|
|
1648
1718
|
"""Constructs a new Warp array object
|
|
1649
1719
|
|
|
@@ -2219,6 +2289,9 @@ class array(Array):
|
|
|
2219
2289
|
else:
|
|
2220
2290
|
return str(self.numpy())
|
|
2221
2291
|
|
|
2292
|
+
def __repr__(self):
|
|
2293
|
+
return type_repr(self)
|
|
2294
|
+
|
|
2222
2295
|
def __getitem__(self, key):
|
|
2223
2296
|
if isinstance(key, int):
|
|
2224
2297
|
if self.ndim == 1:
|
|
@@ -2939,7 +3012,7 @@ def from_ipc_handle(
|
|
|
2939
3012
|
|
|
2940
3013
|
# A base class for non-contiguous arrays, providing the implementation of common methods like
|
|
2941
3014
|
# contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
|
|
2942
|
-
class noncontiguous_array_base(
|
|
3015
|
+
class noncontiguous_array_base(Array[T]):
|
|
2943
3016
|
def __init__(self, array_type_id):
|
|
2944
3017
|
self.type_id = array_type_id
|
|
2945
3018
|
self.is_contiguous = False
|
|
@@ -3036,12 +3109,18 @@ def check_index_array(indices, expected_device):
|
|
|
3036
3109
|
raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
|
|
3037
3110
|
|
|
3038
3111
|
|
|
3039
|
-
class indexedarray(noncontiguous_array_base
|
|
3112
|
+
class indexedarray(noncontiguous_array_base):
|
|
3040
3113
|
# member attributes available during code-gen (e.g.: d = arr.shape[0])
|
|
3041
3114
|
# (initialized when needed)
|
|
3042
3115
|
_vars = None
|
|
3043
3116
|
|
|
3044
|
-
def __init__(
|
|
3117
|
+
def __init__(
|
|
3118
|
+
self,
|
|
3119
|
+
data: Optional[array] = None,
|
|
3120
|
+
indices: Union[array, List[array], None] = None,
|
|
3121
|
+
dtype=None,
|
|
3122
|
+
ndim: Optional[int] = None,
|
|
3123
|
+
):
|
|
3045
3124
|
super().__init__(ARRAY_TYPE_INDEXED)
|
|
3046
3125
|
|
|
3047
3126
|
# canonicalize types
|
|
@@ -3232,7 +3311,7 @@ class Tile:
|
|
|
3232
3311
|
return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},wp::tile_shape_t<{','.join(map(str, self.shape))}>,{'true' if requires_grad else 'false'}>()"
|
|
3233
3312
|
else:
|
|
3234
3313
|
# tile will be initialized by another call, e.g.: tile_transpose()
|
|
3235
|
-
return "
|
|
3314
|
+
return "nullptr"
|
|
3236
3315
|
|
|
3237
3316
|
# return total tile size in bytes
|
|
3238
3317
|
def size_in_bytes(self):
|
|
@@ -3634,7 +3713,7 @@ class Volume:
|
|
|
3634
3713
|
instance.id = None
|
|
3635
3714
|
return instance
|
|
3636
3715
|
|
|
3637
|
-
def __init__(self, data: array, copy: bool = True):
|
|
3716
|
+
def __init__(self, data: array, copy: builtins.bool = True):
|
|
3638
3717
|
"""Class representing a sparse grid.
|
|
3639
3718
|
|
|
3640
3719
|
Args:
|
|
@@ -4361,6 +4440,15 @@ class Volume:
|
|
|
4361
4440
|
translation_buf = (ctypes.c_float * 3)(translation[0], translation[1], translation[2])
|
|
4362
4441
|
return transform_buf, translation_buf
|
|
4363
4442
|
|
|
4443
|
+
# nanovdb types for which we instantiate the grid builder
|
|
4444
|
+
# Should be in sync with WP_VOLUME_BUILDER_INSTANTIATE_TYPES in volume_builder.h
|
|
4445
|
+
_supported_allocation_types = [
|
|
4446
|
+
"int32",
|
|
4447
|
+
"float",
|
|
4448
|
+
"Vec3f",
|
|
4449
|
+
"Vec4f",
|
|
4450
|
+
]
|
|
4451
|
+
|
|
4364
4452
|
@classmethod
|
|
4365
4453
|
def allocate_by_tiles(
|
|
4366
4454
|
cls,
|
|
@@ -4388,7 +4476,8 @@ class Volume:
|
|
|
4388
4476
|
or a floating point scalar type (2D N-by-3 array of :class:`warp.float32` or 1D array of `warp.vec3f` values), indicating world space positions.
|
|
4389
4477
|
Repeated points per tile are allowed and will be efficiently deduplicated.
|
|
4390
4478
|
voxel_size (float or array-like): Voxel size(s) of the new volume. Ignored if `transform` is given.
|
|
4391
|
-
bg_value (array-like,
|
|
4479
|
+
bg_value (array-like, scalar or None): Value of unallocated voxels of the volume, also defines the volume's type. An index volume will be created if `bg_value` is ``None``.
|
|
4480
|
+
Other supported grid types are `int`, `float`, `vec3f`, and `vec4f`.
|
|
4392
4481
|
translation (array-like): Translation between the index and world spaces.
|
|
4393
4482
|
transform (array-like): Linear transform between the index and world spaces. If ``None``, deduced from `voxel_size`.
|
|
4394
4483
|
device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
@@ -4420,35 +4509,47 @@ class Volume:
|
|
|
4420
4509
|
translation_buf,
|
|
4421
4510
|
in_world_space,
|
|
4422
4511
|
)
|
|
4423
|
-
elif hasattr(bg_value, "__len__"):
|
|
4424
|
-
volume.id = volume.runtime.core.volume_v_from_tiles_device(
|
|
4425
|
-
volume.device.context,
|
|
4426
|
-
ctypes.c_void_p(tile_points.ptr),
|
|
4427
|
-
tile_points.shape[0],
|
|
4428
|
-
transform_buf,
|
|
4429
|
-
translation_buf,
|
|
4430
|
-
in_world_space,
|
|
4431
|
-
(ctypes.c_float * 3)(bg_value[0], bg_value[1], bg_value[2]),
|
|
4432
|
-
)
|
|
4433
|
-
elif isinstance(bg_value, int):
|
|
4434
|
-
volume.id = volume.runtime.core.volume_i_from_tiles_device(
|
|
4435
|
-
volume.device.context,
|
|
4436
|
-
ctypes.c_void_p(tile_points.ptr),
|
|
4437
|
-
tile_points.shape[0],
|
|
4438
|
-
transform_buf,
|
|
4439
|
-
translation_buf,
|
|
4440
|
-
in_world_space,
|
|
4441
|
-
bg_value,
|
|
4442
|
-
)
|
|
4443
4512
|
else:
|
|
4444
|
-
|
|
4513
|
+
# normalize background value type
|
|
4514
|
+
grid_type = type_to_warp(type(bg_value))
|
|
4515
|
+
if not (is_value(bg_value) or type_is_vector(grid_type)) and (
|
|
4516
|
+
hasattr(bg_value, "__len__") and is_value(bg_value[0])
|
|
4517
|
+
):
|
|
4518
|
+
# non-warp vectors are considered float, for backward compatibility
|
|
4519
|
+
grid_type = vector(len(bg_value), dtype=float)
|
|
4520
|
+
|
|
4521
|
+
# look for corresponding nvdb type
|
|
4522
|
+
try:
|
|
4523
|
+
nvdb_type = next(
|
|
4524
|
+
typ
|
|
4525
|
+
for typ in Volume._supported_allocation_types
|
|
4526
|
+
if types_equal(grid_type, Volume._nvdb_type_to_dtype[typ])
|
|
4527
|
+
)
|
|
4528
|
+
except StopIteration as err:
|
|
4529
|
+
raise TypeError(
|
|
4530
|
+
f"Unsupported bg_value type for volume allocation {type_repr(grid_type)}. Supported volume types are {', '.join(Volume._supported_allocation_types)}."
|
|
4531
|
+
) from err
|
|
4532
|
+
|
|
4533
|
+
# cast to ctype
|
|
4534
|
+
# wrap scalar values in length-1 vectors to handle specific ctype conversion
|
|
4535
|
+
if not type_is_vector(grid_type):
|
|
4536
|
+
grid_type = vector(length=1, dtype=grid_type)
|
|
4537
|
+
|
|
4538
|
+
cvalue = grid_type(bg_value)
|
|
4539
|
+
cvalue_ptr = ctypes.pointer(cvalue)
|
|
4540
|
+
cvalue_size = ctypes.sizeof(cvalue)
|
|
4541
|
+
cvalue_type = nvdb_type.encode("ascii")
|
|
4542
|
+
|
|
4543
|
+
volume.id = volume.runtime.core.volume_from_tiles_device(
|
|
4445
4544
|
volume.device.context,
|
|
4446
4545
|
ctypes.c_void_p(tile_points.ptr),
|
|
4447
4546
|
tile_points.shape[0],
|
|
4448
4547
|
transform_buf,
|
|
4449
4548
|
translation_buf,
|
|
4450
4549
|
in_world_space,
|
|
4451
|
-
|
|
4550
|
+
cvalue_ptr,
|
|
4551
|
+
cvalue_size,
|
|
4552
|
+
cvalue_type,
|
|
4452
4553
|
)
|
|
4453
4554
|
|
|
4454
4555
|
if volume.id == 0:
|
|
@@ -4606,6 +4707,8 @@ def matmul(
|
|
|
4606
4707
|
):
|
|
4607
4708
|
"""Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
4608
4709
|
|
|
4710
|
+
.. versionremoved:: 1.7
|
|
4711
|
+
|
|
4609
4712
|
.. deprecated:: 1.6
|
|
4610
4713
|
Use :doc:`tile primitives </modules/tiles>` instead.
|
|
4611
4714
|
|
|
@@ -4619,80 +4722,8 @@ def matmul(
|
|
|
4619
4722
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
4620
4723
|
while using Tensor Cores
|
|
4621
4724
|
"""
|
|
4622
|
-
from warp.context import runtime
|
|
4623
4725
|
|
|
4624
|
-
|
|
4625
|
-
"wp.matmul() is deprecated and will be removed in a\nfuture version. Use tile primitives instead.",
|
|
4626
|
-
category=DeprecationWarning,
|
|
4627
|
-
stacklevel=2,
|
|
4628
|
-
)
|
|
4629
|
-
|
|
4630
|
-
device = a.device
|
|
4631
|
-
|
|
4632
|
-
if b.device != device or c.device != device or d.device != device:
|
|
4633
|
-
raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
|
|
4634
|
-
|
|
4635
|
-
if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
|
|
4636
|
-
raise RuntimeError(
|
|
4637
|
-
"wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
4638
|
-
)
|
|
4639
|
-
|
|
4640
|
-
if (
|
|
4641
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
4642
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
4643
|
-
or (not c.is_contiguous)
|
|
4644
|
-
or (not d.is_contiguous)
|
|
4645
|
-
):
|
|
4646
|
-
raise RuntimeError(
|
|
4647
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
4648
|
-
)
|
|
4649
|
-
|
|
4650
|
-
m = a.shape[0]
|
|
4651
|
-
n = b.shape[1]
|
|
4652
|
-
k = a.shape[1]
|
|
4653
|
-
if b.shape != (k, n) or c.shape != (m, n) or d.shape != (m, n):
|
|
4654
|
-
raise RuntimeError(
|
|
4655
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
|
|
4656
|
-
)
|
|
4657
|
-
|
|
4658
|
-
if runtime.tape:
|
|
4659
|
-
runtime.tape.record_func(
|
|
4660
|
-
backward=lambda: adj_matmul(a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith),
|
|
4661
|
-
arrays=[a, b, c, d],
|
|
4662
|
-
)
|
|
4663
|
-
if warp.config.verify_autograd_array_access:
|
|
4664
|
-
d.mark_write()
|
|
4665
|
-
a.mark_read()
|
|
4666
|
-
b.mark_read()
|
|
4667
|
-
c.mark_read()
|
|
4668
|
-
|
|
4669
|
-
# cpu fallback if no cuda devices found
|
|
4670
|
-
if device == "cpu":
|
|
4671
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
4672
|
-
d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
|
|
4673
|
-
return
|
|
4674
|
-
|
|
4675
|
-
cc = device.arch
|
|
4676
|
-
ret = runtime.core.cutlass_gemm(
|
|
4677
|
-
device.context,
|
|
4678
|
-
cc,
|
|
4679
|
-
m,
|
|
4680
|
-
n,
|
|
4681
|
-
k,
|
|
4682
|
-
type_typestr(a.dtype).encode(),
|
|
4683
|
-
ctypes.c_void_p(a.ptr),
|
|
4684
|
-
ctypes.c_void_p(b.ptr),
|
|
4685
|
-
ctypes.c_void_p(c.ptr),
|
|
4686
|
-
ctypes.c_void_p(d.ptr),
|
|
4687
|
-
alpha,
|
|
4688
|
-
beta,
|
|
4689
|
-
not a.is_transposed,
|
|
4690
|
-
not b.is_transposed,
|
|
4691
|
-
allow_tf32x3_arith,
|
|
4692
|
-
1,
|
|
4693
|
-
)
|
|
4694
|
-
if not ret:
|
|
4695
|
-
raise RuntimeError("matmul failed.")
|
|
4726
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
4696
4727
|
|
|
4697
4728
|
|
|
4698
4729
|
def adj_matmul(
|
|
@@ -4724,171 +4755,8 @@ def adj_matmul(
|
|
|
4724
4755
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
4725
4756
|
while using Tensor Cores
|
|
4726
4757
|
"""
|
|
4727
|
-
from warp.context import runtime
|
|
4728
|
-
|
|
4729
|
-
device = a.device
|
|
4730
|
-
|
|
4731
|
-
if (
|
|
4732
|
-
b.device != device
|
|
4733
|
-
or c.device != device
|
|
4734
|
-
or adj_a.device != device
|
|
4735
|
-
or adj_b.device != device
|
|
4736
|
-
or adj_c.device != device
|
|
4737
|
-
or adj_d.device != device
|
|
4738
|
-
):
|
|
4739
|
-
raise RuntimeError(
|
|
4740
|
-
"Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
|
|
4741
|
-
)
|
|
4742
|
-
|
|
4743
|
-
if (
|
|
4744
|
-
a.dtype != b.dtype
|
|
4745
|
-
or a.dtype != c.dtype
|
|
4746
|
-
or a.dtype != adj_a.dtype
|
|
4747
|
-
or a.dtype != adj_b.dtype
|
|
4748
|
-
or a.dtype != adj_c.dtype
|
|
4749
|
-
or a.dtype != adj_d.dtype
|
|
4750
|
-
):
|
|
4751
|
-
raise RuntimeError(
|
|
4752
|
-
"wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
|
|
4753
|
-
)
|
|
4754
4758
|
|
|
4755
|
-
|
|
4756
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
4757
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
4758
|
-
or (not c.is_contiguous)
|
|
4759
|
-
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
4760
|
-
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
4761
|
-
or (not adj_c.is_contiguous)
|
|
4762
|
-
or (not adj_d.is_contiguous)
|
|
4763
|
-
):
|
|
4764
|
-
raise RuntimeError(
|
|
4765
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
4766
|
-
)
|
|
4767
|
-
|
|
4768
|
-
m = a.shape[0]
|
|
4769
|
-
n = b.shape[1]
|
|
4770
|
-
k = a.shape[1]
|
|
4771
|
-
if (
|
|
4772
|
-
a.shape != (m, k)
|
|
4773
|
-
or b.shape != (k, n)
|
|
4774
|
-
or c.shape != (m, n)
|
|
4775
|
-
or adj_d.shape != (m, n)
|
|
4776
|
-
or adj_a.shape != (m, k)
|
|
4777
|
-
or adj_b.shape != (k, n)
|
|
4778
|
-
or adj_c.shape != (m, n)
|
|
4779
|
-
):
|
|
4780
|
-
raise RuntimeError(
|
|
4781
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
|
|
4782
|
-
a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
|
|
4783
|
-
)
|
|
4784
|
-
)
|
|
4785
|
-
|
|
4786
|
-
# cpu fallback if no cuda devices found
|
|
4787
|
-
if device == "cpu":
|
|
4788
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
4789
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose(), dtype=np_dtype) + adj_a.numpy())
|
|
4790
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose(), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
|
|
4791
|
-
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
4792
|
-
return
|
|
4793
|
-
|
|
4794
|
-
cc = device.arch
|
|
4795
|
-
|
|
4796
|
-
# adj_a
|
|
4797
|
-
if not a.is_transposed:
|
|
4798
|
-
ret = runtime.core.cutlass_gemm(
|
|
4799
|
-
device.context,
|
|
4800
|
-
cc,
|
|
4801
|
-
m,
|
|
4802
|
-
k,
|
|
4803
|
-
n,
|
|
4804
|
-
type_typestr(a.dtype).encode(),
|
|
4805
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4806
|
-
ctypes.c_void_p(b.ptr),
|
|
4807
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4808
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4809
|
-
alpha,
|
|
4810
|
-
1.0,
|
|
4811
|
-
True,
|
|
4812
|
-
b.is_transposed,
|
|
4813
|
-
allow_tf32x3_arith,
|
|
4814
|
-
1,
|
|
4815
|
-
)
|
|
4816
|
-
if not ret:
|
|
4817
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4818
|
-
else:
|
|
4819
|
-
ret = runtime.core.cutlass_gemm(
|
|
4820
|
-
device.context,
|
|
4821
|
-
cc,
|
|
4822
|
-
k,
|
|
4823
|
-
m,
|
|
4824
|
-
n,
|
|
4825
|
-
type_typestr(a.dtype).encode(),
|
|
4826
|
-
ctypes.c_void_p(b.ptr),
|
|
4827
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4828
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4829
|
-
ctypes.c_void_p(adj_a.ptr),
|
|
4830
|
-
alpha,
|
|
4831
|
-
1.0,
|
|
4832
|
-
not b.is_transposed,
|
|
4833
|
-
False,
|
|
4834
|
-
allow_tf32x3_arith,
|
|
4835
|
-
1,
|
|
4836
|
-
)
|
|
4837
|
-
if not ret:
|
|
4838
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4839
|
-
|
|
4840
|
-
# adj_b
|
|
4841
|
-
if not b.is_transposed:
|
|
4842
|
-
ret = runtime.core.cutlass_gemm(
|
|
4843
|
-
device.context,
|
|
4844
|
-
cc,
|
|
4845
|
-
k,
|
|
4846
|
-
n,
|
|
4847
|
-
m,
|
|
4848
|
-
type_typestr(a.dtype).encode(),
|
|
4849
|
-
ctypes.c_void_p(a.ptr),
|
|
4850
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4851
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4852
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4853
|
-
alpha,
|
|
4854
|
-
1.0,
|
|
4855
|
-
a.is_transposed,
|
|
4856
|
-
True,
|
|
4857
|
-
allow_tf32x3_arith,
|
|
4858
|
-
1,
|
|
4859
|
-
)
|
|
4860
|
-
if not ret:
|
|
4861
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4862
|
-
else:
|
|
4863
|
-
ret = runtime.core.cutlass_gemm(
|
|
4864
|
-
device.context,
|
|
4865
|
-
cc,
|
|
4866
|
-
n,
|
|
4867
|
-
k,
|
|
4868
|
-
m,
|
|
4869
|
-
type_typestr(a.dtype).encode(),
|
|
4870
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
4871
|
-
ctypes.c_void_p(a.ptr),
|
|
4872
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4873
|
-
ctypes.c_void_p(adj_b.ptr),
|
|
4874
|
-
alpha,
|
|
4875
|
-
1.0,
|
|
4876
|
-
False,
|
|
4877
|
-
not a.is_transposed,
|
|
4878
|
-
allow_tf32x3_arith,
|
|
4879
|
-
1,
|
|
4880
|
-
)
|
|
4881
|
-
if not ret:
|
|
4882
|
-
raise RuntimeError("adj_matmul failed.")
|
|
4883
|
-
|
|
4884
|
-
# adj_c
|
|
4885
|
-
warp.launch(
|
|
4886
|
-
kernel=warp.utils.add_kernel_2d,
|
|
4887
|
-
dim=adj_c.shape,
|
|
4888
|
-
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
4889
|
-
device=device,
|
|
4890
|
-
record_tape=False,
|
|
4891
|
-
)
|
|
4759
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
4892
4760
|
|
|
4893
4761
|
|
|
4894
4762
|
def batched_matmul(
|
|
@@ -4902,6 +4770,8 @@ def batched_matmul(
|
|
|
4902
4770
|
):
|
|
4903
4771
|
"""Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
4904
4772
|
|
|
4773
|
+
.. versionremoved:: 1.7
|
|
4774
|
+
|
|
4905
4775
|
.. deprecated:: 1.6
|
|
4906
4776
|
Use :doc:`tile primitives </modules/tiles>` instead.
|
|
4907
4777
|
|
|
@@ -4915,107 +4785,8 @@ def batched_matmul(
|
|
|
4915
4785
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
4916
4786
|
while using Tensor Cores
|
|
4917
4787
|
"""
|
|
4918
|
-
from warp.context import runtime
|
|
4919
|
-
|
|
4920
|
-
device = a.device
|
|
4921
4788
|
|
|
4922
|
-
|
|
4923
|
-
raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
|
|
4924
|
-
|
|
4925
|
-
if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
|
|
4926
|
-
raise RuntimeError(
|
|
4927
|
-
"wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
4928
|
-
)
|
|
4929
|
-
|
|
4930
|
-
if (
|
|
4931
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
4932
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
4933
|
-
or (not c.is_contiguous)
|
|
4934
|
-
or (not d.is_contiguous)
|
|
4935
|
-
):
|
|
4936
|
-
raise RuntimeError(
|
|
4937
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
4938
|
-
)
|
|
4939
|
-
|
|
4940
|
-
m = a.shape[1]
|
|
4941
|
-
n = b.shape[2]
|
|
4942
|
-
k = a.shape[2]
|
|
4943
|
-
batch_count = a.shape[0]
|
|
4944
|
-
if b.shape != (batch_count, k, n) or c.shape != (batch_count, m, n) or d.shape != (batch_count, m, n):
|
|
4945
|
-
raise RuntimeError(
|
|
4946
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
|
|
4947
|
-
)
|
|
4948
|
-
|
|
4949
|
-
if runtime.tape:
|
|
4950
|
-
runtime.tape.record_func(
|
|
4951
|
-
backward=lambda: adj_batched_matmul(
|
|
4952
|
-
a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith
|
|
4953
|
-
),
|
|
4954
|
-
arrays=[a, b, c, d],
|
|
4955
|
-
)
|
|
4956
|
-
if warp.config.verify_autograd_array_access:
|
|
4957
|
-
d.mark_write()
|
|
4958
|
-
a.mark_read()
|
|
4959
|
-
b.mark_read()
|
|
4960
|
-
c.mark_read()
|
|
4961
|
-
|
|
4962
|
-
# cpu fallback if no cuda devices found
|
|
4963
|
-
if device == "cpu":
|
|
4964
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
4965
|
-
d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
|
|
4966
|
-
return
|
|
4967
|
-
|
|
4968
|
-
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
4969
|
-
max_batch_count = 65535
|
|
4970
|
-
iters = int(batch_count / max_batch_count)
|
|
4971
|
-
remainder = batch_count % max_batch_count
|
|
4972
|
-
|
|
4973
|
-
cc = device.arch
|
|
4974
|
-
for i in range(iters):
|
|
4975
|
-
idx_start = i * max_batch_count
|
|
4976
|
-
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
4977
|
-
ret = runtime.core.cutlass_gemm(
|
|
4978
|
-
device.context,
|
|
4979
|
-
cc,
|
|
4980
|
-
m,
|
|
4981
|
-
n,
|
|
4982
|
-
k,
|
|
4983
|
-
type_typestr(a.dtype).encode(),
|
|
4984
|
-
ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
|
|
4985
|
-
ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
|
|
4986
|
-
ctypes.c_void_p(c[idx_start:idx_end, :, :].ptr),
|
|
4987
|
-
ctypes.c_void_p(d[idx_start:idx_end, :, :].ptr),
|
|
4988
|
-
alpha,
|
|
4989
|
-
beta,
|
|
4990
|
-
not a.is_transposed,
|
|
4991
|
-
not b.is_transposed,
|
|
4992
|
-
allow_tf32x3_arith,
|
|
4993
|
-
max_batch_count,
|
|
4994
|
-
)
|
|
4995
|
-
if not ret:
|
|
4996
|
-
raise RuntimeError("Batched matmul failed.")
|
|
4997
|
-
|
|
4998
|
-
idx_start = iters * max_batch_count
|
|
4999
|
-
ret = runtime.core.cutlass_gemm(
|
|
5000
|
-
device.context,
|
|
5001
|
-
cc,
|
|
5002
|
-
m,
|
|
5003
|
-
n,
|
|
5004
|
-
k,
|
|
5005
|
-
type_typestr(a.dtype).encode(),
|
|
5006
|
-
ctypes.c_void_p(a[idx_start:, :, :].ptr),
|
|
5007
|
-
ctypes.c_void_p(b[idx_start:, :, :].ptr),
|
|
5008
|
-
ctypes.c_void_p(c[idx_start:, :, :].ptr),
|
|
5009
|
-
ctypes.c_void_p(d[idx_start:, :, :].ptr),
|
|
5010
|
-
alpha,
|
|
5011
|
-
beta,
|
|
5012
|
-
not a.is_transposed,
|
|
5013
|
-
not b.is_transposed,
|
|
5014
|
-
allow_tf32x3_arith,
|
|
5015
|
-
remainder,
|
|
5016
|
-
)
|
|
5017
|
-
if not ret:
|
|
5018
|
-
raise RuntimeError("Batched matmul failed.")
|
|
4789
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
5019
4790
|
|
|
5020
4791
|
|
|
5021
4792
|
def adj_batched_matmul(
|
|
@@ -5045,270 +4816,8 @@ def adj_batched_matmul(
|
|
|
5045
4816
|
allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
|
|
5046
4817
|
while using Tensor Cores
|
|
5047
4818
|
"""
|
|
5048
|
-
from warp.context import runtime
|
|
5049
|
-
|
|
5050
|
-
device = a.device
|
|
5051
|
-
|
|
5052
|
-
if (
|
|
5053
|
-
b.device != device
|
|
5054
|
-
or c.device != device
|
|
5055
|
-
or adj_a.device != device
|
|
5056
|
-
or adj_b.device != device
|
|
5057
|
-
or adj_c.device != device
|
|
5058
|
-
or adj_d.device != device
|
|
5059
|
-
):
|
|
5060
|
-
raise RuntimeError(
|
|
5061
|
-
"Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
|
|
5062
|
-
)
|
|
5063
|
-
|
|
5064
|
-
if (
|
|
5065
|
-
a.dtype != b.dtype
|
|
5066
|
-
or a.dtype != c.dtype
|
|
5067
|
-
or a.dtype != adj_a.dtype
|
|
5068
|
-
or a.dtype != adj_b.dtype
|
|
5069
|
-
or a.dtype != adj_c.dtype
|
|
5070
|
-
or a.dtype != adj_d.dtype
|
|
5071
|
-
):
|
|
5072
|
-
raise RuntimeError(
|
|
5073
|
-
"wp.adj_batched_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
|
|
5074
|
-
)
|
|
5075
|
-
|
|
5076
|
-
m = a.shape[1]
|
|
5077
|
-
n = b.shape[2]
|
|
5078
|
-
k = a.shape[2]
|
|
5079
|
-
batch_count = a.shape[0]
|
|
5080
|
-
if (
|
|
5081
|
-
b.shape != (batch_count, k, n)
|
|
5082
|
-
or c.shape != (batch_count, m, n)
|
|
5083
|
-
or adj_d.shape != (batch_count, m, n)
|
|
5084
|
-
or adj_a.shape != (batch_count, m, k)
|
|
5085
|
-
or adj_b.shape != (batch_count, k, n)
|
|
5086
|
-
or adj_c.shape != (batch_count, m, n)
|
|
5087
|
-
):
|
|
5088
|
-
raise RuntimeError(
|
|
5089
|
-
"Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
|
|
5090
|
-
a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
|
|
5091
|
-
)
|
|
5092
|
-
)
|
|
5093
4819
|
|
|
5094
|
-
|
|
5095
|
-
(not a.is_contiguous and not a.is_transposed)
|
|
5096
|
-
or (not b.is_contiguous and not b.is_transposed)
|
|
5097
|
-
or (not c.is_contiguous)
|
|
5098
|
-
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
5099
|
-
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
5100
|
-
or (not adj_c.is_contiguous)
|
|
5101
|
-
or (not adj_d.is_contiguous)
|
|
5102
|
-
):
|
|
5103
|
-
raise RuntimeError(
|
|
5104
|
-
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
5105
|
-
)
|
|
5106
|
-
|
|
5107
|
-
# cpu fallback if no cuda devices found
|
|
5108
|
-
if device == "cpu":
|
|
5109
|
-
np_dtype = warp_type_to_np_dtype[a.dtype]
|
|
5110
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1)), dtype=np_dtype) + adj_a.numpy())
|
|
5111
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
|
|
5112
|
-
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
5113
|
-
return
|
|
5114
|
-
|
|
5115
|
-
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
5116
|
-
max_batch_count = 65535
|
|
5117
|
-
iters = int(batch_count / max_batch_count)
|
|
5118
|
-
remainder = batch_count % max_batch_count
|
|
5119
|
-
|
|
5120
|
-
cc = device.arch
|
|
5121
|
-
|
|
5122
|
-
for i in range(iters):
|
|
5123
|
-
idx_start = i * max_batch_count
|
|
5124
|
-
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
5125
|
-
|
|
5126
|
-
# adj_a
|
|
5127
|
-
if not a.is_transposed:
|
|
5128
|
-
ret = runtime.core.cutlass_gemm(
|
|
5129
|
-
device.context,
|
|
5130
|
-
cc,
|
|
5131
|
-
m,
|
|
5132
|
-
k,
|
|
5133
|
-
n,
|
|
5134
|
-
type_typestr(a.dtype).encode(),
|
|
5135
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5136
|
-
ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
|
|
5137
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5138
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5139
|
-
alpha,
|
|
5140
|
-
1.0,
|
|
5141
|
-
True,
|
|
5142
|
-
b.is_transposed,
|
|
5143
|
-
allow_tf32x3_arith,
|
|
5144
|
-
max_batch_count,
|
|
5145
|
-
)
|
|
5146
|
-
if not ret:
|
|
5147
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5148
|
-
else:
|
|
5149
|
-
ret = runtime.core.cutlass_gemm(
|
|
5150
|
-
device.context,
|
|
5151
|
-
cc,
|
|
5152
|
-
k,
|
|
5153
|
-
m,
|
|
5154
|
-
n,
|
|
5155
|
-
type_typestr(a.dtype).encode(),
|
|
5156
|
-
ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
|
|
5157
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5158
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5159
|
-
ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
|
|
5160
|
-
alpha,
|
|
5161
|
-
1.0,
|
|
5162
|
-
not b.is_transposed,
|
|
5163
|
-
False,
|
|
5164
|
-
allow_tf32x3_arith,
|
|
5165
|
-
max_batch_count,
|
|
5166
|
-
)
|
|
5167
|
-
if not ret:
|
|
5168
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5169
|
-
|
|
5170
|
-
# adj_b
|
|
5171
|
-
if not b.is_transposed:
|
|
5172
|
-
ret = runtime.core.cutlass_gemm(
|
|
5173
|
-
device.context,
|
|
5174
|
-
cc,
|
|
5175
|
-
k,
|
|
5176
|
-
n,
|
|
5177
|
-
m,
|
|
5178
|
-
type_typestr(a.dtype).encode(),
|
|
5179
|
-
ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
|
|
5180
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5181
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5182
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5183
|
-
alpha,
|
|
5184
|
-
1.0,
|
|
5185
|
-
a.is_transposed,
|
|
5186
|
-
True,
|
|
5187
|
-
allow_tf32x3_arith,
|
|
5188
|
-
max_batch_count,
|
|
5189
|
-
)
|
|
5190
|
-
if not ret:
|
|
5191
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5192
|
-
else:
|
|
5193
|
-
ret = runtime.core.cutlass_gemm(
|
|
5194
|
-
device.context,
|
|
5195
|
-
cc,
|
|
5196
|
-
n,
|
|
5197
|
-
k,
|
|
5198
|
-
m,
|
|
5199
|
-
type_typestr(a.dtype).encode(),
|
|
5200
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
|
|
5201
|
-
ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
|
|
5202
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5203
|
-
ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
|
|
5204
|
-
alpha,
|
|
5205
|
-
1.0,
|
|
5206
|
-
False,
|
|
5207
|
-
not a.is_transposed,
|
|
5208
|
-
allow_tf32x3_arith,
|
|
5209
|
-
max_batch_count,
|
|
5210
|
-
)
|
|
5211
|
-
if not ret:
|
|
5212
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5213
|
-
|
|
5214
|
-
idx_start = iters * max_batch_count
|
|
5215
|
-
|
|
5216
|
-
# adj_a
|
|
5217
|
-
if not a.is_transposed:
|
|
5218
|
-
ret = runtime.core.cutlass_gemm(
|
|
5219
|
-
device.context,
|
|
5220
|
-
cc,
|
|
5221
|
-
m,
|
|
5222
|
-
k,
|
|
5223
|
-
n,
|
|
5224
|
-
type_typestr(a.dtype).encode(),
|
|
5225
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5226
|
-
ctypes.c_void_p(b[idx_start:, :, :].ptr),
|
|
5227
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5228
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5229
|
-
alpha,
|
|
5230
|
-
1.0,
|
|
5231
|
-
True,
|
|
5232
|
-
b.is_transposed,
|
|
5233
|
-
allow_tf32x3_arith,
|
|
5234
|
-
remainder,
|
|
5235
|
-
)
|
|
5236
|
-
if not ret:
|
|
5237
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5238
|
-
else:
|
|
5239
|
-
ret = runtime.core.cutlass_gemm(
|
|
5240
|
-
device.context,
|
|
5241
|
-
cc,
|
|
5242
|
-
k,
|
|
5243
|
-
m,
|
|
5244
|
-
n,
|
|
5245
|
-
type_typestr(a.dtype).encode(),
|
|
5246
|
-
ctypes.c_void_p(b[idx_start:, :, :].ptr),
|
|
5247
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5248
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5249
|
-
ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
|
|
5250
|
-
alpha,
|
|
5251
|
-
1.0,
|
|
5252
|
-
not b.is_transposed,
|
|
5253
|
-
False,
|
|
5254
|
-
allow_tf32x3_arith,
|
|
5255
|
-
remainder,
|
|
5256
|
-
)
|
|
5257
|
-
if not ret:
|
|
5258
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5259
|
-
|
|
5260
|
-
# adj_b
|
|
5261
|
-
if not b.is_transposed:
|
|
5262
|
-
ret = runtime.core.cutlass_gemm(
|
|
5263
|
-
device.context,
|
|
5264
|
-
cc,
|
|
5265
|
-
k,
|
|
5266
|
-
n,
|
|
5267
|
-
m,
|
|
5268
|
-
type_typestr(a.dtype).encode(),
|
|
5269
|
-
ctypes.c_void_p(a[idx_start:, :, :].ptr),
|
|
5270
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5271
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5272
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5273
|
-
alpha,
|
|
5274
|
-
1.0,
|
|
5275
|
-
a.is_transposed,
|
|
5276
|
-
True,
|
|
5277
|
-
allow_tf32x3_arith,
|
|
5278
|
-
remainder,
|
|
5279
|
-
)
|
|
5280
|
-
if not ret:
|
|
5281
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5282
|
-
else:
|
|
5283
|
-
ret = runtime.core.cutlass_gemm(
|
|
5284
|
-
device.context,
|
|
5285
|
-
cc,
|
|
5286
|
-
n,
|
|
5287
|
-
k,
|
|
5288
|
-
m,
|
|
5289
|
-
type_typestr(a.dtype).encode(),
|
|
5290
|
-
ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
|
|
5291
|
-
ctypes.c_void_p(a[idx_start:, :, :].ptr),
|
|
5292
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5293
|
-
ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
|
|
5294
|
-
alpha,
|
|
5295
|
-
1.0,
|
|
5296
|
-
False,
|
|
5297
|
-
not a.is_transposed,
|
|
5298
|
-
allow_tf32x3_arith,
|
|
5299
|
-
remainder,
|
|
5300
|
-
)
|
|
5301
|
-
if not ret:
|
|
5302
|
-
raise RuntimeError("adj_matmul failed.")
|
|
5303
|
-
|
|
5304
|
-
# adj_c
|
|
5305
|
-
warp.launch(
|
|
5306
|
-
kernel=warp.utils.add_kernel_3d,
|
|
5307
|
-
dim=adj_c.shape,
|
|
5308
|
-
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
5309
|
-
device=device,
|
|
5310
|
-
record_tape=False,
|
|
5311
|
-
)
|
|
4820
|
+
raise RuntimeError("This function has been removed. Use tile primitives instead.")
|
|
5312
4821
|
|
|
5313
4822
|
|
|
5314
4823
|
class HashGrid:
|
|
@@ -5691,7 +5200,7 @@ simple_type_codes = {
|
|
|
5691
5200
|
}
|
|
5692
5201
|
|
|
5693
5202
|
|
|
5694
|
-
def get_type_code(arg_type):
|
|
5203
|
+
def get_type_code(arg_type: type) -> str:
|
|
5695
5204
|
if arg_type == Any:
|
|
5696
5205
|
# special case for generics
|
|
5697
5206
|
# note: since Python 3.11 Any is a type, so we check for it first
|
|
@@ -5755,8 +5264,8 @@ def get_type_code(arg_type):
|
|
|
5755
5264
|
raise TypeError(f"Unrecognized type '{arg_type}'")
|
|
5756
5265
|
|
|
5757
5266
|
|
|
5758
|
-
def get_signature(arg_types, func_name=None, arg_names=None):
|
|
5759
|
-
type_codes = []
|
|
5267
|
+
def get_signature(arg_types: List[type], func_name: Optional[str] = None, arg_names: Optional[List[str]] = None) -> str:
|
|
5268
|
+
type_codes: List[str] = []
|
|
5760
5269
|
for i, arg_type in enumerate(arg_types):
|
|
5761
5270
|
try:
|
|
5762
5271
|
type_codes.append(get_type_code(arg_type))
|