warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -20,7 +20,21 @@ import ctypes
20
20
  import inspect
21
21
  import struct
22
22
  import zlib
23
- from typing import Any, Callable, Generic, List, Literal, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
23
+ from typing import (
24
+ Any,
25
+ Callable,
26
+ Generic,
27
+ List,
28
+ Literal,
29
+ NamedTuple,
30
+ Optional,
31
+ Sequence,
32
+ Tuple,
33
+ TypeVar,
34
+ Union,
35
+ get_args,
36
+ get_origin,
37
+ )
24
38
 
25
39
  import numpy as np
26
40
  import numpy.typing as npt
@@ -56,7 +70,9 @@ class Transformation(Generic[Float]):
56
70
 
57
71
 
58
72
  class Array(Generic[DType]):
59
- pass
73
+ device: Optional[warp.context.Device]
74
+ dtype: type
75
+ size: int
60
76
 
61
77
 
62
78
  int_tuple_type_hints = {
@@ -262,7 +278,12 @@ def vector(length, dtype):
262
278
  def __str__(self):
263
279
  return f"[{', '.join(map(str, self))}]"
264
280
 
281
+ def __repr__(self):
282
+ return f"{type_repr(self)}([{', '.join(map(repr, self))}])"
283
+
265
284
  def __eq__(self, other):
285
+ if self._length_ != len(other):
286
+ return False
266
287
  for i in range(self._length_):
267
288
  if self[i] != other[i]:
268
289
  return False
@@ -405,7 +426,11 @@ def matrix(shape, dtype):
405
426
  return "[" + ",\n ".join(row_str) + "]"
406
427
 
407
428
  def __eq__(self, other):
429
+ if self._shape_[0] != len(other):
430
+ return False
408
431
  for i in range(self._shape_[0]):
432
+ if self._shape_[1] != len(other[i]):
433
+ return False
409
434
  for j in range(self._shape_[1]):
410
435
  if self[i][j] != other[i][j]:
411
436
  return False
@@ -1139,7 +1164,7 @@ ARRAY_TYPE_FABRIC_INDEXED = 3
1139
1164
  class launch_bounds_t(ctypes.Structure):
1140
1165
  _fields_ = [("shape", ctypes.c_int32 * LAUNCH_MAX_DIMS), ("ndim", ctypes.c_int32), ("size", ctypes.c_size_t)]
1141
1166
 
1142
- def __init__(self, shape):
1167
+ def __init__(self, shape: Union[int, Sequence[int]]):
1143
1168
  if isinstance(shape, int):
1144
1169
  # 1d launch
1145
1170
  self.ndim = 1
@@ -1260,7 +1285,7 @@ _type_size_cache = {
1260
1285
  }
1261
1286
 
1262
1287
 
1263
- def type_size_in_bytes(dtype):
1288
+ def type_size_in_bytes(dtype: type) -> int:
1264
1289
  size = _type_size_cache.get(dtype)
1265
1290
 
1266
1291
  if size is None:
@@ -1279,7 +1304,7 @@ def type_size_in_bytes(dtype):
1279
1304
  return size
1280
1305
 
1281
1306
 
1282
- def type_to_warp(dtype):
1307
+ def type_to_warp(dtype: type) -> type:
1283
1308
  if dtype == float:
1284
1309
  return float32
1285
1310
  elif dtype == int:
@@ -1290,7 +1315,7 @@ def type_to_warp(dtype):
1290
1315
  return dtype
1291
1316
 
1292
1317
 
1293
- def type_typestr(dtype):
1318
+ def type_typestr(dtype: type) -> str:
1294
1319
  if dtype == bool:
1295
1320
  return "|b1"
1296
1321
  elif dtype == float16:
@@ -1323,22 +1348,67 @@ def type_typestr(dtype):
1323
1348
  raise Exception("Unknown ctype")
1324
1349
 
1325
1350
 
1351
+ def scalar_short_name(t):
1352
+ if t == float32:
1353
+ return "f"
1354
+ elif t == float64:
1355
+ return "d"
1356
+ elif t == int8:
1357
+ return "b"
1358
+ elif t == int16:
1359
+ return "s"
1360
+ elif t == int32:
1361
+ return "i"
1362
+ elif t == int64:
1363
+ return "l"
1364
+ elif t == uint8:
1365
+ return "ub"
1366
+ elif t == uint16:
1367
+ return "us"
1368
+ elif t == uint32:
1369
+ return "ui"
1370
+ elif t == uint64:
1371
+ return "ul"
1372
+ return None
1373
+
1374
+
1326
1375
  # converts any known type to a human readable string, good for error messages, reporting etc
1327
1376
  def type_repr(t):
1328
1377
  if is_array(t):
1329
- return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
1378
+ if t.device is None:
1379
+ # array is used as a type annotation - display ndim instead of shape
1380
+ return f"array(ndim={t.ndim}, dtype={type_repr(t.dtype)})"
1381
+ return f"array(shape={t.shape}, dtype={type_repr(t.dtype)})"
1330
1382
  if is_tile(t):
1331
- return str(f"tile(dtype={t.dtype}, shape={t.shape}")
1332
- if type_is_vector(t):
1333
- return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
1334
- if type_is_matrix(t):
1335
- return str(f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={t._wp_scalar_type_})")
1383
+ return f"tile(shape={t.shape}, dtype={type_repr(t.dtype)})"
1336
1384
  if isinstance(t, warp.codegen.Struct):
1337
1385
  return type_repr(t.cls)
1386
+ sn = None
1387
+ if hasattr(t, "_wp_scalar_type_"):
1388
+ sn = scalar_short_name(t._wp_scalar_type_)
1389
+ if type_is_transformation(t):
1390
+ if sn is not None:
1391
+ return f"transform{sn}"
1392
+ return f"transform(dtype={type_repr(t._wp_scalar_type_)})"
1393
+ if type_is_quaternion(t):
1394
+ if sn is not None:
1395
+ return f"quat{sn}"
1396
+ return f"quat(dtype={type_repr(t._wp_scalar_type_)})"
1397
+ if type_is_vector(t):
1398
+ if sn is not None and t._shape_[0] <= 4:
1399
+ return f"vec{t._shape_[0]}{sn}"
1400
+ return f"vector(length={t._shape_[0]}, dtype={type_repr(t._wp_scalar_type_)})"
1401
+ if type_is_matrix(t):
1402
+ if sn is not None and t._shape_[0] <= 4 and t._shape_[1] <= 4:
1403
+ return f"mat{t._shape_[0]}{t._shape_[1]}({sn})"
1404
+ return f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={type_repr(t._wp_scalar_type_)})"
1338
1405
  if t in scalar_types:
