warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -20,7 +20,21 @@ import ctypes
20
20
  import inspect
21
21
  import struct
22
22
  import zlib
23
- from typing import Any, Callable, Generic, List, Literal, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
23
+ from typing import (
24
+ Any,
25
+ Callable,
26
+ Generic,
27
+ List,
28
+ Literal,
29
+ NamedTuple,
30
+ Optional,
31
+ Sequence,
32
+ Tuple,
33
+ TypeVar,
34
+ Union,
35
+ get_args,
36
+ get_origin,
37
+ )
24
38
 
25
39
  import numpy as np
26
40
  import numpy.typing as npt
@@ -56,7 +70,9 @@ class Transformation(Generic[Float]):
56
70
 
57
71
 
58
72
  class Array(Generic[DType]):
59
- pass
73
+ device: Optional[warp.context.Device]
74
+ dtype: type
75
+ size: int
60
76
 
61
77
 
62
78
  int_tuple_type_hints = {
@@ -1139,7 +1155,7 @@ ARRAY_TYPE_FABRIC_INDEXED = 3
1139
1155
  class launch_bounds_t(ctypes.Structure):
1140
1156
  _fields_ = [("shape", ctypes.c_int32 * LAUNCH_MAX_DIMS), ("ndim", ctypes.c_int32), ("size", ctypes.c_size_t)]
1141
1157
 
1142
- def __init__(self, shape):
1158
+ def __init__(self, shape: Union[int, Sequence[int]]):
1143
1159
  if isinstance(shape, int):
1144
1160
  # 1d launch
1145
1161
  self.ndim = 1
@@ -1260,7 +1276,7 @@ _type_size_cache = {
1260
1276
  }
1261
1277
 
1262
1278
 
1263
- def type_size_in_bytes(dtype):
1279
+ def type_size_in_bytes(dtype: type) -> int:
1264
1280
  size = _type_size_cache.get(dtype)
1265
1281
 
1266
1282
  if size is None:
@@ -1279,7 +1295,7 @@ def type_size_in_bytes(dtype):
1279
1295
  return size
1280
1296
 
1281
1297
 
1282
- def type_to_warp(dtype):
1298
+ def type_to_warp(dtype: type) -> type:
1283
1299
  if dtype == float:
1284
1300
  return float32
1285
1301
  elif dtype == int:
@@ -1290,7 +1306,7 @@ def type_to_warp(dtype):
1290
1306
  return dtype
1291
1307
 
1292
1308
 
1293
- def type_typestr(dtype):
1309
+ def type_typestr(dtype: type) -> str:
1294
1310
  if dtype == bool:
1295
1311
  return "|b1"
1296
1312
  elif dtype == float16:
@@ -1376,29 +1392,29 @@ def type_is_transformation(t):
1376
1392
  return getattr(t, "_wp_generic_type_hint_", None) is Transformation
1377
1393
 
1378
1394
 
1379
- value_types = (int, float, builtins.bool) + scalar_types
1395
+ value_types = (int, float, builtins.bool) + scalar_and_bool_types
1380
1396
 
1381
1397
 
1382
1398
  # returns true for all value types (int, float, bool, scalars, vectors, matrices)
1383
- def type_is_value(x):
1399
+ def type_is_value(x: Any) -> builtins.bool:
1384
1400
  return x in value_types or hasattr(x, "_wp_scalar_type_")
1385
1401
 
1386
1402
 
1387
1403
  # equivalent of the above but for values
1388
- def is_int(x):
1404
+ def is_int(x: Any) -> builtins.bool:
1389
1405
  return type_is_int(type(x))
1390
1406
 
1391
1407
 
1392
- def is_float(x):
1408
+ def is_float(x: Any) -> builtins.bool:
1393
1409
  return type_is_float(type(x))
1394
1410
 
1395
1411
 
1396
- def is_value(x):
1412
+ def is_value(x: Any) -> builtins.bool:
1397
1413
  return type_is_value(type(x))
1398
1414
 
1399
1415
 
1400
- # returns true if the passed *instance* is one of the array types
1401
- def is_array(a):
1416
+ def is_array(a) -> builtins.bool:
1417
+ """Return true if the passed *instance* is one of the array types."""
1402
1418
  return isinstance(a, array_types)
1403
1419
 
1404
1420
 
@@ -1465,21 +1481,21 @@ def types_equal(a, b, match_generic=False):
1465
1481
  if a_length is None or b_length is None or a_length == b_length:
1466
1482
  return True
1467
1483
 
1468
- a_origin = warp.codegen.get_type_origin(a)
1469
- b_origin = warp.codegen.get_type_origin(b)
1484
+ a_origin = get_origin(a)
1485
+ b_origin = get_origin(b)
1470
1486
  if a_origin is tuple and b_origin is tuple:
1471
- a_args = warp.codegen.get_type_args(a)
1472
- b_args = warp.codegen.get_type_args(b)
1487
+ a_args = get_args(a)
1488
+ b_args = get_args(b)
1473
1489
  if len(a_args) == len(b_args) and all(
1474
1490
  scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b_args)
1475
1491
  ):
