warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -98,7 +98,13 @@ def vector(length, dtype):
98
98
  # ctypes.Array data for length, shape and c type:
99
99
  _length_ = 0 if length is Any else length
100
100
  _shape_ = (_length_,)
101
- _type_ = ctypes.c_float if dtype in [Scalar, Float] else dtype._type_
101
+
102
+ if dtype is bool:
103
+ _type_ = ctypes.c_bool
104
+ elif dtype in [Scalar, Float]:
105
+ _type_ = ctypes.c_float
106
+ else:
107
+ _type_ = dtype._type_
102
108
 
103
109
  # warp scalar type:
104
110
  _wp_scalar_type_ = dtype
@@ -271,7 +277,13 @@ def matrix(shape, dtype):
271
277
  class mat_t(ctypes.Array):
272
278
  _length_ = 0 if shape[0] == Any or shape[1] == Any else shape[0] * shape[1]
273
279
  _shape_ = (0, 0) if _length_ == 0 else shape
274
- _type_ = ctypes.c_float if dtype in [Scalar, Float] else dtype._type_
280
+
281
+ if dtype is bool:
282
+ _type_ = ctypes.c_bool
283
+ elif dtype in [Scalar, Float]:
284
+ _type_ = ctypes.c_float
285
+ else:
286
+ _type_ = dtype._type_
275
287
 
276
288
  # warp scalar type:
277
289
  # used in type checking and when writing out c++ code for constructors:
@@ -391,8 +403,7 @@ def matrix(shape, dtype):
391
403
  iter(v)
392
404
  except TypeError:
393
405
  raise TypeError(
394
- f"Expected to assign a slice from a sequence of values "
395
- f"but got `{type(v).__name__}` instead"
406
+ f"Expected to assign a slice from a sequence of values " f"but got `{type(v).__name__}` instead"
396
407
  ) from None
397
408
 
398
409
  row_start = r * self._shape_[1]
@@ -1103,6 +1114,7 @@ class range_t:
1103
1114
  # definition just for kernel type (cannot be a parameter), see bvh.h
1104
1115
  class bvh_query_t:
1105
1116
  """Object used to track state during BVH traversal."""
1117
+
1106
1118
  def __init__(self):
1107
1119
  pass
1108
1120
 
@@ -1110,6 +1122,7 @@ class bvh_query_t:
1110
1122
  # definition just for kernel type (cannot be a parameter), see mesh.h
1111
1123
  class mesh_query_aabb_t:
1112
1124
  """Object used to track state during mesh traversal."""
1125
+
1113
1126
  def __init__(self):
1114
1127
  pass
1115
1128
 
@@ -1117,6 +1130,7 @@ class mesh_query_aabb_t:
1117
1130
  # definition just for kernel type (cannot be a parameter), see hash_grid.h
1118
1131
  class hash_grid_query_t:
1119
1132
  """Object used to track state during neighbor traversal."""
1133
+
1120
1134
  def __init__(self):
1121
1135
  pass
1122
1136
 
@@ -1250,18 +1264,30 @@ def type_scalar_type(dtype):
1250
1264
  return getattr(dtype, "_wp_scalar_type_", dtype)
1251
1265
 
1252
1266
 
1267
+ # Cache results of type_size_in_bytes(), because the function is actually quite slow.
1268
+ _type_size_cache = {
1269
+ float: 4,
1270
+ int: 4,
1271
+ }
1272
+
1273
+
1253
1274
  def type_size_in_bytes(dtype):
1254
- if dtype.__module__ == "ctypes":
1255
- return ctypes.sizeof(dtype)
1256
- elif isinstance(dtype, warp.codegen.Struct):
1257
- return ctypes.sizeof(dtype.ctype)
1258
- elif dtype == float or dtype == int:
1259
- return 4
1260
- elif hasattr(dtype, "_type_"):
1261
- return getattr(dtype, "_length_", 1) * ctypes.sizeof(dtype._type_)
1275
+ size = _type_size_cache.get(dtype)
1262
1276
 
1263
- else:
1264
- return 0
1277
+ if size is None:
1278
+ if dtype.__module__ == "ctypes":
1279
+ size = ctypes.sizeof(dtype)
1280
+ elif hasattr(dtype, "_type_"):
1281
+ size = getattr(dtype, "_length_", 1) * ctypes.sizeof(dtype._type_)
1282
+ elif isinstance(dtype, warp.codegen.Struct):
1283
+ size = ctypes.sizeof(dtype.ctype)
1284
+ elif dtype == Any:
1285
+ raise TypeError(f"A concrete type is required")
1286
+ else:
1287
+ raise TypeError(f"Invalid data type: {dtype}")
1288
+ _type_size_cache[dtype] = size
1289
+
1290
+ return size
1265
1291
 
1266
1292
 
1267
1293
  def type_to_warp(dtype):
@@ -1399,9 +1425,9 @@ def types_equal(a, b, match_generic=False):
1399
1425
  if match_generic:
1400
1426
  if p1 == Any or p2 == Any:
1401
1427
  return True
1402
- if p1 == Scalar and p2 in scalar_types:
1428
+ if p1 == Scalar and p2 in scalar_types + [bool]:
1403
1429
  return True
1404
- if p2 == Scalar and p1 in scalar_types:
1430
+ if p2 == Scalar and p1 in scalar_types + [bool]:
1405
1431
  return True
1406
1432
  if p1 == Scalar and p2 == Scalar:
1407
1433
  return True
@@ -1454,6 +1480,17 @@ def strides_from_shape(shape: Tuple, dtype):
1454
1480
  return tuple(strides)
1455
1481
 
1456
1482
 
1483
+ def check_array_shape(shape: Tuple):
1484
+ """Checks that the size in each dimension is positive and less than 2^32."""
1485
+
1486
+ for dim_index, dim_size in enumerate(shape):
1487
+ if dim_size < 0:
1488
+ raise ValueError(f"Array shapes must be non-negative, got {dim_size} in dimension {dim_index}")
1489
+ if dim_size >= 2**31:
1490
+ raise ValueError("Array shapes must not exceed the maximum representable value of a signed 32-bit integer, "
1491
+ f"got {dim_size} in dimension {dim_index}.")
1492
+
1493
+
1457
1494
  class array(Array):
1458
1495
  # member attributes available during code-gen (e.g.: d = array.shape[0])
1459
1496
  # (initialized when needed)
@@ -1471,7 +1508,8 @@ class array(Array):
1471
1508
  device=None,
1472
1509
  pinned=False,
1473
1510
  copy=True,
1474
- owner=True, # TODO: replace with deleter=None
1511
+ owner=False, # deprecated - pass deleter instead
1512
+ deleter=None,
1475
1513
  ndim=None,
1476
1514
  grad=None,
1477
1515
  requires_grad=False,
@@ -1505,15 +1543,18 @@ class array(Array):
1505
1543
  capacity (int): Maximum size in bytes of the ptr allocation (data should be None)
1506
1544
  device (Devicelike): Device the array lives on
1507
1545
  copy (bool): Whether the incoming data will be copied or aliased, this is only possible when the incoming `data` already lives on the device specified and types match
1508
- owner (bool): Should the array object try to deallocate memory when it is deleted
1546
+ owner (bool): Should the array object try to deallocate memory when it is deleted (deprecated, pass `deleter` if you wish to transfer ownership to Warp)
1547
+ deleter (Callable): Function to be called when deallocating the array, taking two arguments, pointer and size
1509
1548
  requires_grad (bool): Whether or not gradients will be tracked for this array, see :class:`warp.Tape` for details
1510
1549
  grad (array): The gradient array to use
1511
1550
  pinned (bool): Whether to allocate pinned host memory, which allows asynchronous host-device transfers (only applicable with device="cpu")
1512
1551
 