1339
1406
  return t.__name__
1340
1407
 
1341
- name = getattr(t, "__qualname__", t.__name__)
1408
+ name = getattr(t, "__name__", None)
1409
+ if name is None:
1410
+ return repr(t)
1411
+ name = getattr(t, "__qualname__", name)
1342
1412
  return t.__module__ + "." + name
1343
1413
 
1344
1414
 
@@ -1376,33 +1446,33 @@ def type_is_transformation(t):
1376
1446
  return getattr(t, "_wp_generic_type_hint_", None) is Transformation
1377
1447
 
1378
1448
 
1379
- value_types = (int, float, builtins.bool) + scalar_types
1449
+ value_types = (int, float, builtins.bool) + scalar_and_bool_types
1380
1450
 
1381
1451
 
1382
1452
  # returns true for all value types (int, float, bool, scalars, vectors, matrices)
1383
- def type_is_value(x):
1453
+ def type_is_value(x: Any) -> builtins.bool:
1384
1454
  return x in value_types or hasattr(x, "_wp_scalar_type_")
1385
1455
 
1386
1456
 
1387
1457
  # equivalent of the above but for values
1388
- def is_int(x):
1458
+ def is_int(x: Any) -> builtins.bool:
1389
1459
  return type_is_int(type(x))
1390
1460
 
1391
1461
 
1392
- def is_float(x):
1462
+ def is_float(x: Any) -> builtins.bool:
1393
1463
  return type_is_float(type(x))
1394
1464
 
1395
1465
 
1396
- def is_value(x):
1466
+ def is_value(x: Any) -> builtins.bool:
1397
1467
  return type_is_value(type(x))
1398
1468
 
1399
1469
 
1400
- # returns true if the passed *instance* is one of the array types
1401
- def is_array(a):
1470
+ def is_array(a) -> builtins.bool:
1471
+ """Return true if the passed *instance* is one of the array types."""
1402
1472
  return isinstance(a, array_types)
1403
1473
 
1404
1474
 
1405
- def scalars_equal(a, b, match_generic):
1475
+ def scalars_equal(a, b, match_generic=False):
1406
1476
  # convert to canonical types
1407
1477
  if a == float:
1408
1478
  a = float32
@@ -1465,21 +1535,21 @@ def types_equal(a, b, match_generic=False):
1465
1535
  if a_length is None or b_length is None or a_length == b_length:
1466
1536
  return True
1467
1537
 
1468
- a_origin = warp.codegen.get_type_origin(a)
1469
- b_origin = warp.codegen.get_type_origin(b)
1538
+ a_origin = get_origin(a)
1539
+ b_origin = get_origin(b)
1470
1540
  if a_origin is tuple and b_origin is tuple:
1471
- a_args = warp.codegen.get_type_args(a)
1472
- b_args = warp.codegen.get_type_args(b)
1541
+ a_args = get_args(a)
1542
+ b_args = get_args(b)
1473
1543
  if len(a_args) == len(b_args) and all(
1474
1544
  scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b_args)
1475
1545
  ):
1476
1546
  return True
1477
1547
  elif a_origin is tuple and isinstance(b, Sequence):
1478
- a_args = warp.codegen.get_type_args(a)
1548
+ a_args = get_args(a)
1479
1549
  if len(a_args) == len(b) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b)):
1480
1550
  return True
1481
1551
  elif b_origin is tuple and isinstance(a, Sequence):
1482
- b_args = warp.codegen.get_type_args(b)
1552
+ b_args = get_args(b)
1483
1553
  if len(b_args) == len(a) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(b_args, a)):
1484
1554
  return True
1485
1555
 
@@ -1600,7 +1670,7 @@ def array_ctype_from_interface(interface: dict, dtype=None, owner=None):
1600
1670
  return array_ctype