1476
1492
  return True
1477
1493
  elif a_origin is tuple and isinstance(b, Sequence):
1478
- a_args = warp.codegen.get_type_args(a)
1494
+ a_args = get_args(a)
1479
1495
  if len(a_args) == len(b) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b)):
1480
1496
  return True
1481
1497
  elif b_origin is tuple and isinstance(a, Sequence):
1482
- b_args = warp.codegen.get_type_args(b)
1498
+ b_args = get_args(b)
1483
1499
  if len(b_args) == len(a) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(b_args, a)):
1484
1500
  return True
1485
1501
 
@@ -1600,7 +1616,7 @@ def array_ctype_from_interface(interface: dict, dtype=None, owner=None):
1600
1616
  return array_ctype
1601
1617
 
1602
1618
 
1603
- class array(Array):
1619
+ class array(Array[DType]):
1604
1620
  """A fixed-size multi-dimensional array containing values of the same type.
1605
1621
 
1606
1622
  Attributes:
@@ -1629,21 +1645,21 @@ class array(Array):
1629
1645
 
1630
1646
  def __init__(
1631
1647
  self,
1632
- data: Optional[Union[List, Tuple, npt.NDArray]] = None,
1633
- dtype: Union[DType, Any] = Any,
1634
- shape: Optional[Tuple[int, ...]] = None,
1648
+ data: Union[List, Tuple, npt.NDArray, None] = None,
1649
+ dtype: Any = Any,
1650
+ shape: Union[int, Tuple[int, ...], List[int], None] = None,
1635
1651
  strides: Optional[Tuple[int, ...]] = None,
1636
1652
  length: Optional[int] = None,
1637
1653
  ptr: Optional[int] = None,
1638
1654
  capacity: Optional[int] = None,
1639
1655
  device=None,
1640
- pinned: bool = False,
1641
- copy: bool = True,
1642
- owner: bool = False, # deprecated - pass deleter instead
1656
+ pinned: builtins.bool = False,
1657
+ copy: builtins.bool = True,
1658
+ owner: builtins.bool = False, # deprecated - pass deleter instead
1643
1659
  deleter: Optional[Callable[[int, int], None]] = None,
1644
1660
  ndim: Optional[int] = None,
1645
1661
  grad: Optional[array] = None,
1646
- requires_grad: bool = False,
1662
+ requires_grad: builtins.bool = False,
1647
1663
  ):
1648
1664
  """Constructs a new Warp array object
1649
1665
 
@@ -2939,7 +2955,7 @@ def from_ipc_handle(
2939
2955
 
2940
2956
  # A base class for non-contiguous arrays, providing the implementation of common methods like
2941
2957
  # contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
2942
- class noncontiguous_array_base(Generic[T]):
2958
+ class noncontiguous_array_base(Array[T]):
2943
2959
  def __init__(self, array_type_id):
2944
2960
  self.type_id = array_type_id
2945
2961
  self.is_contiguous = False
@@ -3036,12 +3052,18 @@ def check_index_array(indices, expected_device):
3036
3052
  raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
3037
3053
 
3038
3054
 
3039
- class indexedarray(noncontiguous_array_base[T]):
3055
+ class indexedarray(noncontiguous_array_base):
3040
3056
  # member attributes available during code-gen (e.g.: d = arr.shape[0])
3041
3057
  # (initialized when needed)
3042
3058
  _vars = None
3043
3059
 
3044
- def __init__(self, data: array = None, indices: Union[array, List[array]] = None, dtype=None, ndim=None):
3060
+ def __init__(
3061
+ self,
3062
+ data: Optional[array] = None,
3063
+ indices: Union[array, List[array], None] = None,
3064
+ dtype=None,
3065
+ ndim: Optional[int] = None,
3066
+ ):
3045
3067
  super().__init__(ARRAY_TYPE_INDEXED)
3046
3068
 
3047
3069
  # canonicalize types
@@ -3232,7 +3254,7 @@ class Tile:
3232
3254
  return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},wp::tile_shape_t<{','.join(map(str, self.shape))}>,{'true' if requires_grad else 'false'}>()"
3233
3255
  else:
3234
3256
  # tile will be initialized by another call, e.g.: tile_transpose()
3235
- return "NULL"
3257
+ return "nullptr"
3236
3258
 