1513
1552
  """
1514
1553
 
1515
- self.owner = False
1554
+ self.deleter = None
1516
1555
  self.ctype = None
1556
+
1557
+ # properties
1517
1558
  self._requires_grad = False
1518
1559
  self._grad = None
1519
1560
  # __array_interface__ or __cuda_array_interface__, evaluated lazily and cached
@@ -1547,7 +1588,7 @@ class array(Array):
1547
1588
  raise RuntimeError("Can only construct arrays with either `data` or `ptr` arguments, not both")
1548
1589
  self._init_from_data(data, dtype, shape, device, copy, pinned)
1549
1590
  elif ptr is not None:
1550
- self._init_from_ptr(ptr, dtype, shape, strides, capacity, device, owner, pinned)
1591
+ self._init_from_ptr(ptr, dtype, shape, strides, capacity, device, pinned, deleter)
1551
1592
  elif shape is not None:
1552
1593
  self._init_new(dtype, shape, strides, device, pinned)
1553
1594
  else:
@@ -1562,8 +1603,7 @@ class array(Array):
1562
1603
  # allocate gradient if needed
1563
1604
  self._requires_grad = requires_grad
1564
1605
  if requires_grad:
1565
- with warp.ScopedStream(self.device.null_stream):
1566
- self._alloc_grad()
1606
+ self._alloc_grad()
1567
1607
 
1568
1608
  def _init_from_data(self, data, dtype, shape, device, copy, pinned):
1569
1609
  if not hasattr(data, "__len__"):
@@ -1674,11 +1714,16 @@ class array(Array):
1674
1714
  shape = arr.shape or (1,)
1675
1715
  strides = arr.strides or (type_size_in_bytes(dtype),)
1676
1716
 
1677
- device = warp.get_device(device)
1717
+ try:
1718
+ # Performance note: try first, ask questions later
1719
+ device = warp.context.runtime.get_device(device)
1720
+ except:
1721
+ warp.context.assert_initialized()
1722
+ raise
1678
1723
 
1679
1724
  if device.is_cpu and not copy and not pinned:
1680
1725
  # reference numpy memory directly
1681
- self._init_from_ptr(arr.ctypes.data, dtype, shape, strides, None, device, False, False)
1726
+ self._init_from_ptr(arr.ctypes.data, dtype, shape, strides, None, device, False, None)
1682
1727
  # keep a ref to the source array to keep allocation alive
1683
1728
  self._ref = arr
1684
1729
  else:
@@ -1691,82 +1736,106 @@ class array(Array):
1691
1736
  strides=strides,
1692
1737
  device="cpu",
1693
1738
  copy=False,
1694
- owner=False,
1695
1739
  )
1696
1740
  warp.copy(self, src)
1697
1741
 
1698
- def _init_from_ptr(self, ptr, dtype, shape, strides, capacity, device, owner, pinned):
1699
- if dtype == Any:
1700
- raise RuntimeError("A concrete data type is required to create the array")
1701
-
1702
- device = warp.get_device(device)
1703
-
1704
- size = 1
1705
- for d in shape:
1706
- size *= d
1707
-
1708
- contiguous_strides = strides_from_shape(shape, dtype)
1742
+ def _init_from_ptr(self, ptr, dtype, shape, strides, capacity, device, pinned, deleter):
1743
+ try:
1744
+ # Performance note: try first, ask questions later
1745
+ device = warp.context.runtime.get_device(device)
1746
+ except:
1747
+ warp.context.assert_initialized()
1748
+ raise
1749
+
1750
+ check_array_shape(shape)
1751
+ ndim = len(shape)
1752
+ dtype_size = type_size_in_bytes(dtype)
1753
+
1754
+ # compute size and contiguous strides
1755
+ # Performance note: we could use strides_from_shape() here, but inlining it is faster.
1756
+ contiguous_strides = [None] * ndim
1757
+ i = ndim - 1
1758
+ contiguous_strides[i] = dtype_size
1759
+ size = shape[i]
1760
+ while i > 0:
1761
+ contiguous_strides[i - 1] = contiguous_strides[i] * shape[i]
1762
+ i -= 1
1763
+ size *= shape[i]
1764
+ contiguous_strides = tuple(contiguous_strides)
1709
1765
 
1710
1766
  if strides is None:
1711
1767
  strides = contiguous_strides
1712
1768
  is_contiguous = True
1713
1769
  if capacity is None:
1714
- capacity = size * type_size_in_bytes(dtype)
1770
+ capacity = size * dtype_size
1715
1771
  else:
1772
+ strides = tuple(strides)
1716
1773
  is_contiguous = strides == contiguous_strides
1717
1774
  if capacity is None:
1718
1775
  capacity = shape[0] * strides[0]
1719
1776
 
1720
1777
  self.dtype = dtype
1721
- self.ndim = len(shape)
1778
+ self.ndim = ndim
1722
1779
  self.size = size
1723
1780
  self.capacity = capacity
1724
1781
  self.shape = shape
1725
1782
  self.strides = strides
1726
1783
  self.ptr = ptr
1727
1784
  self.device = device
1728
- self.owner = owner
1729
1785
  self.pinned = pinned if device.is_cpu else False
1730
1786
  self.is_contiguous = is_contiguous
1787
+ self.deleter = deleter
1731
1788
 
1732
1789
  def _init_new(self, dtype, shape, strides, device, pinned):
1733
- if dtype == Any:
1734
- raise RuntimeError("A concrete data type is required to create the array")
1735
-
1736
- device = warp.get_device(device)
1737
-
1738
- size = 1
1739
- for d in shape:
1740
- size *= d
1741
-
1742
- contiguous_strides = strides_from_shape(shape, dtype)
1790
+ try:
1791
+ # Performance note: try first, ask questions later
1792
+ device = warp.context.runtime.get_device(device)
1793
+ except:
1794
+ warp.context.assert_initialized()
1795
+ raise
1796
+
1797
+ check_array_shape(shape)
1798
+ ndim = len(shape)
1799
+ dtype_size = type_size_in_bytes(dtype)
1800
+
1801
+ # compute size and contiguous strides
1802
+ # Performance note: we could use strides_from_shape() here, but inlining it is faster.
1803
+ contiguous_strides = [None] * ndim
1804
+ i = ndim - 1
1805
+ contiguous_strides[i] = dtype_size
1806
+ size = shape[i]
1807
+ while i > 0:
1808
+ contiguous_strides[i - 1] = contiguous_strides[i] * shape[i]
1809
+ i -= 1
1810
+ size *= shape[i]
1811
+ contiguous_strides = tuple(contiguous_strides)
1743
1812
 
1744
1813
  if strides is None:
1745
1814
  strides = contiguous_strides
1746
1815
  is_contiguous = True
1747
- capacity = size * type_size_in_bytes(dtype)
1816
+ capacity = size * dtype_size
1748
1817
  else:
1818
+ strides = tuple(strides)
1749
1819
  is_contiguous = strides == contiguous_strides
1750
1820
  capacity = shape[0] * strides[0]
1751
1821
 
1822
+ allocator = device.get_allocator(pinned=pinned)
1752
1823
  if capacity > 0:
1753
- ptr = device.allocator.alloc(capacity, pinned=pinned)
1754
- if ptr is None:
1755
- raise RuntimeError(f"Array allocation failed on device: {device} for {capacity} bytes")
1824
+ ptr = allocator.alloc(capacity)
1756
1825
  else:
1757
1826
  ptr = None
1758
1827
 
1759
1828
  self.dtype = dtype
1760
- self.ndim = len(shape)
1829
+ self.ndim = ndim
1761
1830
  self.size = size
1762
1831
  self.capacity = capacity
1763
1832
  self.shape = shape
1764
1833
  self.strides = strides
1765
1834
  self.ptr = ptr
1766
1835
  self.device = device
1767
- self.owner = True
1768
1836
  self.pinned = pinned if device.is_cpu else False
1769
1837
  self.is_contiguous = is_contiguous
1838
+ self.deleter = allocator.deleter
1770
1839
 
1771
1840
  def _init_annotation(self, dtype, ndim):
1772
1841
  self.dtype = dtype
@@ -1777,12 +1846,20 @@ class array(Array):
1777
1846
  self.strides = (0,) * ndim
1778
1847
  self.ptr = None
1779
1848
  self.device = None
1780
- self.owner = False
1781
1849
  self.pinned = False
1782
1850
  self.is_contiguous = False
1783
1851
 
1852
+ def __del__(self):
1853
+
1854
+ if self.deleter is None:
1855
+ return
1856
+
1857
+ with self.device.context_guard:
1858
+ self.deleter(self.ptr, self.capacity)
1859
+
1784
1860
  @property
1785
1861
  def __array_interface__(self):
1862
+
1786
1863
  # raising an AttributeError here makes hasattr() return False
1787
1864
  if self.device is None or not self.device.is_cpu:
1788
1865
  raise AttributeError(f"__array_interface__ not supported because device is {self.device}")
@@ -1819,6 +1896,7 @@ class array(Array):
1819
1896
 
1820
1897
  @property
1821
1898
  def __cuda_array_interface__(self):
1899
+
1822
1900
  # raising an AttributeError here makes hasattr() return False
1823
1901
  if self.device is None or not self.device.is_cuda:
1824
1902
  raise AttributeError(f"__cuda_array_interface__ is not supported because device is {self.device}")
@@ -1845,11 +1923,45 @@ class array(Array):
1845
1923
 
1846
1924
  return self._array_interface
1847
1925
 
1848
- def __del__(self):
1849
- if self.owner:
1850
- # use CUDA context guard to avoid side effects during garbage collection
1851
- with self.device.context_guard:
1852
- self.device.allocator.free(self.ptr, self.capacity, self.pinned)
1926
+ def __dlpack__(self, stream=None):
1927
+ # See https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.__dlpack__.html
1928
+
1929
+ if self.device is None:
1930
+ raise RuntimeError("Array has no device assigned")
1931
+
1932
+ if self.device.is_cuda and stream != -1:
1933
+ if not isinstance(stream, int):
1934
+ raise TypeError("DLPack stream must be an integer or None")
1935
+
1936
+ # assume that the array is being used on its device's current stream
1937
+ array_stream = self.device.stream
1938
+
1939
+ # the external stream should wait for outstanding operations to complete
1940
+ if stream in (None, 0, 1):
1941
+ external_stream = 0
1942
+ else:
1943
+ external_stream = stream
1944
+
1945
+ # Performance note: avoid wrapping the external stream in a temporary Stream object
1946
+ if external_stream != array_stream.cuda_stream:
1947
+ warp.context.runtime.core.cuda_stream_wait_stream(
1948
+ external_stream, array_stream.cuda_stream, array_stream.cached_event.cuda_event
1949
+ )
1950
+
1951
+ return warp.dlpack.to_dlpack(self)
1952
+
1953
+ def __dlpack_device__(self):
1954
+ # See https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.array.__dlpack_device__.html
1955
+
1956
+ if self.device is None:
1957
+ raise RuntimeError("Array has no device assigned")
1958
+
1959
+ if self.device.is_cuda:
1960
+ return (warp.dlpack.DLDeviceType.kDLCUDA, self.device.ordinal)
1961
+ elif self.pinned:
1962
+ return (warp.dlpack.DLDeviceType.kDLCUDAHost, 0)
1963
+ else:
1964
+ return (warp.dlpack.DLDeviceType.kDLCPU, 0)
1853
1965
 
1854
1966
  def __len__(self):
1855
1967
  return self.shape[0]
@@ -1949,7 +2061,6 @@ class array(Array):
1949
2061
  strides=tuple(new_strides),
1950
2062
  device=self.grad.device,
1951
2063
  pinned=self.grad.pinned,
1952
- owner=False,
1953
2064
  )
1954
2065
  # store back-ref to stop data being destroyed
1955
2066
  new_grad._ref = self.grad
@@ -1963,7 +2074,6 @@ class array(Array):
1963
2074
  strides=tuple(new_strides),
1964
2075
  device=self.device,
1965
2076
  pinned=self.pinned,
1966
- owner=False,
1967
2077
  grad=new_grad,
1968
2078
  )
1969
2079
 
@@ -2002,7 +2112,7 @@ class array(Array):
2002
2112
  n = other.shape[1]
2003
2113
  c = warp.zeros(shape=(m, n), dtype=self.dtype, device=self.device, requires_grad=True)
2004
2114
  d = warp.zeros(shape=(m, n), dtype=self.dtype, device=self.device, requires_grad=True)
2005
- matmul(self, other, c, d, device=self.device)
2115
+ matmul(self, other, c, d)
2006
2116
  return d
2007
2117
 
2008
2118
  @property
@@ -2046,10 +2156,9 @@ class array(Array):
2046
2156
  self.ctype = None
2047
2157
 
2048
2158
  def _alloc_grad(self):
2049
- self._grad = array(
2159
+ self._grad = warp.zeros(
2050
2160
  dtype=self.dtype, shape=self.shape, strides=self.strides, device=self.device, pinned=self.pinned
2051
2161
  )
2052
- self._grad.zero_()
2053
2162
 
2054
2163
  # trigger re-creation of C-representation
2055
2164
  self.ctype = None
@@ -2246,7 +2355,6 @@ class array(Array):
2246
2355
  device=self.device,
2247
2356
  pinned=self.pinned,
2248
2357
  copy=False,
2249
- owner=False,
2250
2358
  grad=None if self.grad is None else self.grad.flatten(),
2251
2359
  )
2252
2360
 
@@ -2308,7 +2416,6 @@ class array(Array):
2308
2416
  device=self.device,
2309
2417
  pinned=self.pinned,
2310
2418
  copy=False,
2311
- owner=False,
2312
2419
  grad=None if self.grad is None else self.grad.reshape(shape),
2313
2420
  )
2314
2421
 
@@ -2332,7 +2439,6 @@ class array(Array):
2332
2439
  device=self.device,
2333
2440
  pinned=self.pinned,
2334
2441
  copy=False,
2335
- owner=False,
2336
2442
  grad=None if self.grad is None else self.grad.view(dtype),
2337
2443
  )
2338
2444
 
@@ -2384,7 +2490,6 @@ class array(Array):
2384
2490
  device=self.device,
2385
2491
  pinned=self.pinned,
2386
2492
  copy=False,
2387
- owner=False,
2388
2493
  grad=None if self.grad is None else self.grad.transpose(axes=axes),
2389
2494
  )
2390
2495
 
@@ -2427,7 +2532,6 @@ def from_ptr(ptr, length, dtype=None, shape=None, device=None):
2427
2532
  ptr=0 if ptr == 0 else ctypes.cast(ptr, ctypes.POINTER(ctypes.c_size_t)).contents.value,
2428
2533
  shape=shape,
2429
2534
  device=device,
2430
- owner=False,
2431
2535
  requires_grad=False,
2432
2536
  )
2433
2537
 
@@ -2682,6 +2786,8 @@ class Bvh:
2682
2786
  uppers (:class:`warp.array`): Array of upper bounds :class:`warp.vec3`
2683
2787
  """