1601
1671
 
1602
1672
 
1603
- class array(Array):
1673
+ class array(Array[DType]):
1604
1674
  """A fixed-size multi-dimensional array containing values of the same type.
1605
1675
 
1606
1676
  Attributes:
@@ -1629,21 +1699,21 @@ class array(Array):
1629
1699
 
1630
1700
  def __init__(
1631
1701
  self,
1632
- data: Optional[Union[List, Tuple, npt.NDArray]] = None,
1633
- dtype: Union[DType, Any] = Any,
1634
- shape: Optional[Tuple[int, ...]] = None,
1702
+ data: Union[List, Tuple, npt.NDArray, None] = None,
1703
+ dtype: Any = Any,
1704
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
1635
1705
  strides: Optional[Tuple[int, ...]] = None,
1636
1706
  length: Optional[int] = None,
1637
1707
  ptr: Optional[int] = None,
1638
1708
  capacity: Optional[int] = None,
1639
1709
  device=None,
1640
- pinned: bool = False,
1641
- copy: bool = True,
1642
- owner: bool = False, # deprecated - pass deleter instead
1710
+ pinned: builtins.bool = False,
1711
+ copy: builtins.bool = True,
1712
+ owner: builtins.bool = False, # deprecated - pass deleter instead
1643
1713
  deleter: Optional[Callable[[int, int], None]] = None,
1644
1714
  ndim: Optional[int] = None,
1645
1715
  grad: Optional[array] = None,
1646
- requires_grad: bool = False,
1716
+ requires_grad: builtins.bool = False,
1647
1717
  ):
1648
1718
  """Constructs a new Warp array object
1649
1719
 
@@ -2219,6 +2289,9 @@ class array(Array):
2219
2289
  else:
2220
2290
  return str(self.numpy())
2221
2291
 
2292
+ def __repr__(self):
2293
+ return type_repr(self)
2294
+
2222
2295
  def __getitem__(self, key):
2223
2296
  if isinstance(key, int):
2224
2297
  if self.ndim == 1:
@@ -2939,7 +3012,7 @@ def from_ipc_handle(
2939
3012
 
2940
3013
  # A base class for non-contiguous arrays, providing the implementation of common methods like
2941
3014
  # contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
2942
- class noncontiguous_array_base(Generic[T]):
3015
+ class noncontiguous_array_base(Array[T]):
2943
3016
  def __init__(self, array_type_id):
2944
3017
  self.type_id = array_type_id
2945
3018
  self.is_contiguous = False
@@ -3036,12 +3109,18 @@ def check_index_array(indices, expected_device):
3036
3109
  raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
3037
3110
 
3038
3111
 
3039
- class indexedarray(noncontiguous_array_base[T]):
3112
+ class indexedarray(noncontiguous_array_base):
3040
3113
  # member attributes available during code-gen (e.g.: d = arr.shape[0])
3041
3114
  # (initialized when needed)
3042
3115
  _vars = None
3043
3116
 
3044
- def __init__(self, data: array = None, indices: Union[array, List[array]] = None, dtype=None, ndim=None):
3117
+ def __init__(
3118
+ self,
3119
+ data: Optional[array] = None,
3120
+ indices: Union[array, List[array], None] = None,
3121
+ dtype=None,
3122
+ ndim: Optional[int] = None,
3123
+ ):
3045
3124
  super().__init__(ARRAY_TYPE_INDEXED)
3046
3125
 
3047
3126
  # canonicalize types
@@ -3232,7 +3311,7 @@ class Tile:
3232
3311
  return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},wp::tile_shape_t<{','.join(map(str, self.shape))}>,{'true' if requires_grad else 'false'}>()"
3233
3312
  else:
3234
3313
  # tile will be initialized by another call, e.g.: tile_transpose()
3235
- return "NULL"
3314
+ return "nullptr"
3236
3315
 
3237
3316
  # return total tile size in bytes
3238
3317
  def size_in_bytes(self):
@@ -3634,7 +3713,7 @@ class Volume:
3634
3713
  instance.id = None
3635
3714
  return instance
3636
3715
 