3237
3259
  # return total tile size in bytes
3238
3260
  def size_in_bytes(self):
@@ -3634,7 +3656,7 @@ class Volume:
3634
3656
  instance.id = None
3635
3657
  return instance
3636
3658
 
3637
- def __init__(self, data: array, copy: bool = True):
3659
+ def __init__(self, data: array, copy: builtins.bool = True):
3638
3660
  """Class representing a sparse grid.
3639
3661
 
3640
3662
  Args:
@@ -4361,6 +4383,15 @@ class Volume:
4361
4383
  translation_buf = (ctypes.c_float * 3)(translation[0], translation[1], translation[2])
4362
4384
  return transform_buf, translation_buf
4363
4385
 
4386
+ # nanovdb types for which we instantiate the grid builder
4387
+ # Should be in sync with WP_VOLUME_BUILDER_INSTANTIATE_TYPES in volume_builder.h
4388
+ _supported_allocation_types = [
4389
+ "int32",
4390
+ "float",
4391
+ "Vec3f",
4392
+ "Vec4f",
4393
+ ]
4394
+
4364
4395
  @classmethod
4365
4396
  def allocate_by_tiles(
4366
4397
  cls,
@@ -4388,7 +4419,8 @@ class Volume:
4388
4419
  or a floating point scalar type (2D N-by-3 array of :class:`warp.float32` or 1D array of `warp.vec3f` values), indicating world space positions.
4389
4420
  Repeated points per tile are allowed and will be efficiently deduplicated.
4390
4421
  voxel_size (float or array-like): Voxel size(s) of the new volume. Ignored if `transform` is given.
4391
- bg_value (array-like, float, int or None): Value of unallocated voxels of the volume, also defines the volume's type. A :class:`warp.vec3` volume is created if this is `array-like`, an index volume will be created if `bg_value` is ``None``.
4422
+ bg_value (array-like, scalar or None): Value of unallocated voxels of the volume, also defines the volume's type. An index volume will be created if `bg_value` is ``None``.
4423
+ Other supported grid types are `int`, `float`, `vec3f`, and `vec4f`.
4392
4424
  translation (array-like): Translation between the index and world spaces.
4393
4425
  transform (array-like): Linear transform between the index and world spaces. If ``None``, deduced from `voxel_size`.
4394
4426
  device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
@@ -4420,35 +4452,47 @@ class Volume:
4420
4452
  translation_buf,
4421
4453
  in_world_space,
4422
4454
  )
4423
- elif hasattr(bg_value, "__len__"):
4424
- volume.id = volume.runtime.core.volume_v_from_tiles_device(
4425
- volume.device.context,
4426
- ctypes.c_void_p(tile_points.ptr),
4427
- tile_points.shape[0],
4428
- transform_buf,
4429
- translation_buf,
4430
- in_world_space,
4431
- (ctypes.c_float * 3)(bg_value[0], bg_value[1], bg_value[2]),
4432
- )
4433
- elif isinstance(bg_value, int):
4434
- volume.id = volume.runtime.core.volume_i_from_tiles_device(
4435
- volume.device.context,
4436
- ctypes.c_void_p(tile_points.ptr),
4437
- tile_points.shape[0],
4438
- transform_buf,
4439
- translation_buf,
4440
- in_world_space,
4441
- bg_value,
4442
- )
4443
4455
  else:
4444
- volume.id = volume.runtime.core.volume_f_from_tiles_device(
4456
+ # normalize background value type
4457
+ grid_type = type_to_warp(type(bg_value))
4458
+ if not (is_value(bg_value) or type_is_vector(grid_type)) and (
4459
+ hasattr(bg_value, "__len__") and is_value(bg_value[0])
4460
+ ):
4461
+ # non-warp vectors are considered float, for backward compatibility
4462
+ grid_type = vector(len(bg_value), dtype=float)
4463
+
4464
+ # look for corresponding nvdb type
4465
+ try:
4466
+ nvdb_type = next(
4467
+ typ
4468
+ for typ in Volume._supported_allocation_types
4469
+ if types_equal(grid_type, Volume._nvdb_type_to_dtype[typ])
4470
+ )
4471
+ except StopIteration as err:
4472
+ raise TypeError(
4473
+ f"Unsupported bg_value type for volume allocation {type_repr(grid_type)}. Supported volume types are {', '.join(Volume._supported_allocation_types)}."
4474
+ ) from err
4475
+
4476
+ # cast to ctype
4477
+ # wrap scalar values in length-1 vectors to handle specific ctype conversion
4478
+ if not type_is_vector(grid_type):
4479
+ grid_type = vector(length=1, dtype=grid_type)
4480
+
4481
+ cvalue = grid_type(bg_value)
4482
+ cvalue_ptr = ctypes.pointer(cvalue)
4483
+ cvalue_size = ctypes.sizeof(cvalue)
4484
+ cvalue_type = nvdb_type.encode("ascii")
4485
+
4486
+ volume.id = volume.runtime.core.volume_from_tiles_device(
4445
4487
  volume.device.context,
4446
4488
  ctypes.c_void_p(tile_points.ptr),
4447
4489
  tile_points.shape[0],
4448
4490
  transform_buf,
4449
4491
  translation_buf,
4450
4492
  in_world_space,
4451
- float(bg_value),
4493
+ cvalue_ptr,
4494
+ cvalue_size,
4495
+ cvalue_type,
4452
4496
  )
4453
4497
 
4454
4498
  if volume.id == 0:
@@ -4606,6 +4650,8 @@ def matmul(
4606
4650
  ):
4607
4651
  """Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