2684
2788
 
2789
+ self.id = 0
2790
+
2685
2791
  if len(lowers) != len(uppers):
2686
2792
  raise RuntimeError("Bvh the same number of lower and upper bounds must be provided")
2687
2793
 
@@ -2704,39 +2810,35 @@ class Bvh:
2704
2810
  else:
2705
2811
  return ctypes.c_void_p(0)
2706
2812
 
2707
- from warp.context import runtime
2813
+ self.runtime = warp.context.runtime
2708
2814
 
2709
2815
  if self.device.is_cpu:
2710
- self.id = runtime.core.bvh_create_host(get_data(lowers), get_data(uppers), int(len(lowers)))
2816
+ self.id = self.runtime.core.bvh_create_host(get_data(lowers), get_data(uppers), int(len(lowers)))
2711
2817
  else:
2712
- self.id = runtime.core.bvh_create_device(
2818
+ self.id = self.runtime.core.bvh_create_device(
2713
2819
  self.device.context, get_data(lowers), get_data(uppers), int(len(lowers))
2714
2820
  )
2715
2821
 
2716
2822
  def __del__(self):
2717
- try:
2718
- from warp.context import runtime
2719
2823
 
2720
- if self.device.is_cpu:
2721
- runtime.core.bvh_destroy_host(self.id)
2722
- else:
2723
- # use CUDA context guard to avoid side effects during garbage collection
2724
- with self.device.context_guard:
2725
- runtime.core.bvh_destroy_device(self.id)
2824
+ if not self.id:
2825
+ return
2726
2826
 
2727
- except Exception:
2728
- pass
2827
+ if self.device.is_cpu:
2828
+ self.runtime.core.bvh_destroy_host(self.id)
2829
+ else:
2830
+ # use CUDA context guard to avoid side effects during garbage collection
2831
+ with self.device.context_guard:
2832
+ self.runtime.core.bvh_destroy_device(self.id)
2729
2833
 
2730
2834
  def refit(self):
2731
2835
  """Refit the BVH. This should be called after users modify the `lowers` and `uppers` arrays."""
2732
2836
 
2733
- from warp.context import runtime
2734
-
2735
2837
  if self.device.is_cpu:
2736
- runtime.core.bvh_refit_host(self.id)
2838
+ self.runtime.core.bvh_refit_host(self.id)
2737
2839
  else:
2738
- runtime.core.bvh_refit_device(self.id)
2739
- runtime.verify_cuda_device(self.device)
2840
+ self.runtime.core.bvh_refit_device(self.id)
2841
+ self.runtime.verify_cuda_device(self.device)
2740
2842
 
2741
2843
 
2742
2844
  class Mesh:
@@ -2762,6 +2864,8 @@ class Mesh:
2762
2864
  support_winding_number (bool): If true the mesh will build additional datastructures to support `wp.mesh_query_point_sign_winding_number()` queries
2763
2865
  """
2764
2866
 
2867
+ self.id = 0
2868
+
2765
2869
  if points.device != indices.device:
2766
2870
  raise RuntimeError("Mesh points and indices must live on the same device")
2767
2871
 
@@ -2782,10 +2886,10 @@ class Mesh:
2782
2886
  self.velocities = velocities
2783
2887
  self.indices = indices
2784
2888
 
2785
- from warp.context import runtime
2889
+ self.runtime = warp.context.runtime
2786
2890
 
2787
2891
  if self.device.is_cpu:
2788
- self.id = runtime.core.mesh_create_host(
2892
+ self.id = self.runtime.core.mesh_create_host(
2789
2893
  points.__ctype__(),
2790
2894
  velocities.__ctype__() if velocities else array().__ctype__(),
2791
2895
  indices.__ctype__(),
@@ -2794,7 +2898,7 @@ class Mesh:
2794
2898
  int(support_winding_number),
2795
2899
  )
2796
2900
  else:
2797
- self.id = runtime.core.mesh_create_device(
2901
+ self.id = self.runtime.core.mesh_create_device(
2798
2902
  self.device.context,
2799
2903
  points.__ctype__(),
2800
2904
  velocities.__ctype__() if velocities else array().__ctype__(),
@@ -2805,28 +2909,25 @@ class Mesh:
2805
2909
  )
2806
2910
 
2807
2911
  def __del__(self):
2808
- try:
2809
- from warp.context import runtime
2810
2912
 
2811
- if self.device.is_cpu:
2812
- runtime.core.mesh_destroy_host(self.id)
2813
- else:
2814
- # use CUDA context guard to avoid side effects during garbage collection
2815
- with self.device.context_guard:
2816
- runtime.core.mesh_destroy_device(self.id)
2817
- except Exception:
2818
- pass
2913
+ if not self.id:
2914
+ return
2915
+
2916
+ if self.device.is_cpu:
2917
+ self.runtime.core.mesh_destroy_host(self.id)
2918
+ else:
2919
+ # use CUDA context guard to avoid side effects during garbage collection
2920
+ with self.device.context_guard:
2921
+ self.runtime.core.mesh_destroy_device(self.id)
2819
2922
 
2820
2923
  def refit(self):
2821
2924
  """Refit the BVH to points. This should be called after users modify the `points` data."""
2822
2925
 
2823
- from warp.context import runtime
2824
-
2825
2926
  if self.device.is_cpu:
2826
- runtime.core.mesh_refit_host(self.id)
2927
+ self.runtime.core.mesh_refit_host(self.id)
2827
2928
  else:
2828
- runtime.core.mesh_refit_device(self.id)
2829
- runtime.verify_cuda_device(self.device)
2929
+ self.runtime.core.mesh_refit_device(self.id)
2930
+ self.runtime.verify_cuda_device(self.device)
2830
2931
 
2831
2932
 
2832
2933
  class Volume:
@@ -2844,9 +2945,8 @@ class Volume:
2844
2945
 
2845
2946
  self.id = 0
2846
2947
 
2847
- from warp.context import runtime
2848
-
2849
- self.context = runtime
2948
+ # keep a runtime reference for orderly destruction
2949
+ self.runtime = warp.context.runtime
2850
2950
 
2851
2951
  if data is None:
2852
2952
  return
@@ -2856,9 +2956,9 @@ class Volume:
2856
2956
  self.device = data.device
2857
2957
 
2858
2958
  if self.device.is_cpu:
2859
- self.id = self.context.core.volume_create_host(ctypes.cast(data.ptr, ctypes.c_void_p), data.size)
2959
+ self.id = self.runtime.core.volume_create_host(ctypes.cast(data.ptr, ctypes.c_void_p), data.size)
2860
2960
  else:
2861
- self.id = self.context.core.volume_create_device(
2961
+ self.id = self.runtime.core.volume_create_device(
2862
2962
  self.device.context, ctypes.cast(data.ptr, ctypes.c_void_p), data.size
2863
2963
  )
2864
2964
 
@@ -2866,31 +2966,26 @@ class Volume:
2866
2966
  raise RuntimeError("Failed to create volume from input array")
2867
2967
 
2868
2968
  def __del__(self):
2869
- if self.id == 0:
2870
- return
2871
-
2872
- try:
2873
- from warp.context import runtime
2874
2969
 
2875
- if self.device.is_cpu:
2876
- runtime.core.volume_destroy_host(self.id)
2877
- else:
2878
- # use CUDA context guard to avoid side effects during garbage collection
2879
- with self.device.context_guard:
2880
- runtime.core.volume_destroy_device(self.id)
2970
+ if not self.id:
2971
+ return
2881
2972
 
2882
- except Exception:
2883
- pass
2973
+ if self.device.is_cpu:
2974
+ self.runtime.core.volume_destroy_host(self.id)
2975
+ else:
2976
+ # use CUDA context guard to avoid side effects during garbage collection
2977
+ with self.device.context_guard:
2978
+ self.runtime.core.volume_destroy_device(self.id)
2884
2979
 
2885
2980
  def array(self) -> array:
2886
2981
  """Returns the raw memory buffer of the Volume as an array"""
2887
2982
  buf = ctypes.c_void_p(0)
2888
2983
  size = ctypes.c_uint64(0)
2889
2984
  if self.device.is_cpu:
2890
- self.context.core.volume_get_buffer_info_host(self.id, ctypes.byref(buf), ctypes.byref(size))
2985
+ self.runtime.core.volume_get_buffer_info_host(self.id, ctypes.byref(buf), ctypes.byref(size))
2891
2986
  else:
2892
- self.context.core.volume_get_buffer_info_device(self.id, ctypes.byref(buf), ctypes.byref(size))
2893
- return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device, owner=False)
2987
+ self.runtime.core.volume_get_buffer_info_device(self.id, ctypes.byref(buf), ctypes.byref(size))
2988
+ return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device)
2894
2989
 
2895
2990
  def get_tiles(self) -> array:
2896
2991
  if self.id == 0:
@@ -2899,18 +2994,24 @@ class Volume:
2899
2994
  buf = ctypes.c_void_p(0)
2900
2995
  size = ctypes.c_uint64(0)
2901
2996
  if self.device.is_cpu:
2902
- self.context.core.volume_get_tiles_host(self.id, ctypes.byref(buf), ctypes.byref(size))
2997
+ self.runtime.core.volume_get_tiles_host(self.id, ctypes.byref(buf), ctypes.byref(size))
2998
+ deleter = self.device.default_allocator.deleter
2903
2999
  else:
2904
- self.context.core.volume_get_tiles_device(self.id, ctypes.byref(buf), ctypes.byref(size))
3000
+ self.runtime.core.volume_get_tiles_device(self.id, ctypes.byref(buf), ctypes.byref(size))
3001
+ if self.device.is_mempool_supported:
3002
+ deleter = self.device.mempool_allocator.deleter
3003
+ else:
3004
+ deleter = self.device.default_allocator.deleter
2905
3005
  num_tiles = size.value // (3 * 4)
2906
- return array(ptr=buf.value, dtype=int32, shape=(num_tiles, 3), device=self.device, owner=True)
3006
+
3007
+ return array(ptr=buf.value, dtype=int32, shape=(num_tiles, 3), device=self.device, deleter=deleter)
2907
3008
 
2908
3009
  def get_voxel_size(self) -> Tuple[float, float, float]:
2909
3010
  if self.id == 0:
2910
3011
  raise RuntimeError("Invalid Volume")
2911
3012
 
2912
3013
  dx, dy, dz = ctypes.c_float(0), ctypes.c_float(0), ctypes.c_float(0)
2913
- self.context.core.volume_get_voxel_size(self.id, ctypes.byref(dx), ctypes.byref(dy), ctypes.byref(dz))
3014
+ self.runtime.core.volume_get_voxel_size(self.id, ctypes.byref(dx), ctypes.byref(dy), ctypes.byref(dz))
2914
3015
  return (dx.value, dy.value, dz.value)
2915
3016
 
2916
3017
  @classmethod
@@ -3109,9 +3210,7 @@ class Volume:
3109
3210
  device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
3110
3211
 
3111
3212
  """
3112
- from warp.context import runtime
3113
-
3114
- device = runtime.get_device(device)
3213
+ device = warp.get_device(device)
3115
3214
 
3116
3215
  if voxel_size <= 0.0:
3117
3216
  raise RuntimeError(f"Voxel size must be positive! Got {voxel_size}")
@@ -3130,7 +3229,7 @@ class Volume:
3130
3229
  volume.device = device
3131
3230
  in_world_space = tile_points.dtype == vec3
3132
3231
  if hasattr(bg_value, "__len__"):
3133
- volume.id = volume.context.core.volume_v_from_tiles_device(
3232
+ volume.id = volume.runtime.core.volume_v_from_tiles_device(
3134
3233
  volume.device.context,
3135
3234
  ctypes.c_void_p(tile_points.ptr),
3136
3235
  tile_points.shape[0],
@@ -3144,7 +3243,7 @@ class Volume:
3144
3243
  in_world_space,
3145
3244
  )
3146
3245
  elif isinstance(bg_value, int):
3147
- volume.id = volume.context.core.volume_i_from_tiles_device(
3246
+ volume.id = volume.runtime.core.volume_i_from_tiles_device(
3148
3247
  volume.device.context,
3149
3248
  ctypes.c_void_p(tile_points.ptr),
3150
3249
  tile_points.shape[0],
@@ -3156,7 +3255,7 @@ class Volume:
3156
3255
  in_world_space,
3157
3256
  )
3158
3257
  else:
3159
- volume.id = volume.context.core.volume_f_from_tiles_device(
3258
+ volume.id = volume.runtime.core.volume_f_from_tiles_device(
3160
3259
  volume.device.context,
3161
3260
  ctypes.c_void_p(tile_points.ptr),
3162
3261
  tile_points.shape[0],
@@ -3194,6 +3293,7 @@ class mesh_query_point_t:
3194
3293
  :func:`mesh_query_point_sign_normal`,
3195
3294
  and :func:`mesh_query_point_sign_winding_number`.
3196
3295
  """
3296
+
3197
3297
  from warp.codegen import Var
3198
3298
 
3199
3299
  vars = {
@@ -3222,6 +3322,7 @@ class mesh_query_ray_t:
3222
3322
  See Also:
3223
3323
  :func:`mesh_query_ray`.
3224
3324
  """
3325
+
3225
3326
  from warp.codegen import Var
3226
3327
 
3227
3328
  vars = {
@@ -3243,7 +3344,6 @@ def matmul(
3243
3344
  alpha: float = 1.0,
3244
3345
  beta: float = 0.0,
3245
3346
  allow_tf32x3_arith: builtins.bool = False,
3246
- device=None,
3247
3347
  ):
3248
3348
  """Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
3249
3349
 
@@ -3256,14 +3356,12 @@ def matmul(
3256
3356
  beta (float): parameter beta of GEMM
3257
3357
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
3258
3358
  while using Tensor Cores
3259
- device: device we want to use to multiply matrices. Defaults to active runtime device. If "cpu", resorts to using numpy multiplication.
3260
3359
  """
3261
3360
  from warp.context import runtime
3262
3361
 
3263
- if device is None:
3264
- device = runtime.get_device(device)
3362
+ device = a.device
3265
3363
 
3266
- if a.device != device or b.device != device or c.device != device or d.device != device:
3364
+ if b.device != device or c.device != device or d.device != device:
3267
3365
  raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
3268
3366
 
3269
3367
  if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
@@ -3271,7 +3369,12 @@ def matmul(
3271
3369
  "wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
3272
3370
  )
3273
3371
 
3274
- if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
3372
+ if (
3373
+ (not a.is_contiguous and not a.is_transposed)
3374
+ or (not b.is_contiguous and not b.is_transposed)
3375
+ or (not c.is_contiguous)
3376
+ or (not d.is_contiguous)
3377
+ ):
3275
3378
  raise RuntimeError(
3276
3379
  "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
3277
3380
  )
@@ -3286,9 +3389,7 @@ def matmul(
3286
3389
 
3287
3390
  if runtime.tape:
3288
3391
  runtime.tape.record_func(
3289
- backward=lambda: adj_matmul(
3290
- a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
3291
- ),
3392
+ backward=lambda: adj_matmul(a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith),
3292
3393
  arrays=[a, b, c, d],
3293
3394
  )
3294
3395
 
@@ -3299,6 +3400,7 @@ def matmul(
3299
3400
 
3300
3401
  cc = device.arch
3301
3402
  ret = runtime.core.cutlass_gemm(
3403
+ device.context,
3302
3404
  cc,
3303
3405
  m,
3304
3406
  n,
@@ -3330,7 +3432,6 @@ def adj_matmul(
3330
3432
  alpha: float = 1.0,
3331
3433
  beta: float = 0.0,
3332
3434
  allow_tf32x3_arith: builtins.bool = False,
3333
- device=None,
3334
3435
  ):
3335
3436
  """Computes the adjoint of a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
3336
3437
  note: the adjoint of parameter alpha is not included but can be computed as `adj_alpha = np.sum(np.concatenate(np.multiply(a @ b, adj_d)))`.
@@ -3348,16 +3449,13 @@ def adj_matmul(
3348
3449
  beta (float): parameter beta of GEMM
3349
3450
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
3350
3451
  while using Tensor Cores
3351
- device: device we want to use to multiply matrices. Defaults to active runtime device. If "cpu", resorts to using numpy multiplication.
3352
3452
  """
3353
3453
  from warp.context import runtime
3354
3454
 
3355
- if device is None:
3356
- device = runtime.get_device(device)
3455
+ device = a.device
3357
3456
 
3358
3457
  if (
3359
- a.device != device
3360
- or b.device != device
3458
+ b.device != device
3361
3459
  or c.device != device
3362
3460
  or adj_a.device != device
3363
3461
  or adj_b.device != device
@@ -3392,7 +3490,7 @@ def adj_matmul(
3392
3490
  raise RuntimeError(
3393
3491
  "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
3394
3492
  )
3395
-
3493
+
3396
3494
  m = a.shape[0]
3397
3495
  n = b.shape[1]
3398
3496
  k = a.shape[1]
@@ -3423,6 +3521,7 @@ def adj_matmul(
3423
3521
  # adj_a
3424
3522
  if not a.is_transposed:
3425
3523
  ret = runtime.core.cutlass_gemm(
3524
+ device.context,
3426
3525
  cc,
3427
3526
  m,
3428
3527
  k,
@@ -3443,6 +3542,7 @@ def adj_matmul(
3443
3542
  raise RuntimeError("adj_matmul failed.")
3444
3543
  else:
3445
3544
  ret = runtime.core.cutlass_gemm(
3545
+ device.context,
3446
3546
  cc,
3447
3547
  k,
3448
3548
  m,
@@ -3465,6 +3565,7 @@ def adj_matmul(
3465
3565
  # adj_b
3466
3566
  if not b.is_transposed:
3467
3567
  ret = runtime.core.cutlass_gemm(
3568
+ device.context,
3468
3569
  cc,
3469
3570
  k,
3470
3571
  n,
@@ -3485,6 +3586,7 @@ def adj_matmul(
3485
3586
  raise RuntimeError("adj_matmul failed.")
3486
3587
  else:
3487
3588
  ret = runtime.core.cutlass_gemm(
3589
+ device.context,
3488
3590
  cc,
3489
3591
  n,
3490
3592
  k,
@@ -3502,7 +3604,7 @@ def adj_matmul(
3502
3604
  1,
3503
3605
  )
3504
3606
  if not ret:
3505
- raise RuntimeError("adj_matmul failed.")
3607
+ raise RuntimeError("adj_matmul failed.")
3506
3608
 
3507
3609
  # adj_c
3508
3610
  warp.launch(
@@ -3510,7 +3612,7 @@ def adj_matmul(
3510
3612
  dim=adj_c.shape,
3511
3613
  inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3512
3614
  device=device,
3513
- record_tape=False
3615
+ record_tape=False,
3514
3616
  )
3515
3617
 
3516
3618
 
@@ -3522,7 +3624,6 @@ def batched_matmul(
3522
3624
  alpha: float = 1.0,
3523
3625
  beta: float = 0.0,
3524
3626
  allow_tf32x3_arith: builtins.bool = False,
3525
- device=None,
3526
3627
  ):
3527
3628
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
3528
3629
 
@@ -3535,14 +3636,12 @@ def batched_matmul(
3535
3636
  beta (float): parameter beta of GEMM
3536
3637
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
3537
3638
  while using Tensor Cores
3538
- device: device we want to use to multiply matrices. Defaults to active runtime device. If "cpu", resorts to using numpy multiplication.
3539
3639
  """
3540
3640
  from warp.context import runtime
3541
3641
 
3542
- if device is None:
3543
- device = runtime.get_device(device)
3642
+ device = a.device
3544
3643
 
3545
- if a.device != device or b.device != device or c.device != device or d.device != device:
3644
+ if b.device != device or c.device != device or d.device != device:
3546
3645
  raise RuntimeError("Matrices A, B, C, and D must all be on the same device as the runtime device.")
3547
3646
 
3548
3647
  if a.dtype != b.dtype or a.dtype != c.dtype or a.dtype != d.dtype:
@@ -3550,7 +3649,12 @@ def batched_matmul(
3550
3649
  "wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
3551
3650
  )
3552
3651
 
3553
- if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
3652
+ if (
3653
+ (not a.is_contiguous and not a.is_transposed)
3654
+ or (not b.is_contiguous and not b.is_transposed)
3655
+ or (not c.is_contiguous)
3656
+ or (not d.is_contiguous)
3657
+ ):
3554
3658
  raise RuntimeError(
3555
3659
  "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
3556
3660
  )
@@ -3567,7 +3671,7 @@ def batched_matmul(
3567
3671
  if runtime.tape:
3568
3672
  runtime.tape.record_func(
3569
3673
  backward=lambda: adj_batched_matmul(
3570
- a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
3674
+ a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith
3571
3675
  ),
3572
3676
  arrays=[a, b, c, d],
3573
3677
  )
@@ -3587,15 +3691,16 @@ def batched_matmul(
3587
3691
  idx_start = i * max_batch_count
3588
3692
  idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
3589
3693
  ret = runtime.core.cutlass_gemm(
3694
+ device.context,
3590
3695
  cc,
3591
3696
  m,
3592
3697
  n,
3593
3698
  k,
3594
3699
  type_typestr(a.dtype).encode(),
3595
- ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3596
- ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3597
- ctypes.c_void_p(c[idx_start:idx_end,:,:].ptr),
3598
- ctypes.c_void_p(d[idx_start:idx_end,:,:].ptr),
3700
+ ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
3701
+ ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
3702
+ ctypes.c_void_p(c[idx_start:idx_end, :, :].ptr),
3703
+ ctypes.c_void_p(d[idx_start:idx_end, :, :].ptr),
3599
3704
  alpha,
3600
3705
  beta,
3601
3706
  not a.is_transposed,
@@ -3605,18 +3710,19 @@ def batched_matmul(
3605
3710
  )
3606
3711
  if not ret:
3607
3712
  raise RuntimeError("Batched matmul failed.")
3608
-
3713
+
3609
3714
  idx_start = iters * max_batch_count
3610
3715
  ret = runtime.core.cutlass_gemm(
3716
+ device.context,
3611
3717
  cc,
3612
3718
  m,
3613
3719
  n,
3614
3720
  k,
3615
3721
  type_typestr(a.dtype).encode(),
3616
- ctypes.c_void_p(a[idx_start:,:,:].ptr),
3617
- ctypes.c_void_p(b[idx_start:,:,:].ptr),
3618
- ctypes.c_void_p(c[idx_start:,:,:].ptr),
3619
- ctypes.c_void_p(d[idx_start:,:,:].ptr),
3722
+ ctypes.c_void_p(a[idx_start:, :, :].ptr),
3723
+ ctypes.c_void_p(b[idx_start:, :, :].ptr),
3724
+ ctypes.c_void_p(c[idx_start:, :, :].ptr),
3725
+ ctypes.c_void_p(d[idx_start:, :, :].ptr),
3620
3726
  alpha,
3621
3727
  beta,
3622
3728
  not a.is_transposed,
@@ -3625,7 +3731,7 @@ def batched_matmul(
3625
3731
  remainder,
3626
3732
  )
3627
3733
  if not ret:
3628
- raise RuntimeError("Batched matmul failed.")
3734
+ raise RuntimeError("Batched matmul failed.")
3629
3735
 
3630
3736
 
3631
3737
  def adj_batched_matmul(
@@ -3639,7 +3745,6 @@ def adj_batched_matmul(
3639
3745
  alpha: float = 1.0,
3640
3746
  beta: float = 0.0,
3641
3747
  allow_tf32x3_arith: builtins.bool = False,
3642
- device=None,
3643
3748
  ):
3644
3749
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
3645
3750
 
@@ -3655,16 +3760,13 @@ def adj_batched_matmul(
3655
3760
  beta (float): parameter beta of GEMM
3656
3761
  allow_tf32x3_arith (bool): whether to use CUTLASS's 3xTF32 GEMMs, which enable accuracy similar to FP32
3657
3762
  while using Tensor Cores
3658
- device: device we want to use to multiply matrices. Defaults to active runtime device. If "cpu", resorts to using numpy multiplication.
3659
3763
  """
3660
3764
  from warp.context import runtime
3661
3765
 
3662
- if device is None:
3663
- device = runtime.get_device(device)
3766
+ device = a.device
3664
3767
 
3665
3768
  if (
3666
- a.device != device
3667
- or b.device != device
3769
+ b.device != device
3668
3770
  or c.device != device
3669
3771
  or adj_a.device != device
3670
3772
  or adj_b.device != device
@@ -3739,15 +3841,16 @@ def adj_batched_matmul(
3739
3841
  # adj_a
3740
3842
  if not a.is_transposed:
3741
3843
  ret = runtime.core.cutlass_gemm(
3844
+ device.context,
3742
3845
  cc,
3743
3846
  m,
3744
3847
  k,
3745
3848
  n,
3746
3849
  type_typestr(a.dtype).encode(),
3747
- ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3748
- ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3749
- ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3750
- ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3850
+ ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
3851
+ ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
3852
+ ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
3853
+ ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
3751
3854
  alpha,
3752
3855
  1.0,
3753
3856
  True,
@@ -3759,15 +3862,16 @@ def adj_batched_matmul(
3759
3862
  raise RuntimeError("adj_matmul failed.")
3760
3863
  else:
3761
3864
  ret = runtime.core.cutlass_gemm(
3865
+ device.context,
3762
3866
  cc,
3763
3867
  k,
3764
3868
  m,
3765
3869
  n,
3766
3870
  type_typestr(a.dtype).encode(),
3767
- ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3768
- ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3769
- ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3770
- ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3871
+ ctypes.c_void_p(b[idx_start:idx_end, :, :].ptr),
3872
+ ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
3873
+ ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
3874
+ ctypes.c_void_p(adj_a[idx_start:idx_end, :, :].ptr),
3771
3875
  alpha,
3772
3876
  1.0,
3773
3877
  not b.is_transposed,
@@ -3781,15 +3885,16 @@ def adj_batched_matmul(
3781
3885
  # adj_b
3782
3886
  if not b.is_transposed:
3783
3887
  ret = runtime.core.cutlass_gemm(
3888
+ device.context,
3784
3889
  cc,
3785
3890
  k,
3786
3891
  n,
3787
3892
  m,
3788
3893
  type_typestr(a.dtype).encode(),
3789
- ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3790
- ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3791
- ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3792
- ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3894
+ ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
3895
+ ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
3896
+ ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
3897
+ ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
3793
3898
  alpha,
3794
3899
  1.0,
3795
3900
  a.is_transposed,
@@ -3801,15 +3906,16 @@ def adj_batched_matmul(
3801
3906
  raise RuntimeError("adj_matmul failed.")
3802
3907
  else:
3803
3908
  ret = runtime.core.cutlass_gemm(
3909
+ device.context,
3804
3910
  cc,
3805
3911
  n,
3806
3912
  k,
3807
3913
  m,
3808
3914
  type_typestr(a.dtype).encode(),
3809
- ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3810
- ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3811
- ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3812
- ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3915
+ ctypes.c_void_p(adj_d[idx_start:idx_end, :, :].ptr),
3916
+ ctypes.c_void_p(a[idx_start:idx_end, :, :].ptr),
3917
+ ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
3918
+ ctypes.c_void_p(adj_b[idx_start:idx_end, :, :].ptr),
3813
3919
  alpha,
3814
3920
  1.0,
3815
3921
  False,
@@ -3818,22 +3924,23 @@ def adj_batched_matmul(
3818
3924
  max_batch_count,
3819
3925
  )
3820
3926
  if not ret:
3821
- raise RuntimeError("adj_matmul failed.")
3822
-
3927
+ raise RuntimeError("adj_matmul failed.")
3928
+
3823
3929
  idx_start = iters * max_batch_count
3824
-
3930
+
3825
3931
  # adj_a
3826
3932
  if not a.is_transposed:
3827
3933
  ret = runtime.core.cutlass_gemm(
3934
+ device.context,
3828
3935
  cc,
3829
3936
  m,
3830
3937
  k,
3831
3938
  n,
3832
3939
  type_typestr(a.dtype).encode(),
3833
- ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3834
- ctypes.c_void_p(b[idx_start:,:,:].ptr),
3835
- ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3836
- ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3940
+ ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
3941
+ ctypes.c_void_p(b[idx_start:, :, :].ptr),
3942
+ ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
3943
+ ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
3837
3944
  alpha,
3838
3945
  1.0,
3839
3946
  True,
@@ -3845,15 +3952,16 @@ def adj_batched_matmul(
3845
3952
  raise RuntimeError("adj_matmul failed.")
3846
3953
  else:
3847
3954
  ret = runtime.core.cutlass_gemm(
3955
+ device.context,
3848
3956
  cc,
3849
3957
  k,
3850
3958
  m,
3851
3959
  n,
3852
3960
  type_typestr(a.dtype).encode(),
3853
- ctypes.c_void_p(b[idx_start:,:,:].ptr),
3854
- ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3855
- ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3856
- ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3961
+ ctypes.c_void_p(b[idx_start:, :, :].ptr),
3962
+ ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
3963
+ ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
3964
+ ctypes.c_void_p(adj_a[idx_start:, :, :].ptr),
3857
3965
  alpha,
3858
3966
  1.0,
3859
3967
  not b.is_transposed,
@@ -3867,15 +3975,16 @@ def adj_batched_matmul(
3867
3975
  # adj_b
3868
3976
  if not b.is_transposed:
3869
3977
  ret = runtime.core.cutlass_gemm(
3978
+ device.context,
3870
3979
  cc,
3871
3980
  k,
3872
3981
  n,
3873
3982
  m,
3874
3983
  type_typestr(a.dtype).encode(),
3875
- ctypes.c_void_p(a[idx_start:,:,:].ptr),
3876
- ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3877
- ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3878
- ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3984
+ ctypes.c_void_p(a[idx_start:, :, :].ptr),
3985
+ ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
3986
+ ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
3987
+ ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
3879
3988
  alpha,
3880
3989
  1.0,
3881
3990
  a.is_transposed,
@@ -3887,15 +3996,16 @@ def adj_batched_matmul(
3887
3996
  raise RuntimeError("adj_matmul failed.")
3888
3997
  else:
3889
3998
  ret = runtime.core.cutlass_gemm(
3999
+ device.context,
3890
4000
  cc,
3891
4001
  n,
3892
4002
  k,
3893
4003
  m,
3894
4004
  type_typestr(a.dtype).encode(),
3895
- ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3896
- ctypes.c_void_p(a[idx_start:,:,:].ptr),
3897
- ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3898
- ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
4005
+ ctypes.c_void_p(adj_d[idx_start:, :, :].ptr),
4006
+ ctypes.c_void_p(a[idx_start:, :, :].ptr),
4007
+ ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
4008
+ ctypes.c_void_p(adj_b[idx_start:, :, :].ptr),
3899
4009
  alpha,
3900
4010
  1.0,
3901
4011
  False,
@@ -3904,7 +4014,7 @@ def adj_batched_matmul(
3904
4014
  remainder,
3905
4015
  )
3906
4016
  if not ret:
3907
- raise RuntimeError("adj_matmul failed.")
4017
+ raise RuntimeError("adj_matmul failed.")
3908
4018
 
3909
4019
  # adj_c
3910
4020
  warp.launch(
@@ -3912,9 +4022,10 @@ def adj_batched_matmul(
3912
4022
  dim=adj_c.shape,
3913
4023
  inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3914
4024
  device=device,
3915
- record_tape=False
4025
+ record_tape=False,
3916
4026
  )
3917
4027
 
4028
+
3918
4029
  class HashGrid:
3919
4030
  def __init__(self, dim_x, dim_y, dim_z, device=None):
3920
4031
  """Class representing a hash grid object for accelerated point queries.
@@ -3929,14 +4040,16 @@ class HashGrid:
3929
4040
  dim_z (int): Number of cells in z-axis
3930
4041
  """
3931
4042
 
3932
- from warp.context import runtime
4043
+ self.id = 0
4044
+
4045
+ self.runtime = warp.context.runtime
3933
4046
 
3934
- self.device = runtime.get_device(device)
4047
+ self.device = self.runtime.get_device(device)
3935
4048
 
3936
4049
  if self.device.is_cpu:
3937
- self.id = runtime.core.hash_grid_create_host(dim_x, dim_y, dim_z)
4050
+ self.id = self.runtime.core.hash_grid_create_host(dim_x, dim_y, dim_z)
3938
4051
  else:
3939
- self.id = runtime.core.hash_grid_create_device(self.device.context, dim_x, dim_y, dim_z)
4052
+ self.id = self.runtime.core.hash_grid_create_device(self.device.context, dim_x, dim_y, dim_z)
3940
4053
 
3941
4054
  # indicates whether the grid data has been reserved for use by a kernel
3942
4055
  self.reserved = False
@@ -3954,43 +4067,45 @@ class HashGrid:
3954
4067
  the radius used when performing queries.
3955
4068
  """
3956
4069
 
3957
- from warp.context import runtime
3958
-
3959
4070
  if self.device.is_cpu:
3960
- runtime.core.hash_grid_update_host(self.id, radius, ctypes.cast(points.ptr, ctypes.c_void_p), len(points))
4071
+ self.runtime.core.hash_grid_update_host(
4072
+ self.id, radius, ctypes.cast(points.ptr, ctypes.c_void_p), len(points)
4073
+ )
3961
4074
  else:
3962
- runtime.core.hash_grid_update_device(self.id, radius, ctypes.cast(points.ptr, ctypes.c_void_p), len(points))
4075
+ self.runtime.core.hash_grid_update_device(
4076
+ self.id, radius, ctypes.cast(points.ptr, ctypes.c_void_p), len(points)
4077
+ )
3963
4078
  self.reserved = True
3964
4079
 
3965
4080
  def reserve(self, num_points):
3966
- from warp.context import runtime
3967
4081
 
3968
4082
  if self.device.is_cpu:
3969
- runtime.core.hash_grid_reserve_host(self.id, num_points)
4083
+ self.runtime.core.hash_grid_reserve_host(self.id, num_points)
3970
4084
  else:
3971
- runtime.core.hash_grid_reserve_device(self.id, num_points)
4085
+ self.runtime.core.hash_grid_reserve_device(self.id, num_points)
3972
4086
  self.reserved = True
3973
4087
 
3974
4088
  def __del__(self):
3975
- try:
3976
- from warp.context import runtime
3977
4089
 
3978
- if self.device.is_cpu:
3979
- runtime.core.hash_grid_destroy_host(self.id)
3980
- else:
3981
- # use CUDA context guard to avoid side effects during garbage collection
3982
- with self.device.context_guard:
3983
- runtime.core.hash_grid_destroy_device(self.id)
4090
+ if not self.id:
4091
+ return
3984
4092
 
3985
- except Exception:
3986
- pass
4093
+ if self.device.is_cpu:
4094
+ self.runtime.core.hash_grid_destroy_host(self.id)
4095
+ else:
4096
+ # use CUDA context guard to avoid side effects during garbage collection
4097
+ with self.device.context_guard:
4098
+ self.runtime.core.hash_grid_destroy_device(self.id)
3987
4099
 
3988
4100
 
3989
4101
  class MarchingCubes:
3990
4102
  def __init__(self, nx: int, ny: int, nz: int, max_verts: int, max_tris: int, device=None):
3991
- from warp.context import runtime
3992
4103
 
3993
- self.device = runtime.get_device(device)
4104
+ self.id = 0
4105
+
4106
+ self.runtime = warp.context.runtime
4107
+
4108
+ self.device = self.runtime.get_device(device)
3994
4109
 
3995
4110
  if not self.device.is_cuda:
3996
4111
  raise RuntimeError("Only CUDA devices are supported for marching cubes")
@@ -4003,10 +4118,10 @@ class MarchingCubes:
4003
4118
  self.max_tris = max_tris
4004
4119
 
4005
4120
  # bindings to warp.so
4006
- self.alloc = runtime.core.marching_cubes_create_device
4121
+ self.alloc = self.runtime.core.marching_cubes_create_device
4007
4122
  self.alloc.argtypes = [ctypes.c_void_p]
4008
4123
  self.alloc.restype = ctypes.c_uint64
4009
- self.free = runtime.core.marching_cubes_destroy_device
4124
+ self.free = self.runtime.core.marching_cubes_destroy_device
4010
4125
 
4011
4126
  from warp.context import zeros
4012
4127
 
@@ -4017,6 +4132,10 @@ class MarchingCubes:
4017
4132
  self.id = ctypes.c_uint64(self.alloc(self.device.context))
4018
4133
 
4019
4134
  def __del__(self):
4135
+
4136
+ if not self.id:
4137
+ return
4138
+
4020
4139
  # use CUDA context guard to avoid side effects during garbage collection
4021
4140
  with self.device.context_guard:
4022
4141
  # destroy surfacer
@@ -4031,15 +4150,13 @@ class MarchingCubes:
4031
4150
  self.max_tris = max_tris
4032
4151
 
4033
4152
  def surface(self, field: array(dtype=float), threshold: float):
4034
- from warp.context import runtime
4035
-
4036
4153
  # WP_API int marching_cubes_surface_host(const float* field, int nx, int ny, int nz, float threshold, wp::vec3* verts, int* triangles, int max_verts, int max_tris, int* out_num_verts, int* out_num_tris);
4037
4154
  num_verts = ctypes.c_int(0)
4038
4155
  num_tris = ctypes.c_int(0)
4039
4156
 
4040
- runtime.core.marching_cubes_surface_device.restype = ctypes.c_int
4157
+ self.runtime.core.marching_cubes_surface_device.restype = ctypes.c_int
4041
4158
 
4042
- error = runtime.core.marching_cubes_surface_device(
4159
+ error = self.runtime.core.marching_cubes_surface_device(
4043
4160
  self.id,
4044
4161
  ctypes.cast(field.ptr, ctypes.c_void_p),
4045
4162
  self.nx,
@@ -4155,7 +4272,7 @@ def infer_argument_types(args, template_types, arg_names=None):
4155
4272
  arg_name = arg_names[i] if arg_names else str(i)
4156
4273
  if arg_type in warp.types.array_types:
4157
4274
  arg_types.append(arg_type(dtype=arg.dtype, ndim=arg.ndim))
4158
- elif arg_type in warp.types.scalar_types:
4275
+ elif arg_type in warp.types.scalar_types + [bool]:
4159
4276
  arg_types.append(arg_type)
4160
4277
  elif arg_type in [int, float]:
4161
4278
  # canonicalize type