3637
- def __init__(self, data: array, copy: bool = True):
3716
+ def __init__(self, data: array, copy: builtins.bool = True):
3638
3717
  """Class representing a sparse grid.
3639
3718
 
3640
3719
  Args:
@@ -4361,6 +4440,15 @@ class Volume:
4361
4440
  translation_buf = (ctypes.c_float * 3)(translation[0], translation[1], translation[2])
4362
4441
  return transform_buf, translation_buf
4363
4442
 
4443
+ # nanovdb types for which we instantiate the grid builder
4444
+ # Should be in sync with WP_VOLUME_BUILDER_INSTANTIATE_TYPES in volume_builder.h
4445
+ _supported_allocation_types = [
4446
+ "int32",
4447
+ "float",
4448
+ "Vec3f",
4449
+ "Vec4f",
4450
+ ]
4451
+
4364
4452
  @classmethod
4365
4453
  def allocate_by_tiles(
4366
4454
  cls,
@@ -4388,7 +4476,8 @@ class Volume:
4388
4476
  or a floating point scalar type (2D N-by-3 array of :class:`warp.float32` or 1D array of `warp.vec3f` values), indicating world space positions.
4389
4477
  Repeated points per tile are allowed and will be efficiently deduplicated.
4390
4478
  voxel_size (float or array-like): Voxel size(s) of the new volume. Ignored if `transform` is given.
4391
- bg_value (array-like, float, int or None): Value of unallocated voxels of the volume, also defines the volume's type. A :class:`warp.vec3` volume is created if this is `array-like`, an index volume will be created if `bg_value` is ``None``.
4479
+ bg_value (array-like, scalar or None): Value of unallocated voxels of the volume, also defines the volume's type. An index volume will be created if `bg_value` is ``None``.
4480
+ Other supported grid types are `int`, `float`, `vec3f`, and `vec4f`.
4392
4481
  translation (array-like): Translation between the index and world spaces.
4393
4482
  transform (array-like): Linear transform between the index and world spaces. If ``None``, deduced from `voxel_size`.
4394
4483
  device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
@@ -4420,35 +4509,47 @@ class Volume:
4420
4509
  translation_buf,
4421
4510
  in_world_space,
4422
4511
  )
4423
- elif hasattr(bg_value, "__len__"):
4424
- volume.id = volume.runtime.core.volume_v_from_tiles_device(
4425
- volume.device.context,
4426
- ctypes.c_void_p(tile_points.ptr),
4427
- tile_points.shape[0],
4428
- transform_buf,
4429
- translation_buf,
4430
- in_world_space,
4431
- (ctypes.c_float * 3)(bg_value[0], bg_value[1], bg_value[2]),
4432
- )
4433
- elif isinstance(bg_value, int):
4434
- volume.id = volume.runtime.core.volume_i_from_tiles_device(
4435
- volume.device.context,
4436
- ctypes.c_void_p(tile_points.ptr),
4437
- tile_points.shape[0],
4438
- transform_buf,
4439
- translation_buf,
4440
- in_world_space,
4441
- bg_value,
4442
- )
4443
4512
  else:
4444
- volume.id = volume.runtime.core.volume_f_from_tiles_device(
4513
+ # normalize background value type
4514
+ grid_type = type_to_warp(type(bg_value))
4515
+ if not (is_value(bg_value) or type_is_vector(grid_type)) and (
4516
+ hasattr(bg_value, "__len__") and is_value(bg_value[0])
4517
+ ):
4518
+ # non-warp vectors are considered float, for backward compatibility
4519
+ grid_type = vector(len(bg_value), dtype=float)
4520
+
4521
+ # look for corresponding nvdb type
4522
+ try:
4523
+ nvdb_type = next(
4524
+ typ
4525
+ for typ in Volume._supported_allocation_types
4526
+ if types_equal(grid_type, Volume._nvdb_type_to_dtype[typ])
4527
+ )
4528
+ except StopIteration as err:
4529
+ raise TypeError(
4530
+ f"Unsupported bg_value type for volume allocation {type_repr(grid_type)}. Supported volume types are {', '.join(Volume._supported_allocation_types)}."
4531
+ ) from err
4532
+
4533
+ # cast to ctype
4534
+ # wrap scalar values in length-1 vectors to handle specific ctype conversion
4535
+ if not type_is_vector(grid_type):
4536
+ grid_type = vector(length=1, dtype=grid_type)
4537
+
4538
+ cvalue = grid_type(bg_value)
4539
+ cvalue_ptr = ctypes.pointer(cvalue)
4540
+ cvalue_size = ctypes.sizeof(cvalue)
4541
+ cvalue_type = nvdb_type.encode("ascii")
4542
+
4543
+ volume.id = volume.runtime.core.volume_from_tiles_device(
4445
4544
  volume.device.context,
4446
4545
  ctypes.c_void_p(tile_points.ptr),
4447
4546
  tile_points.shape[0],
4448
4547
  transform_buf,
4449
4548
  translation_buf,
4450
4549
  in_world_space,
4451
- float(bg_value),
4550
+ cvalue_ptr,
4551
+ cvalue_size,
4552
+ cvalue_type,
4452
4553
  )
4453
4554
 
4454
4555
  if volume.id == 0:
@@ -4606,6 +4707,8 @@ def matmul(
4606
4707
  ):
4607
4708
  """Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
4608
4709
 
4710
+ .. versionremoved:: 1.7
4711
+
4609
4712
  .. deprecated:: 1.6
4610
4713
  Use :doc:`tile primitives </modules/tiles>` instead.
4611
4714
 