4608
4652
 
4653
+ .. versionremoved:: 1.7
4654
+
4609
4655
  .. deprecated:: 1.6
4610
4656
  Use :doc:`tile primitives </modules/tiles>` instead.
4611
4657
 
@@ -4619,80 +4665,8 @@ def matmul(
4619
4665
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
4620
4666
  while using Tensor Cores
4621
4667
  """
4622
- from warp.context import runtime
4623
-
4624
- warp.utils.warn(
4625
- "wp.matmul() is deprecated and will be removed in a\nfuture version. Use tile primitives instead.",
4626
- category=DeprecationWarning,
4627
- stacklevel=2,
4628
- )
4629
-
4630
- device = a.device
4631
-
4632
- if b.device != device or c.device != device or d.device != device:
4633
- raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
4634
-
4635
- if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
4636
- raise RuntimeError(
4637
- "wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
4638
- )
4639
-
4640
- if (
4641
- (not a.is_contiguous and not a.is_transposed)
4642
- or (not b.is_contiguous and not b.is_transposed)
4643
- or (not c.is_contiguous)
4644
- or (not d.is_contiguous)
4645
- ):
4646
- raise RuntimeError(
4647
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
4648
- )
4649
4668
 
4650
- m = a.shape[0]
4651
- n = b.shape[1]
4652
- k = a.shape[1]
4653
- if b.shape != (k, n) or c.shape != (m, n) or d.shape != (m, n):
4654
- raise RuntimeError(
4655
- "Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
4656
- )
4657
-
4658
- if runtime.tape:
4659
- runtime.tape.record_func(
4660
- backward=lambda: adj_matmul(a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith),
4661
- arrays=[a, b, c, d],
4662
- )
4663
- if warp.config.verify_autograd_array_access:
4664
- d.mark_write()
4665
- a.mark_read()
4666
- b.mark_read()
4667
- c.mark_read()
4668
-
4669
- # cpu fallback if no cuda devices found
4670
- if device == "cpu":
4671
- np_dtype = warp_type_to_np_dtype[a.dtype]
4672
- d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
4673
- return
4674
-
4675
- cc = device.arch
4676
- ret = runtime.core.cutlass_gemm(
4677
- device.context,
4678
- cc,
4679
- m,
4680
- n,
4681
- k,
4682
- type_typestr(a.dtype).encode(),
4683
- ctypes.c_void_p(a.ptr),
4684
- ctypes.c_void_p(b.ptr),
4685
- ctypes.c_void_p(c.ptr),
4686
- ctypes.c_void_p(d.ptr),
4687
- alpha,
4688
- beta,
4689
- not a.is_transposed,
4690
- not b.is_transposed,
4691
- allow_tf32x3_arith,
4692
- 1,
4693
- )
4694
- if not ret:
4695
- raise RuntimeError("matmul failed.")
4669
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
4696
4670
 
4697
4671
 
4698
4672
  def adj_matmul(
@@ -4724,171 +4698,8 @@ def adj_matmul(
4724
4698
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
4725
4699
  while using Tensor Cores
4726
4700
  """
4727
- from warp.context import runtime
4728
-
4729
- device = a.device
4730
-
4731
- if (
4732
- b.device != device
4733
- or c.device != device
4734
- or adj_a.device != device
4735
- or adj_b.device != device
4736
- or adj_c.device != device
4737
- or adj_d.device != device
4738
- ):
4739
- raise RuntimeError(
4740
- "Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
4741
- )
4742
-
4743
- if (
4744
- a.dtype != b.dtype
4745
- or a.dtype != c.dtype
4746
- or a.dtype != adj_a.dtype
4747
- or a.dtype != adj_b.dtype
4748
- or a.dtype != adj_c.dtype
4749
- or a.dtype != adj_d.dtype
4750
- ):
4751
- raise RuntimeError(
4752
- "wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
4753
- )
4754
-
4755
- if (
4756
- (not a.is_contiguous and not a.is_transposed)
4757
- or (not b.is_contiguous and not b.is_transposed)
4758
- or (not c.is_contiguous)
4759
- or (not adj_a.is_contiguous and not adj_a.is_transposed)
4760
- or (not adj_b.is_contiguous and not adj_b.is_transposed)
4761
- or (not adj_c.is_contiguous)
4762
- or (not adj_d.is_contiguous)
4763
- ):
4764
- raise RuntimeError(
4765
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
4766
- )
4767
4701
 
4768
- m = a.shape[0]
4769
- n = b.shape[1]
4770
- k = a.shape[1]
4771
- if (
4772
- a.shape != (m, k)
4773
- or b.shape != (k, n)
4774
- or c.shape != (m, n)
4775
- or adj_d.shape != (m, n)
4776
- or adj_a.shape != (m, k)
4777
- or adj_b.shape != (k, n)
4778
- or adj_c.shape != (m, n)
4779
- ):
4780
- raise RuntimeError(
4781
- "Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
4782
- a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
4783
- )
4784
- )
4785
-
4786
- # cpu fallback if no cuda devices found
4787
- if device == "cpu":
4788
- np_dtype = warp_type_to_np_dtype[a.dtype]
4789
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose(), dtype=np_dtype) + adj_a.numpy())
4790
- adj_b.assign(alpha * np.matmul(a.numpy().transpose(), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
4791
- adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
4792
- return
4793
-
4794
- cc = device.arch
4795
-
4796
- # adj_a
4797
- if not a.is_transposed:
4798
- ret = runtime.core.cutlass_gemm(
4799
- device.context,
4800
- cc,
4801
- m,
4802
- k,
4803
- n,
4804
- type_typestr(a.dtype).encode(),
4805
- ctypes.c_void_p(adj_d.ptr),
4806
- ctypes.c_void_p(b.ptr),
4807
- ctypes.c_void_p(adj_a.ptr),
4808
- ctypes.c_void_p(adj_a.ptr),
4809
- alpha,
4810
- 1.0,
4811
- True,
4812
- b.is_transposed,
4813
- allow_tf32x3_arith,
4814
- 1,
4815
- )
4816
- if not ret:
4817
- raise RuntimeError("adj_matmul failed.")
4818
- else:
4819
- ret = runtime.core.cutlass_gemm(
4820
- device.context,
4821
- cc,
4822
- k,
4823
- m,
4824
- n,
4825
- type_typestr(a.dtype).encode(),
4826
- ctypes.c_void_p(b.ptr),
4827
- ctypes.c_void_p(adj_d.ptr),
4828
- ctypes.c_void_p(adj_a.ptr),
4829
- ctypes.c_void_p(adj_a.ptr),
4830
- alpha,
4831
- 1.0,
4832
- not b.is_transposed,
4833
- False,
4834
- allow_tf32x3_arith,
4835
- 1,
4836
- )
4837
- if not ret:
4838
- raise RuntimeError("adj_matmul failed.")
4839
-
4840
- # adj_b
4841
- if not b.is_transposed:
4842
- ret = runtime.core.cutlass_gemm(
4843
- device.context,
4844
- cc,
4845
- k,
4846
- n,
4847
- m,
4848
- type_typestr(a.dtype).encode(),
4849
- ctypes.c_void_p(a.ptr),
4850
- ctypes.c_void_p(adj_d.ptr),
4851
- ctypes.c_void_p(adj_b.ptr),
4852
- ctypes.c_void_p(adj_b.ptr),
4853
- alpha,
4854
- 1.0,
4855
- a.is_transposed,
4856
- True,
4857
- allow_tf32x3_arith,
4858
- 1,
4859
- )
4860
- if not ret:
4861
- raise RuntimeError("adj_matmul failed.")
4862
- else:
4863
- ret = runtime.core.cutlass_gemm(
4864
- device.context,
4865
- cc,
4866
- n,
4867
- k,
4868
- m,
4869
- type_typestr(a.dtype).encode(),
4870
- ctypes.c_void_p(adj_d.ptr),
4871
- ctypes.c_void_p(a.ptr),
4872
- ctypes.c_void_p(adj_b.ptr),
4873
- ctypes.c_void_p(adj_b.ptr),
4874
- alpha,
4875
- 1.0,
4876
- False,
4877
- not a.is_transposed,
4878
- allow_tf32x3_arith,
4879
- 1,
4880
- )
4881
- if not ret:
4882
- raise RuntimeError("adj_matmul failed.")
4883
-
4884
- # adj_c
4885
- warp.launch(
4886
- kernel=warp.utils.add_kernel_2d,
4887
- dim=adj_c.shape,
4888
- inputs=[adj_c, adj_d, adj_d.dtype(beta)],
4889
- device=device,
4890
- record_tape=False,
4891
- )
4702
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
4892
4703
 
4893
4704
 
4894
4705
  def batched_matmul(
@@ -4902,6 +4713,8 @@ def batched_matmul(
4902
4713
  ):
4903
4714
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
4904
4715
 
4716
+ .. versionremoved:: 1.7
4717
+
4905
4718
  .. deprecated:: 1.6
4906
4719
  Use :doc:`tile primitives </modules/tiles>` instead.
4907
4720
 
@@ -4915,107 +4728,8 @@ def batched_matmul(
4915
4728
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
4916
4729
  while using Tensor Cores
4917
4730
  """
4918
- from warp.context import runtime
4919
-
4920
- device = a.device
4921
-
4922
- if b.device != device or c.device != device or d.device != device:
4923
- raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
4924
-
4925
- if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
4926
- raise RuntimeError(
4927
- "wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
4928
- )
4929
-
4930
- if (
4931
- (not a.is_contiguous and not a.is_transposed)
4932
- or (not b.is_contiguous and not b.is_transposed)
4933
- or (not c.is_contiguous)
4934
- or (not d.is_contiguous)
4935
- ):
4936
- raise RuntimeError(
4937
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
4938
- )
4939
-
4940
- m = a.shape[1]
4941
- n = b.shape[2]
4942
- k = a.shape[2]
4943
- batch_count = a.shape[0]
4944
- if b.shape != (batch_count, k, n) or c.shape != (batch_count, m, n) or d.shape != (batch_count, m, n):
4945
- raise RuntimeError(
4946
- "Invalid shapes for matrices: A = {} B = {} C = {} D = {}".format(a.shape, b.shape, c.shape, d.shape)
4947
- )
4948
4731
 
4949
- if runtime.tape:
4950
- runtime.tape.record_func(
4951
- backward=lambda: adj_batched_matmul(
4952
- a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith
4953
- ),
4954
- arrays=[a, b, c, d],
4955
- )
4956
- if warp.config.verify_autograd_array_access:
4957
- d.mark_write()
4958
- a.mark_read()
4959
- b.mark_read()
4960
- c.mark_read()
4961
-
4962
- # cpu fallback if no cuda devices found
4963
- if device == "cpu":
4964
- np_dtype = warp_type_to_np_dtype[a.dtype]
4965
- d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
4966
- return
4967
-
4968
- # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
4969
- max_batch_count = 65535
4970
- iters = int(batch_count / max_batch_count)
4971
- remainder = batch_count % max_batch_count
4972
-
4973
- cc = device.arch
4974
- for i in range(iters):
4975
- idx_start = i * max_batch_count
4976
- idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
4977
- ret = runtime.core.cutlass_gemm(
4978
- device.context,
4979
- cc,
4980
- m,
4981
- n,
4982
- k,
4983
- type_typestr(a.dtype).encode(),
4984
- ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
4985
- ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
4986
- ctypes.c_void_p(c[idx_start:idx_end, :, :].ptr),
4987
- ctypes.c_void_p(d[idx_start:idx_end, :, :].ptr),
4988
- alpha,
4989
- beta,
4990
- not a.is_transposed,
4991
- not b.is_transposed,
4992
- allow_tf32x3_arith,
4993
- max_batch_count,
4994
- )
4995
- if not ret:
4996
- raise RuntimeError("Batched matmul failed.")
4997
-
4998
- idx_start = iters * max_batch_count
4999
- ret = runtime.core.cutlass_gemm(
5000
- device.context,
5001
- cc,
5002
- m,
5003
- n,
5004
- k,
5005
- type_typestr(a.dtype).encode(),
5006
- ctypes.c_void_p(a[idx_start:, :, :].ptr),
5007
- ctypes.c_void_p(b[idx_start:, :, :].ptr),
5008
- ctypes.c_void_p(c[idx_start:, :, :].ptr),
5009
- ctypes.c_void_p(d[idx_start:, :, :].ptr),
5010
- alpha,
5011
- beta,
5012
- not a.is_transposed,
5013
- not b.is_transposed,
5014
- allow_tf32x3_arith,
5015
- remainder,
5016
- )
5017
- if not ret:
5018
- raise RuntimeError("Batched matmul failed.")
4732
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
5019
4733
 
5020
4734
 
5021
4735
  def adj_batched_matmul(
@@ -5045,270 +4759,8 @@ def adj_batched_matmul(
5045
4759
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
5046
4760
  while using Tensor Cores
5047
4761
  """
5048
- from warp.context import runtime
5049
4762
 
5050
- device = a.device
5051
-
5052
- if (
5053
- b.device != device
5054
- or c.device != device
5055
- or adj_a.device != device
5056
- or adj_b.device != device
5057
- or adj_c.device != device
5058
- or adj_d.device != device
5059
- ):
5060
- raise RuntimeError(
5061
- "Matrices A, B, C, D, and their adjoints must all be on the same device as the runtime device."
5062
- )
5063
-
5064
- if (
5065
- a.dtype != b.dtype
5066
- or a.dtype != c.dtype
5067
- or a.dtype != adj_a.dtype
5068
- or a.dtype != adj_b.dtype
5069
- or a.dtype != adj_c.dtype
5070
- or a.dtype != adj_d.dtype
5071
- ):
5072
- raise RuntimeError(
5073
- "wp.adj_batched_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
5074
- )
5075
-
5076
- m = a.shape[1]
5077
- n = b.shape[2]
5078
- k = a.shape[2]
5079
- batch_count = a.shape[0]
5080
- if (
5081
- b.shape != (batch_count, k, n)
5082
- or c.shape != (batch_count, m, n)
5083
- or adj_d.shape != (batch_count, m, n)
5084
- or adj_a.shape != (batch_count, m, k)
5085
- or adj_b.shape != (batch_count, k, n)
5086
- or adj_c.shape != (batch_count, m, n)
5087
- ):
5088
- raise RuntimeError(
5089
- "Invalid shapes for matrices: A = {} B = {} C = {} adj_D = {} adj_A = {} adj_B = {} adj_C = {}".format(
5090
- a.shape, b.shape, c.shape, adj_d.shape, adj_a.shape, adj_b.shape, adj_c.shape
5091
- )
5092
- )
5093
-
5094
- if (
5095
- (not a.is_contiguous and not a.is_transposed)
5096
- or (not b.is_contiguous and not b.is_transposed)
5097
- or (not c.is_contiguous)
5098
- or (not adj_a.is_contiguous and not adj_a.is_transposed)
5099
- or (not adj_b.is_contiguous and not adj_b.is_transposed)
5100
- or (not adj_c.is_contiguous)
5101
- or (not adj_d.is_contiguous)
5102
- ):
5103
- raise RuntimeError(
5104
- "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
5105
- )
5106
-
5107
- # cpu fallback if no cuda devices found
5108
- if device == "cpu":
5109
- np_dtype = warp_type_to_np_dtype[a.dtype]
5110
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1)), dtype=np_dtype) + adj_a.numpy())
5111
- adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
5112
- adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
5113
- return
5114
-
5115
- # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
5116
- max_batch_count = 65535
5117
- iters = int(batch_count / max_batch_count)
5118
- remainder = batch_count % max_batch_count
5119
-
5120
- cc = device.arch
5121
-
5122
- for i in range(iters):
5123
- idx_start = i * max_batch_count
5124
- idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
5125
-
5126
- # adj_a
5127
- if not a.is_transposed:
5128
- ret = runtime.core.cutlass_gemm(
5129
- device.context,
5130
- cc,
5131
- m,
5132
- k,
5133
- n,
5134
- type_typestr(a.dtype).encode(),
5135
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5136
- ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
5137
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5138
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5139
- alpha,
5140
- 1.0,
5141
- True,
5142
- b.is_transposed,
5143
- allow_tf32x3_arith,
5144
- max_batch_count,
5145
- )
5146
- if not ret:
5147
- raise RuntimeError("adj_matmul failed.")
5148
- else:
5149
- ret = runtime.core.cutlass_gemm(
5150
- device.context,
5151
- cc,
5152
- k,
5153
- m,
5154
- n,
5155
- type_typestr(a.dtype).encode(),
5156
- ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
5157
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5158
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5159
- ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
5160
- alpha,
5161
- 1.0,
5162
- not b.is_transposed,
5163
- False,
5164
- allow_tf32x3_arith,
5165
- max_batch_count,
5166
- )
5167
- if not ret:
5168
- raise RuntimeError("adj_matmul failed.")
5169
-
5170
- # adj_b
5171
- if not b.is_transposed:
5172
- ret = runtime.core.cutlass_gemm(
5173
- device.context,
5174
- cc,
5175
- k,
5176
- n,
5177
- m,
5178
- type_typestr(a.dtype).encode(),
5179
- ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
5180
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5181
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5182
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5183
- alpha,
5184
- 1.0,
5185
- a.is_transposed,
5186
- True,
5187
- allow_tf32x3_arith,
5188
- max_batch_count,
5189
- )
5190
- if not ret:
5191
- raise RuntimeError("adj_matmul failed.")
5192
- else:
5193
- ret = runtime.core.cutlass_gemm(
5194
- device.context,
5195
- cc,
5196
- n,
5197
- k,
5198
- m,
5199
- type_typestr(a.dtype).encode(),
5200
- ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
5201
- ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
5202
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5203
- ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
5204
- alpha,
5205
- 1.0,
5206
- False,
5207
- not a.is_transposed,
5208
- allow_tf32x3_arith,
5209
- max_batch_count,
5210
- )
5211
- if not ret:
5212
- raise RuntimeError("adj_matmul failed.")
5213
-
5214
- idx_start = iters * max_batch_count
5215
-
5216
- # adj_a
5217
- if not a.is_transposed:
5218
- ret = runtime.core.cutlass_gemm(
5219
- device.context,
5220
- cc,
5221
- m,
5222
- k,
5223
- n,
5224
- type_typestr(a.dtype).encode(),
5225
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5226
- ctypes.c_void_p(b[idx_start:, :, :].ptr),
5227
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5228
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5229
- alpha,
5230
- 1.0,
5231
- True,
5232
- b.is_transposed,
5233
- allow_tf32x3_arith,
5234
- remainder,
5235
- )
5236
- if not ret:
5237
- raise RuntimeError("adj_matmul failed.")
5238
- else:
5239
- ret = runtime.core.cutlass_gemm(
5240
- device.context,
5241
- cc,
5242
- k,
5243
- m,
5244
- n,
5245
- type_typestr(a.dtype).encode(),
5246
- ctypes.c_void_p(b[idx_start:, :, :].ptr),
5247
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5248
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5249
- ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
5250
- alpha,
5251
- 1.0,
5252
- not b.is_transposed,
5253
- False,
5254
- allow_tf32x3_arith,
5255
- remainder,
5256
- )
5257
- if not ret:
5258
- raise RuntimeError("adj_matmul failed.")
5259
-
5260
- # adj_b
5261
- if not b.is_transposed:
5262
- ret = runtime.core.cutlass_gemm(
5263
- device.context,
5264
- cc,
5265
- k,
5266
- n,
5267
- m,
5268
- type_typestr(a.dtype).encode(),
5269
- ctypes.c_void_p(a[idx_start:, :, :].ptr),
5270
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5271
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5272
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5273
- alpha,
5274
- 1.0,
5275
- a.is_transposed,
5276
- True,
5277
- allow_tf32x3_arith,
5278
- remainder,
5279
- )
5280
- if not ret:
5281
- raise RuntimeError("adj_matmul failed.")
5282
- else:
5283
- ret = runtime.core.cutlass_gemm(
5284
- device.context,
5285
- cc,
5286
- n,
5287
- k,
5288
- m,
5289
- type_typestr(a.dtype).encode(),
5290
- ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
5291
- ctypes.c_void_p(a[idx_start:, :, :].ptr),
5292
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5293
- ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
5294
- alpha,
5295
- 1.0,
5296
- False,
5297
- not a.is_transposed,
5298
- allow_tf32x3_arith,
5299
- remainder,
5300
- )
5301
- if not ret:
5302
- raise RuntimeError("adj_matmul failed.")
5303
-
5304
- # adj_c
5305
- warp.launch(
5306
- kernel=warp.utils.add_kernel_3d,
5307
- dim=adj_c.shape,
5308
- inputs=[adj_c, adj_d, adj_d.dtype(beta)],
5309
- device=device,
5310
- record_tape=False,
5311
- )
4763
+ raise RuntimeError("This function has been removed. Use tile primitives instead.")
5312
4764
 
5313
4765
 
5314
4766
  class HashGrid:
@@ -5691,7 +5143,7 @@ simple_type_codes = {
5691
5143
  }
5692
5144
 
5693
5145
 
5694
- def get_type_code(arg_type):
5146
+ def get_type_code(arg_type: type) -> str:
5695
5147
  if arg_type == Any:
5696
5148
  # special case for generics
5697
5149
  # note: since Python 3.11 Any is a type, so we check for it first
@@ -5755,8 +5207,8 @@ def get_type_code(arg_type):
5755
5207
  raise TypeError(f"Unrecognized type '{arg_type}'")
5756
5208
 
5757
5209
 
5758
- def get_signature(arg_types, func_name=None, arg_names=None):
5759
- type_codes = []
5210
+ def get_signature(arg_types: List[type], func_name: Optional[str] = None, arg_names: Optional[List[str]] = None) -> str:
5211
+ type_codes: List[str] = []
5760
5212
  for i, arg_type in enumerate(arg_types):
5761
5213
  try:
5762
5214
  type_codes.append(get_type_code(arg_type))