@@ -4619,80 +4722,8 @@ def matmul(
4619
4722
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
4620
4723
  while using Tensor Cores
4621
4724
  """
4622
- from warp.context import runtime
4623
4725
 
4624
- warp.utils.warn(
4625
- "wp.matmul() is deprecated and will be removed in a\nfuture version. Use tile primitives instead.",
4626
- category=DeprecationWarning,
4627
- stacklevel=2,
4628
- )
4629
-
4630
- device = a.device
4631
-
4632
- if b.device != device or c.device != device or d.device != device:
4633
- raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
4634
-
4635
- if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
4636
- raise RuntimeError(
4637
- "wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
4638
- )
4639
-
4640
- if (
4641
- (not a.is_contiguous and not a.is_transposed)
4642
- or (not b.is_contiguous and not b.is_transposed)
4643
- or (not c.is_contiguous)
4644
- or (not d.is_contiguous)
4645
- ):
4646
- raise RuntimeError(
4647
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
4648
- )
4649
-
4650
- m = a.shape[0]
4651
- n = b.shape[1]
4652
- k = a.shape[1]
4653
- if b.shape != (k, n) or c.shape != (m, n) or d.shape != (m, n):
4654
- raise RuntimeError(
4655
- "Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
4656
- )
4657
-
4658
- if runtime.tape:
4659
- runtime.tape.record_func(
4660
- backward=lambda: adj_matmul(a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith),
4661
- arrays=[a, b, c, d],
4662
- )
4663
- if warp.config.verify_autograd_array_access:
4664
- d.mark_write()
4665
- a.mark_read()
4666
- b.mark_read()
4667
- c.mark_read()
4668
-
4669
- # cpu fallback if no cuda devices found
4670
- if device == "cpu":
4671
- np_dtype = warp_type_to_np_dtype[a.dtype]
4672
- d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
4673
- return
4674
-
4675
- cc = device.arch
4676
- ret = runtime.core.cutlass_gemm(
4677
- device.context,
4678
- cc,
4679
- m,
4680
- n,
4681
- k,
4682
- type_typestr(a.dtype).encode(),
4683
- ctypes.c_void_p(a.ptr),
4684
- ctypes.c_void_p(b.ptr),
4685
- ctypes.c_void_p(c.ptr),
4686
- ctypes.c_void_p(d.ptr),
4687
- alpha,
4688
- beta,
4689
- not a.is_transposed,
4690
- not b.is_transposed,
4691
- allow_tf32x3_arith,
4692
- 1,
4693
- )
4694
- if not ret:
4695
- raise RuntimeError("matmul failed.")
4726
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
4696
4727
 
4697
4728
 
4698
4729
  def adj_matmul(
@@ -4724,171 +4755,8 @@ def adj_matmul(
4724
4755
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
4725
4756
  while using Tensor Cores
4726
4757
  """
4727
- from warp.context import runtime
4728
-
4729
- device = a.device
4730
-
4731
- if (
4732
- b.device != device
4733
- or c.device != device
4734
- or adj_a.device != device
4735
- or adj_b.device != device
4736
- or adj_c.device != device
4737
- or adj_d.device != device
4738
- ):
4739
- raise RuntimeError(
4740
- "Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
4741
- )
4742
-
4743
- if (
4744
- a.dtype != b.dtype
4745
- or a.dtype != c.dtype
4746
- or a.dtype != adj_a.dtype
4747
- or a.dtype != adj_b.dtype
4748
- or a.dtype != adj_c.dtype
4749
- or a.dtype != adj_d.dtype
4750
- ):
4751
- raise RuntimeError(
4752
- "wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
4753
- )
4754
4758
 
4755
- if (
4756
- (not a.is_contiguous and not a.is_transposed)
4757
- or (not b.is_contiguous and not b.is_transposed)
4758
- or (not c.is_contiguous)
4759
- or (not adj_a.is_contiguous and not adj_a.is_transposed)
4760
- or (not adj_b.is_contiguous and not adj_b.is_transposed)
4761
- or (not adj_c.is_contiguous)
4762
- or (not adj_d.is_contiguous)
4763
- ):
4764
- raise RuntimeError(
4765
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
4766
- )
4767
-
4768
- m = a.shape[0]
4769
- n = b.shape[1]
4770
- k = a.shape[1]
4771
- if (
4772
- a.shape != (m, k)
4773
- or b.shape != (k, n)
4774
- or c.shape != (m, n)
4775
- or adj_d.shape != (m, n)
4776
- or adj_a.shape != (m, k)
4777
- or adj_b.shape != (k, n)
4778
- or adj_c.shape != (m, n)
4779
- ):
4780
- raise RuntimeError(
4781
- "Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
4782
- a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
4783
- )
4784
- )
4785
-
4786
- # cpu fallback if no cuda devices found
4787
- if device == "cpu":
4788
- np_dtype = warp_type_to_np_dtype[a.dtype]
4789
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose(), dtype=np_dtype) + adj_a.numpy())
4790
- adj_b.assign(alpha * np.matmul(a.numpy().transpose(), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
4791
- adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
4792
- return
4793
-
4794
- cc = device.arch
4795
-
4796
- # adj_a
4797
- if not a.is_transposed:
4798
- ret = runtime.core.cutlass_gemm(
4799
- device.context,
4800
- cc,
4801
- m,
4802
- k,
4803
- n,
4804
- type_typestr(a.dtype).encode(),
4805
- ctypes.c_void_p(adj_d.ptr),
4806
- ctypes.c_void_p(b.ptr),
4807
- ctypes.c_void_p(adj_a.ptr),
4808
- ctypes.c_void_p(adj_a.ptr),
4809
- alpha,
4810
- 1.0,
4811
- True,
4812
- b.is_transposed,
4813
- allow_tf32x3_arith,
4814
- 1,
4815
- )
4816
- if not ret:
4817
- raise RuntimeError("adj_matmul failed.")
4818
- else:
4819
- ret = runtime.core.cutlass_gemm(
4820
- device.context,
4821
- cc,
4822
- k,
4823
- m,
4824
- n,
4825
- type_typestr(a.dtype).encode(),
4826
- ctypes.c_void_p(b.ptr),
4827
- ctypes.c_void_p(adj_d.ptr),
4828
- ctypes.c_void_p(adj_a.ptr),
4829
- ctypes.c_void_p(adj_a.ptr),
4830
- alpha,
4831
- 1.0,
4832
- not b.is_transposed,
4833
- False,
4834
- allow_tf32x3_arith,
4835
- 1,
4836
- )
4837
- if not ret:
4838
- raise RuntimeError("adj_matmul failed.")
4839
-
4840
- # adj_b
4841
- if not b.is_transposed:
4842
- ret = runtime.core.cutlass_gemm(
4843
- device.context,
4844
- cc,
4845
- k,
4846
- n,
4847
- m,
4848
- type_typestr(a.dtype).encode(),
4849
- ctypes.c_void_p(a.ptr),
4850
- ctypes.c_void_p(adj_d.ptr),
4851
- ctypes.c_void_p(adj_b.ptr),
4852
- ctypes.c_void_p(adj_b.ptr),
4853
- alpha,
4854
- 1.0,
4855
- a.is_transposed,
4856
- True,
4857
- allow_tf32x3_arith,
4858
- 1,
4859
- )
4860
- if not ret:
4861
- raise RuntimeError("adj_matmul failed.")
4862
- else:
4863
- ret = runtime.core.cutlass_gemm(
4864
- device.context,
4865
- cc,
4866
- n,
4867
- k,
4868
- m,
4869
- type_typestr(a.dtype).encode(),
4870
- ctypes.c_void_p(adj_d.ptr),
4871
- ctypes.c_void_p(a.ptr),
4872
- ctypes.c_void_p(adj_b.ptr),
4873
- ctypes.c_void_p(adj_b.ptr),
4874
- alpha,
4875
- 1.0,
4876
- False,
4877
- not a.is_transposed,
4878
- allow_tf32x3_arith,
4879
- 1,
4880
- )
4881
- if not ret:
4882
- raise RuntimeError("adj_matmul failed.")
4883
-
4884
- # adj_c
4885
- warp.launch(
4886
- kernel=warp.utils.add_kernel_2d,
4887
- dim=adj_c.shape,
4888
- inputs=[adj_c, adj_d, adj_d.dtype(beta)],
4889
- device=device,
4890
- record_tape=False,
4891
- )
4759
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
4892
4760
 
4893
4761
 
4894
4762
  def batched_matmul(
@@ -4902,6 +4770,8 @@ def batched_matmul(
4902
4770
  ):
4903
4771
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
4904
4772
 
4773
+ .. versionremoved:: 1.7
4774
+
4905
4775
  .. deprecated:: 1.6
4906
4776
  Use :doc:`tile primitives </modules/tiles>` instead.
4907
4777
 
@@ -4915,107 +4785,8 @@ def batched_matmul(
4915
4785
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
4916
4786
  while using Tensor Cores
4917
4787
  """
4918
- from warp.context import runtime
4919
-
4920
- device = a.device
4921
4788
 
4922
- if b.device != device or c.device != device or d.device != device:
4923
- raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
4924
-
4925
- if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
4926
- raise RuntimeError(
4927
- "wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
4928
- )
4929
-
4930
- if (
4931
- (not a.is_contiguous and not a.is_transposed)
4932
- or (not b.is_contiguous and not b.is_transposed)
4933
- or (not c.is_contiguous)
4934
- or (not d.is_contiguous)
4935
- ):
4936
- raise RuntimeError(
4937
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
4938
- )
4939
-
4940
- m = a.shape[1]
4941
- n = b.shape[2]
4942
- k = a.shape[2]
4943
- batch_count = a.shape[0]
4944
- if b.shape != (batch_count, k, n) or c.shape != (batch_count, m, n) or d.shape != (batch_count, m, n):
4945
- raise RuntimeError(
4946
- "Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
4947
- )
4948
-
4949
- if runtime.tape:
4950
- runtime.tape.record_func(
4951
- backward=lambda: adj_batched_matmul(
4952
- a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith
4953
- ),
4954
- arrays=[a, b, c, d],
4955
- )
4956
- if warp.config.verify_autograd_array_access:
4957
- d.mark_write()
4958
- a.mark_read()
4959
- b.mark_read()
4960
- c.mark_read()
4961
-
4962
- # cpu fallback if no cuda devices found
4963
- if device == "cpu":
4964
- np_dtype = warp_type_to_np_dtype[a.dtype]
4965
- d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
4966
- return
4967
-
4968
- # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
4969
- max_batch_count = 65535
4970
- iters = int(batch_count / max_batch_count)
4971
- remainder = batch_count % max_batch_count
4972
-
4973
- cc = device.arch
4974
- for i in range(iters):
4975
- idx_start = i * max_batch_count
4976
- idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
4977
- ret = runtime.core.cutlass_gemm(
4978
- device.context,
4979
- cc,
4980
- m,
4981
- n,
4982
- k,
4983
- type_typestr(a.dtype).encode(),
4984
- ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
4985
- ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
4986
- ctypes.c_void_p(c[idx_start:idx_end, :, :].ptr),
4987
- ctypes.c_void_p(d[idx_start:idx_end, :, :].ptr),
4988
- alpha,
4989
- beta,
4990
- not a.is_transposed,
4991
- not b.is_transposed,
4992
- allow_tf32x3_arith,
4993
- max_batch_count,
4994
- )
4995
- if not ret:
4996
- raise RuntimeError("Batched matmul failed.")
4997
-
4998
- idx_start = iters * max_batch_count
4999
- ret = runtime.core.cutlass_gemm(
5000
- device.context,
5001
- cc,
5002
- m,
5003
- n,
5004
- k,
5005
- type_typestr(a.dtype).encode(),
5006
- ctypes.c_void_p(a[idx_start:, :, :].ptr),
5007
- ctypes.c_void_p(b[idx_start:, :, :].ptr),
5008
- ctypes.c_void_p(c[idx_start:, :, :].ptr),
5009
- ctypes.c_void_p(d[idx_start:, :, :].ptr),
5010
- alpha,
5011
- beta,
5012
- not a.is_transposed,
5013
- not b.is_transposed,
5014
- allow_tf32x3_arith,
5015
- remainder,
5016
- )
5017
- if not ret:
5018
- raise RuntimeError("Batched matmul failed.")
4789
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
5019
4790
 
5020
4791
 
5021
4792
  def adj_batched_matmul(
@@ -5045,270 +4816,8 @@ def adj_batched_matmul(
5045
4816
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
5046
4817
  while using Tensor Cores
5047
4818
  """
5048
- from warp.context import runtime
5049
-
5050
- device = a.device
5051
-
5052
- if (
5053
- b.device != device
5054
- or c.device != device
5055
- or adj_a.device != device
5056
- or adj_b.device != device
5057
- or adj_c.device != device
5058
- or adj_d.device != device
5059
- ):
5060
- raise RuntimeError(
5061
- "Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
5062
- )
5063
-
5064
- if (
5065
- a.dtype != b.dtype
5066
- or a.dtype != c.dtype
5067
- or a.dtype != adj_a.dtype
5068
- or a.dtype != adj_b.dtype
5069
- or a.dtype != adj_c.dtype
5070
- or a.dtype != adj_d.dtype
5071
- ):
5072
- raise RuntimeError(
5073
- "wp.adj_batched_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
5074
- )
5075
-
5076
- m = a.shape[1]
5077
- n = b.shape[2]
5078
- k = a.shape[2]
5079
- batch_count = a.shape[0]
5080
- if (
5081
- b.shape != (batch_count, k, n)
5082
- or c.shape != (batch_count, m, n)
5083
- or adj_d.shape != (batch_count, m, n)
5084
- or adj_a.shape != (batch_count, m, k)
5085
- or adj_b.shape != (batch_count, k, n)
5086
- or adj_c.shape != (batch_count, m, n)
5087
- ):
5088
- raise RuntimeError(
5089
- "Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
5090
- a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
5091
- )
5092
- )
5093
4819
 
5094
- if (
5095
- (not a.is_contiguous and not a.is_transposed)
5096
- or (not b.is_contiguous and not b.is_transposed)
5097
- or (not c.is_contiguous)
5098
- or (not adj_a.is_contiguous and not adj_a.is_transposed)
5099
- or (not adj_b.is_contiguous and not adj_b.is_transposed)
5100
- or (not adj_c.is_contiguous)
5101
- or (not adj_d.is_contiguous)
5102
- ):
5103
- raise RuntimeError(
5104
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
5105
- )
5106
-
5107
- # cpu fallback if no cuda devices found
5108
- if device == "cpu":
5109
- np_dtype = warp_type_to_np_dtype[a.dtype]
5110
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1)), dtype=np_dtype) + adj_a.numpy())
5111
- adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
5112
- adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
5113
- return
5114
-
5115
- # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
5116
- max_batch_count = 65535
5117
- iters = int(batch_count / max_batch_count)
5118
- remainder = batch_count % max_batch_count
5119
-
5120
- cc = device.arch
5121
-
5122
- for i in range(iters):
5123
- idx_start = i * max_batch_count
5124
- idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
5125
-
5126
- # adj_a
5127
- if not a.is_transposed:
5128
- ret = runtime.core.cutlass_gemm(
5129
- device.context,
5130
- cc,
5131
- m,
5132
- k,
5133
- n,
5134
- type_typestr(a.dtype).encode(),
5135
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5136
- ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
5137
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5138
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5139
- alpha,
5140
- 1.0,
5141
- True,
5142
- b.is_transposed,
5143
- allow_tf32x3_arith,
5144
- max_batch_count,
5145
- )
5146
- if not ret:
5147
- raise RuntimeError("adj_matmul failed.")
5148
- else:
5149
- ret = runtime.core.cutlass_gemm(
5150
- device.context,
5151
- cc,
5152
- k,
5153
- m,
5154
- n,
5155
- type_typestr(a.dtype).encode(),
5156
- ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
5157
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5158
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5159
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5160
- alpha,
5161
- 1.0,
5162
- not b.is_transposed,
5163
- False,
5164
- allow_tf32x3_arith,
5165
- max_batch_count,
5166
- )
5167
- if not ret:
5168
- raise RuntimeError("adj_matmul failed.")
5169
-
5170
- # adj_b
5171
- if not b.is_transposed:
5172
- ret = runtime.core.cutlass_gemm(
5173
- device.context,
5174
- cc,
5175
- k,
5176
- n,
5177
- m,
5178
- type_typestr(a.dtype).encode(),
5179
- ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
5180
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5181
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5182
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5183
- alpha,
5184
- 1.0,
5185
- a.is_transposed,
5186
- True,
5187
- allow_tf32x3_arith,
5188
- max_batch_count,
5189
- )
5190
- if not ret:
5191
- raise RuntimeError("adj_matmul failed.")
5192
- else:
5193
- ret = runtime.core.cutlass_gemm(
5194
- device.context,
5195
- cc,
5196
- n,
5197
- k,
5198
- m,
5199
- type_typestr(a.dtype).encode(),
5200
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5201
- ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
5202
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5203
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5204
- alpha,
5205
- 1.0,
5206
- False,
5207
- not a.is_transposed,
5208
- allow_tf32x3_arith,
5209
- max_batch_count,
5210
- )
5211
- if not ret:
5212
- raise RuntimeError("adj_matmul failed.")
5213
-
5214
- idx_start = iters * max_batch_count
5215
-
5216
- # adj_a
5217
- if not a.is_transposed:
5218
- ret = runtime.core.cutlass_gemm(
5219
- device.context,
5220
- cc,
5221
- m,
5222
- k,
5223
- n,
5224
- type_typestr(a.dtype).encode(),
5225
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5226
- ctypes.c_void_p(b[idx_start:, :, :].ptr),
5227
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5228
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5229
- alpha,
5230
- 1.0,
5231
- True,
5232
- b.is_transposed,
5233
- allow_tf32x3_arith,
5234
- remainder,
5235
- )
5236
- if not ret:
5237
- raise RuntimeError("adj_matmul failed.")
5238
- else:
5239
- ret = runtime.core.cutlass_gemm(
5240
- device.context,
5241
- cc,
5242
- k,
5243
- m,
5244
- n,
5245
- type_typestr(a.dtype).encode(),
5246
- ctypes.c_void_p(b[idx_start:, :, :].ptr),
5247
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5248
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5249
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5250
- alpha,
5251
- 1.0,
5252
- not b.is_transposed,
5253
- False,
5254
- allow_tf32x3_arith,
5255
- remainder,
5256
- )
5257
- if not ret:
5258
- raise RuntimeError("adj_matmul failed.")
5259
-
5260
- # adj_b
5261
- if not b.is_transposed:
5262
- ret = runtime.core.cutlass_gemm(
5263
- device.context,
5264
- cc,
5265
- k,
5266
- n,
5267
- m,
5268
- type_typestr(a.dtype).encode(),
5269
- ctypes.c_void_p(a[idx_start:, :, :].ptr),
5270
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5271
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5272
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5273
- alpha,
5274
- 1.0,
5275
- a.is_transposed,
5276
- True,
5277
- allow_tf32x3_arith,
5278
- remainder,
5279
- )
5280
- if not ret:
5281
- raise RuntimeError("adj_matmul failed.")
5282
- else:
5283
- ret = runtime.core.cutlass_gemm(
5284
- device.context,
5285
- cc,
5286
- n,
5287
- k,
5288
- m,
5289
- type_typestr(a.dtype).encode(),
5290
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5291
- ctypes.c_void_p(a[idx_start:, :, :].ptr),
5292
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5293
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5294
- alpha,
5295
- 1.0,
5296
- False,
5297
- not a.is_transposed,
5298
- allow_tf32x3_arith,
5299
- remainder,
5300
- )
5301
- if not ret:
5302
- raise RuntimeError("adj_matmul failed.")
5303
-
5304
- # adj_c
5305
- warp.launch(
5306
- kernel=warp.utils.add_kernel_3d,
5307
- dim=adj_c.shape,
5308
- inputs=[adj_c, adj_d, adj_d.dtype(beta)],
5309
- device=device,
5310
- record_tape=False,
5311
- )
4820
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
5312
4821
 
5313
4822
 
5314
4823
  class HashGrid:
@@ -5691,7 +5200,7 @@ simple_type_codes = {
5691
5200
  }
5692
5201
 
5693
5202
 
5694
- def get_type_code(arg_type):
5203
+ def get_type_code(arg_type: type) -> str:
5695
5204
  if arg_type == Any:
5696
5205
  # special case for generics
5697
5206
  # note: since Python 3.11 Any is a type, so we check for it first
@@ -5755,8 +5264,8 @@ def get_type_code(arg_type):
5755
5264
  raise TypeError(f"Unrecognized type '{arg_type}'")
5756
5265
 
5757
5266
 
5758
- def get_signature(arg_types, func_name=None, arg_names=None):
5759
- type_codes = []
5267
+ def get_signature(arg_types: List[type], func_name: Optional[str] = None, arg_names: Optional[List[str]] = None) -> str:
5268
+ type_codes: List[str] = []
5760
5269
  for i, arg_type in enumerate(arg_types):
5761
5270
  try:
5762
5271
  type_codes.append(get_type_code(arg_type))