warp-lang 1.5.1__py3-none-manylinux2014_aarch64.whl → 1.6.1__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (131) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1077 -481
  8. warp/codegen.py +250 -122
  9. warp/config.py +65 -21
  10. warp/context.py +500 -149
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_marching_cubes.py +1 -1
  16. warp/examples/core/example_mesh.py +1 -1
  17. warp/examples/core/example_torch.py +18 -34
  18. warp/examples/core/example_wave.py +1 -1
  19. warp/examples/fem/example_apic_fluid.py +1 -0
  20. warp/examples/fem/example_mixed_elasticity.py +1 -1
  21. warp/examples/optim/example_bounce.py +1 -1
  22. warp/examples/optim/example_cloth_throw.py +1 -1
  23. warp/examples/optim/example_diffray.py +4 -15
  24. warp/examples/optim/example_drone.py +1 -1
  25. warp/examples/optim/example_softbody_properties.py +392 -0
  26. warp/examples/optim/example_trajectory.py +1 -3
  27. warp/examples/optim/example_walker.py +5 -0
  28. warp/examples/sim/example_cartpole.py +0 -2
  29. warp/examples/sim/example_cloth_self_contact.py +314 -0
  30. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  31. warp/examples/sim/example_jacobian_ik.py +0 -2
  32. warp/examples/sim/example_quadruped.py +5 -2
  33. warp/examples/tile/example_tile_cholesky.py +79 -0
  34. warp/examples/tile/example_tile_convolution.py +2 -2
  35. warp/examples/tile/example_tile_fft.py +2 -2
  36. warp/examples/tile/example_tile_filtering.py +3 -3
  37. warp/examples/tile/example_tile_matmul.py +4 -4
  38. warp/examples/tile/example_tile_mlp.py +12 -12
  39. warp/examples/tile/example_tile_nbody.py +191 -0
  40. warp/examples/tile/example_tile_walker.py +319 -0
  41. warp/math.py +147 -0
  42. warp/native/array.h +12 -0
  43. warp/native/builtin.h +0 -1
  44. warp/native/bvh.cpp +149 -70
  45. warp/native/bvh.cu +287 -68
  46. warp/native/bvh.h +195 -85
  47. warp/native/clang/clang.cpp +6 -2
  48. warp/native/crt.h +1 -0
  49. warp/native/cuda_util.cpp +35 -0
  50. warp/native/cuda_util.h +5 -0
  51. warp/native/exports.h +40 -40
  52. warp/native/intersect.h +17 -0
  53. warp/native/mat.h +57 -3
  54. warp/native/mathdx.cpp +19 -0
  55. warp/native/mesh.cpp +25 -8
  56. warp/native/mesh.cu +153 -101
  57. warp/native/mesh.h +482 -403
  58. warp/native/quat.h +40 -0
  59. warp/native/solid_angle.h +7 -0
  60. warp/native/sort.cpp +85 -0
  61. warp/native/sort.cu +34 -0
  62. warp/native/sort.h +3 -1
  63. warp/native/spatial.h +11 -0
  64. warp/native/tile.h +1189 -664
  65. warp/native/tile_reduce.h +8 -6
  66. warp/native/vec.h +41 -0
  67. warp/native/warp.cpp +8 -1
  68. warp/native/warp.cu +263 -40
  69. warp/native/warp.h +19 -5
  70. warp/optim/linear.py +22 -4
  71. warp/render/render_opengl.py +132 -59
  72. warp/render/render_usd.py +10 -2
  73. warp/sim/__init__.py +6 -1
  74. warp/sim/collide.py +289 -32
  75. warp/sim/import_urdf.py +20 -5
  76. warp/sim/integrator_euler.py +25 -7
  77. warp/sim/integrator_featherstone.py +147 -35
  78. warp/sim/integrator_vbd.py +842 -40
  79. warp/sim/model.py +173 -112
  80. warp/sim/render.py +2 -2
  81. warp/stubs.py +249 -116
  82. warp/tape.py +28 -30
  83. warp/tests/aux_test_module_unload.py +15 -0
  84. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  85. warp/tests/test_array.py +100 -0
  86. warp/tests/test_assert.py +242 -0
  87. warp/tests/test_codegen.py +14 -61
  88. warp/tests/test_collision.py +8 -8
  89. warp/tests/test_examples.py +16 -1
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_launch.py +77 -26
  94. warp/tests/test_mat.py +213 -168
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +11 -7
  97. warp/tests/test_matmul_lite.py +4 -4
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +6 -5
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_static.py +16 -0
  109. warp/tests/test_tape.py +25 -0
  110. warp/tests/test_tile.py +134 -191
  111. warp/tests/test_tile_load.py +399 -0
  112. warp/tests/test_tile_mathdx.py +61 -8
  113. warp/tests/test_tile_mlp.py +17 -17
  114. warp/tests/test_tile_reduce.py +24 -18
  115. warp/tests/test_tile_shared_memory.py +66 -17
  116. warp/tests/test_tile_view.py +165 -0
  117. warp/tests/test_torch.py +35 -0
  118. warp/tests/test_utils.py +36 -24
  119. warp/tests/test_vec.py +110 -0
  120. warp/tests/unittest_suites.py +29 -4
  121. warp/tests/unittest_utils.py +30 -11
  122. warp/thirdparty/unittest_parallel.py +5 -2
  123. warp/types.py +419 -111
  124. warp/utils.py +9 -5
  125. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/METADATA +86 -45
  126. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/RECORD +129 -118
  127. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/WHEEL +1 -1
  128. warp/examples/benchmarks/benchmark_tile.py +0 -179
  129. warp/native/tile_gemm.h +0 -341
  130. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/LICENSE.md +0 -0
  131. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -171,8 +171,7 @@ def vector(length, dtype):
171
171
  iter(value)
172
172
  except TypeError:
173
173
  raise TypeError(
174
- f"Expected to assign a slice from a sequence of values "
175
- f"but got `{type(value).__name__}` instead"
174
+ f"Expected to assign a slice from a sequence of values but got `{type(value).__name__}` instead"
176
175
  ) from None
177
176
 
178
177
  if self._wp_scalar_type_ == float16:
@@ -350,6 +349,9 @@ def matrix(shape, dtype):
350
349
  f"Invalid number of arguments in matrix constructor, expected {self._length_} elements, got {num_args}"
351
350
  )
352
351
 
352
+ def __len__(self):
353
+ return self._shape_[0]
354
+
353
355
  def __add__(self, y):
354
356
  return warp.add(self, y)
355
357
 
@@ -419,7 +421,7 @@ def matrix(shape, dtype):
419
421
  iter(v)
420
422
  except TypeError:
421
423
  raise TypeError(
422
- f"Expected to assign a slice from a sequence of values " f"but got `{type(v).__name__}` instead"
424
+ f"Expected to assign a slice from a sequence of values but got `{type(v).__name__}` instead"
423
425
  ) from None
424
426
 
425
427
  row_start = r * self._shape_[1]
@@ -676,6 +678,10 @@ def transformation(dtype=Any):
676
678
 
677
679
  def __init__(self, *args, **kwargs):
678
680
  if len(args) == 1 and len(kwargs) == 0:
681
+ if is_float(args[0]):
682
+ # Initialize from a single scalar.
683
+ super().__init__(args[0])
684
+ return
679
685
  if args[0]._wp_generic_type_str_ == self._wp_generic_type_str_:
680
686
  # Copy constructor.
681
687
  super().__init__(*args[0])
@@ -1314,7 +1320,7 @@ def type_repr(t):
1314
1320
  if is_array(t):
1315
1321
  return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
1316
1322
  if is_tile(t):
1317
- return str(f"tile(dtype={t.dtype}, m={t.M}, n={t.N})")
1323
+ return str(f"tile(dtype={t.dtype}, shape={t.shape}")
1318
1324
  if type_is_vector(t):
1319
1325
  return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
1320
1326
  if type_is_matrix(t):
@@ -1357,6 +1363,11 @@ def type_is_matrix(t):
1357
1363
  return getattr(t, "_wp_generic_type_hint_", None) is Matrix
1358
1364
 
1359
1365
 
1366
+ # returns True if the passed *type* is a transformation
1367
+ def type_is_transformation(t):
1368
+ return getattr(t, "_wp_generic_type_hint_", None) is Transformation
1369
+
1370
+
1360
1371
  value_types = (int, float, builtins.bool) + scalar_types
1361
1372
 
1362
1373
 
@@ -1514,7 +1525,7 @@ def strides_from_shape(shape: Tuple, dtype):
1514
1525
 
1515
1526
 
1516
1527
  def check_array_shape(shape: Tuple):
1517
- """Checks that the size in each dimension is positive and less than 2^32."""
1528
+ """Checks that the size in each dimension is positive and less than 2^31."""
1518
1529
 
1519
1530
  for dim_index, dim_size in enumerate(shape):
1520
1531
  if dim_size < 0:
@@ -1701,8 +1712,22 @@ class array(Array):
1701
1712
  )
1702
1713
  elif length is not None:
1703
1714
  # backward compatibility
1715
+ warp.utils.warn(
1716
+ "The 'length' keyword is deprecated and will be removed in a future version. Use 'shape' instead.",
1717
+ category=DeprecationWarning,
1718
+ stacklevel=2,
1719
+ )
1704
1720
  shape = (length,)
1705
1721
 
1722
+ if owner:
1723
+ warp.utils.warn(
1724
+ "The 'owner' keyword in the array initializer is\n"
1725
+ "deprecated and will be removed in a future version. It currently has no effect.\n"
1726
+ "Pass a function to the 'deleter' keyword instead.",
1727
+ category=DeprecationWarning,
1728
+ stacklevel=2,
1729
+ )
1730
+
1706
1731
  # determine the construction path from the given arguments
1707
1732
  if data is not None:
1708
1733
  # data or ptr, not both
@@ -1734,32 +1759,6 @@ class array(Array):
1734
1759
  if not hasattr(data, "__len__"):
1735
1760
  raise RuntimeError(f"Data must be a sequence or array, got scalar {data}")
1736
1761
 
1737
- if hasattr(data, "__cuda_array_interface__"):
1738
- try:
1739
- # Performance note: try first, ask questions later
1740
- device = warp.context.runtime.get_device(device)
1741
- except Exception:
1742
- # Fallback to using the public API for retrieving the device,
1743
- # which takes take of initializing Warp if needed.
1744
- device = warp.context.get_device(device)
1745
-
1746
- if device.is_cuda:
1747
- desc = data.__cuda_array_interface__
1748
- shape = desc.get("shape")
1749
- strides = desc.get("strides")
1750
- dtype = np_dtype_to_warp_type[np.dtype(desc.get("typestr"))]
1751
- ptr = desc.get("data")[0]
1752
-
1753
- self._init_from_ptr(ptr, dtype, shape, strides, None, device, False, None)
1754
-
1755
- # keep a ref to the source data to keep allocation alive
1756
- self._ref = data
1757
- return
1758
- else:
1759
- raise RuntimeError(
1760
- f"Trying to construct a Warp array from data argument's __cuda_array_interface__ but {device} is not CUDA-capable"
1761
- )
1762
-
1763
1762
  if hasattr(dtype, "_wp_scalar_type_"):
1764
1763
  dtype_shape = dtype._shape_
1765
1764
  dtype_ndim = len(dtype_shape)
@@ -1769,6 +1768,76 @@ class array(Array):
1769
1768
  dtype_ndim = 0
1770
1769
  scalar_dtype = dtype
1771
1770
 
1771
+ try:
1772
+ # Performance note: try first, ask questions later
1773
+ device = warp.context.runtime.get_device(device)
1774
+ except Exception:
1775
+ # Fallback to using the public API for retrieving the device,
1776
+ # which takes take of initializing Warp if needed.
1777
+ device = warp.context.get_device(device)
1778
+
1779
+ if device.is_cuda and hasattr(data, "__cuda_array_interface__"):
1780
+ desc = data.__cuda_array_interface__
1781
+ data_shape = desc.get("shape")
1782
+ data_strides = desc.get("strides")
1783
+ data_dtype = np.dtype(desc.get("typestr"))
1784
+ data_ptr = desc.get("data")[0]
1785
+
1786
+ if dtype == Any:
1787
+ dtype = np_dtype_to_warp_type[data_dtype]
1788
+
1789
+ if data_strides is None:
1790
+ data_strides = strides_from_shape(data_shape, dtype)
1791
+
1792
+ data_ndim = len(data_shape)
1793
+
1794
+ # determine whether the input needs reshaping
1795
+ target_npshape = None
1796
+ if shape is not None:
1797
+ target_npshape = (*shape, *dtype_shape)
1798
+ elif dtype_ndim > 0:
1799
+ # prune inner dimensions of length 1
1800
+ while data_ndim > 1 and data_shape[-1] == 1:
1801
+ data_shape = data_shape[:-1]
1802
+ # if the inner dims don't match exactly, check if the innermost dim is a multiple of type length
1803
+ if data_ndim < dtype_ndim or data_shape[-dtype_ndim:] != dtype_shape:
1804
+ if data_shape[-1] == dtype._length_:
1805
+ target_npshape = (*data_shape[:-1], *dtype_shape)
1806
+ elif data_shape[-1] % dtype._length_ == 0:
1807
+ target_npshape = (*data_shape[:-1], data_shape[-1] // dtype._length_, *dtype_shape)
1808
+ else:
1809
+ if dtype_ndim == 1:
1810
+ raise RuntimeError(
1811
+ f"The inner dimensions of the input data are not compatible with the requested vector type {warp.context.type_str(dtype)}: expected an inner dimension that is a multiple of {dtype._length_}"
1812
+ )
1813
+ else:
1814
+ raise RuntimeError(
1815
+ f"The inner dimensions of the input data are not compatible with the requested matrix type {warp.context.type_str(dtype)}: expected inner dimensions {dtype._shape_} or a multiple of {dtype._length_}"
1816
+ )
1817
+
1818
+ if target_npshape is None:
1819
+ target_npshape = data_shape if shape is None else shape
1820
+
1821
+ # determine final shape and strides
1822
+ if dtype_ndim > 0:
1823
+ # make sure the inner dims are contiguous for vector/matrix types
1824
+ scalar_size = type_size_in_bytes(dtype._wp_scalar_type_)
1825
+ inner_contiguous = data_strides[-1] == scalar_size
1826
+ if inner_contiguous and dtype_ndim > 1:
1827
+ inner_contiguous = data_strides[-2] == scalar_size * dtype_shape[-1]
1828
+
1829
+ shape = target_npshape[:-dtype_ndim] or (1,)
1830
+ strides = data_strides if shape == data_shape else strides_from_shape(shape, dtype)
1831
+ else:
1832
+ shape = target_npshape or (1,)
1833
+ strides = data_strides if shape == data_shape else strides_from_shape(shape, dtype)
1834
+
1835
+ self._init_from_ptr(data_ptr, dtype, shape, strides, None, device, False, None)
1836
+
1837
+ # keep a ref to the source data to keep allocation alive
1838
+ self._ref = data
1839
+ return
1840
+
1772
1841
  # convert input data to ndarray (handles lists, tuples, etc.) and determine dtype
1773
1842
  if dtype == Any:
1774
1843
  # infer dtype from data
@@ -1971,7 +2040,21 @@ class array(Array):
1971
2040
  else:
1972
2041
  strides = tuple(strides)
1973
2042
  is_contiguous = strides == contiguous_strides
1974
- capacity = shape[0] * strides[0]
2043
+
2044
+ # To calculate the required capacity, find the dimension with largest stride.
2045
+ # Normally it is the first one, but it could be different (e.g., transposed arrays).
2046
+ max_stride = strides[0]
2047
+ max_dim = 0
2048
+ for i in range(1, ndim):
2049
+ if strides[i] > max_stride:
2050
+ max_stride = strides[i]
2051
+ max_dim = i
2052
+
2053
+ if max_stride > 0:
2054
+ capacity = shape[max_dim] * strides[max_dim]
2055
+ else:
2056
+ # single element storage with zero strides
2057
+ capacity = dtype_size
1975
2058
 
1976
2059
  allocator = device.get_allocator(pinned=pinned)
1977
2060
  if capacity > 0:
@@ -1990,6 +2073,7 @@ class array(Array):
1990
2073
  self.pinned = pinned if device.is_cpu else False
1991
2074
  self.is_contiguous = is_contiguous
1992
2075
  self.deleter = allocator.deleter
2076
+ self._allocator = allocator
1993
2077
 
1994
2078
  def _init_annotation(self, dtype, ndim):
1995
2079
  self.dtype = dtype
@@ -2706,6 +2790,52 @@ class array(Array):
2706
2790
  a._ref = self
2707
2791
  return a
2708
2792
 
2793
+ def ipc_handle(self) -> bytes:
2794
+ """Return an IPC handle of the array as a 64-byte ``bytes`` object
2795
+
2796
+ :func:`from_ipc_handle` can be used with this handle in another process
2797
+ to obtain a :class:`array` that shares the same underlying memory
2798
+ allocation.
2799
+
2800
+ IPC is currently only supported on Linux.
2801
+ Additionally, IPC is only supported for arrays allocated using
2802
+ the default memory allocator.
2803
+
2804
+ :class:`Event` objects created with the ``interprocess=True`` argument
2805
+ may similarly be shared between processes to synchronize GPU work.
2806
+
2807
+ Example:
2808
+ Temporarily using the default memory allocator to allocate an array
2809
+ and get its IPC handle::
2810
+
2811
+ with wp.ScopedMempool("cuda:0", False):
2812
+ test_array = wp.full(1024, value=42.0, dtype=wp.float32, device="cuda:0")
2813
+ ipc_handle = test_array.ipc_handle()
2814
+
2815
+ Raises:
2816
+ RuntimeError: The array is not associated with a CUDA device.
2817
+ RuntimeError: The CUDA device does not appear to support IPC.
2818
+ RuntimeError: The array was allocated using the :ref:`mempool memory allocator <mempool_allocators>`.
2819
+ """
2820
+
2821
+ if self.device is None or not self.device.is_cuda:
2822
+ raise RuntimeError("IPC requires a CUDA device")
2823
+ elif self.device.is_ipc_supported is False:
2824
+ raise RuntimeError("IPC does not appear to be supported on this CUDA device")
2825
+ elif isinstance(self._allocator, warp.context.CudaMempoolAllocator):
2826
+ raise RuntimeError(
2827
+ "Currently, IPC is only supported for arrays using the default memory allocator.\n"
2828
+ "See https://nvidia.github.io/warp/modules/allocators.html for instructions on how to disable\n"
2829
+ f"the mempool allocator on device {self.device}."
2830
+ )
2831
+
2832
+ # Allocate a buffer for the data (64-element char array)
2833
+ ipc_handle_buffer = (ctypes.c_char * 64)()
2834
+
2835
+ warp.context.runtime.core.cuda_ipc_get_mem_handle(self.ptr, ipc_handle_buffer)
2836
+
2837
+ return ipc_handle_buffer.raw
2838
+
2709
2839
 
2710
2840
  # aliases for arrays with small dimensions
2711
2841
  def array1d(*args, **kwargs):
@@ -2733,7 +2863,13 @@ def array4d(*args, **kwargs):
2733
2863
 
2734
2864
  def from_ptr(ptr, length, dtype=None, shape=None, device=None):
2735
2865
  warp.utils.warn(
2736
- "This version of wp.from_ptr() is deprecated. OmniGraph applications should use from_omni_graph_ptr() instead. In the future, wp.from_ptr() will work only with regular pointers.",
2866
+ """This version of wp.from_ptr() is deprecated. OmniGraph
2867
+ applications should use from_omni_graph_ptr() instead. To create an array
2868
+ from a C pointer, use the array constructor and pass the ptr argument as a
2869
+ uint64 value representing the start address in memory where the existing
2870
+ array resides. For example, if using ctypes, pass
2871
+ ptr=ctypes.cast(pointer, ctypes.POINTER(ctypes.c_size_t)).contents.value.
2872
+ Be sure to also specify the dtype and shape parameters.""",
2737
2873
  category=DeprecationWarning,
2738
2874
  )
2739
2875
 
@@ -2748,6 +2884,51 @@ def from_ptr(ptr, length, dtype=None, shape=None, device=None):
2748
2884
  )
2749
2885
 
2750
2886
 
2887
+ def _close_cuda_ipc_handle(ptr, size):
2888
+ warp.context.runtime.core.cuda_ipc_close_mem_handle(ptr)
2889
+
2890
+
2891
+ def from_ipc_handle(
2892
+ handle: bytes, dtype, shape: Tuple[int, ...], strides: Optional[Tuple[int, ...]] = None, device=None
2893
+ ) -> array:
2894
+ """Create an array from an IPC handle.
2895
+
2896
+ The ``dtype``, ``shape``, and optional ``strides`` arguments should
2897
+ match the values from the :class:`array` from which ``handle`` was created.
2898
+
2899
+ Args:
2900
+ handle: The interprocess memory handle for an existing device memory allocation.
2901
+ dtype: One of the available `data types <#data-types>`_, such as :class:`warp.float32`, :class:`warp.mat33`, or a custom `struct <#structs>`_.
2902
+ shape: Dimensions of the array.
2903
+ strides: Number of bytes in each dimension between successive elements of the array.
2904
+ device (Devicelike): Device to associate with the array.
2905
+
2906
+ Returns:
2907
+ An array created from the existing memory allocation described by the interprocess memory handle ``handle``.
2908
+
2909
+ A copy of the underlying data is not made. Modifications to the array's data will be reflected in the
2910
+ original process from which ``handle`` was exported.
2911
+
2912
+ Raises:
2913
+ RuntimeError: IPC is not supported on ``device``.
2914
+ """
2915
+
2916
+ try:
2917
+ # Performance note: try first, ask questions later
2918
+ device = warp.context.runtime.get_device(device)
2919
+ except Exception:
2920
+ # Fallback to using the public API for retrieving the device,
2921
+ # which takes take of initializing Warp if needed.
2922
+ device = warp.context.get_device(device)
2923
+
2924
+ if device.is_ipc_supported is False:
2925
+ raise RuntimeError(f"IPC is not supported on device {device}.")
2926
+
2927
+ ptr = warp.context.runtime.core.cuda_ipc_open_mem_handle(device.context, handle)
2928
+
2929
+ return array(ptr=ptr, dtype=dtype, shape=shape, strides=strides, device=device, deleter=_close_cuda_ipc_handle)
2930
+
2931
+
2751
2932
  # A base class for non-contiguous arrays, providing the implementation of common methods like
2752
2933
  # contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
2753
2934
  class noncontiguous_array_base(Generic[T]):
@@ -2985,25 +3166,38 @@ def array_type_id(a):
2985
3166
  raise ValueError("Invalid array type")
2986
3167
 
2987
3168
 
2988
- # tile expression objects
3169
+ # tile object
2989
3170
  class Tile:
2990
3171
  alignment = 16
2991
3172
 
2992
- def __init__(self, dtype, M, N, op=None, storage="register", layout="rowmajor", strides=None, owner=True):
3173
+ def __init__(self, dtype, shape, op=None, storage="register", layout="rowmajor", strides=None, owner=True):
2993
3174
  self.dtype = type_to_warp(dtype)
2994
- self.M = M
2995
- self.N = N
3175
+ self.shape = shape
2996
3176
  self.op = op
2997
3177
  self.storage = storage
2998
3178
  self.layout = layout
3179
+ self.strides = strides
2999
3180
 
3000
- if strides is None:
3001
- if layout == "rowmajor":
3002
- self.strides = (N, 1)
3003
- elif layout == "colmajor":
3004
- self.strides = (1, M)
3005
- else:
3006
- self.strides = strides
3181
+ # handle case where shape is concrete (rather than just Any)
3182
+ if isinstance(self.shape, (list, tuple)):
3183
+ if len(shape) == 0:
3184
+ raise RuntimeError("Empty shape specified, must have at least 1 dimension")
3185
+
3186
+ # compute total size
3187
+ self.size = 1
3188
+ for s in self.shape:
3189
+ self.size *= s
3190
+
3191
+ # if strides are not provided compute default strides
3192
+ if self.strides is None:
3193
+ self.strides = [1] * len(self.shape)
3194
+
3195
+ if layout == "rowmajor":
3196
+ for i in range(len(self.shape) - 2, -1, -1):
3197
+ self.strides[i] = self.strides[i + 1] * self.shape[i + 1]
3198
+ else:
3199
+ for i in range(1, len(shape)):
3200
+ self.strides[i] = self.strides[i - 1] * self.shape[i - 1]
3007
3201
 
3008
3202
  self.owner = owner
3009
3203
 
@@ -3012,9 +3206,9 @@ class Tile:
3012
3206
  from warp.codegen import Var
3013
3207
 
3014
3208
  if self.storage == "register":
3015
- return f"wp::tile_register_t<{Var.type_to_ctype(self.dtype)},{self.M},{self.N}>"
3209
+ return f"wp::tile_register_t<{Var.type_to_ctype(self.dtype)},wp::tile_layout_register_t<wp::tile_shape_t<{','.join(map(str, self.shape))}>>>"
3016
3210
  elif self.storage == "shared":
3017
- return f"wp::tile_shared_t<{Var.type_to_ctype(self.dtype)},{self.M},{self.N},{self.strides[0]}, {self.strides[1]}, {'true' if self.owner else 'false'}>"
3211
+ return f"wp::tile_shared_t<{Var.type_to_ctype(self.dtype)},wp::tile_layout_strided_t<wp::tile_shape_t<{','.join(map(str, self.shape))}>, wp::tile_stride_t<{','.join(map(str, self.strides))}>>, {'true' if self.owner else 'false'}>"
3018
3212
  else:
3019
3213
  raise RuntimeError(f"Unrecognized tile storage type {self.storage}")
3020
3214
 
@@ -3027,24 +3221,33 @@ class Tile:
3027
3221
  elif self.storage == "shared":
3028
3222
  if self.owner:
3029
3223
  # allocate new shared memory tile
3030
- return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},{self.M},{self.N},{'true' if requires_grad else 'false'}>()"
3224
+ return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},wp::tile_shape_t<{','.join(map(str, self.shape))}>,{'true' if requires_grad else 'false'}>()"
3031
3225
  else:
3032
3226
  # tile will be initialized by another call, e.g.: tile_transpose()
3033
3227
  return "NULL"
3034
3228
 
3035
3229
  # return total tile size in bytes
3036
3230
  def size_in_bytes(self):
3037
- num_bytes = self.align(type_size_in_bytes(self.dtype) * self.M * self.N)
3231
+ num_bytes = self.align(type_size_in_bytes(self.dtype) * self.size)
3038
3232
  return num_bytes
3039
3233
 
3234
+ @staticmethod
3235
+ def round_up(bytes):
3236
+ return ((bytes + Tile.alignment - 1) // Tile.alignment) * Tile.alignment
3237
+
3040
3238
  # align tile size to natural boundary, default 16-bytes
3041
3239
  def align(self, bytes):
3042
- return ((bytes + self.alignment - 1) // self.alignment) * self.alignment
3240
+ return Tile.round_up(bytes)
3043
3241
 
3044
3242
 
3045
3243
  class TileZeros(Tile):
3046
- def __init__(self, dtype, M, N, storage="register"):
3047
- Tile.__init__(self, dtype, M, N, op="zeros", storage=storage)
3244
+ def __init__(self, dtype, shape, storage="register"):
3245
+ Tile.__init__(self, dtype, shape, op="zeros", storage=storage)
3246
+
3247
+
3248
+ class TileOnes(Tile):
3249
+ def __init__(self, dtype, shape, storage="register"):
3250
+ Tile.__init__(self, dtype, shape, op="ones", storage=storage)
3048
3251
 
3049
3252
 
3050
3253
  class TileRange(Tile):
@@ -3053,32 +3256,39 @@ class TileRange(Tile):
3053
3256
  self.stop = stop
3054
3257
  self.step = step
3055
3258
 
3056
- M = 1
3057
- N = int((stop - start) / step)
3259
+ n = int((stop - start) / step)
3058
3260
 
3059
- Tile.__init__(self, dtype, M, N, op="arange", storage=storage)
3261
+ Tile.__init__(self, dtype, shape=(n,), op="arange", storage=storage)
3060
3262
 
3061
3263
 
3062
3264
  class TileConstant(Tile):
3063
- def __init__(self, dtype, M, N):
3064
- Tile.__init__(self, dtype, M, N, op="constant", storage="register")
3265
+ def __init__(self, dtype, shape):
3266
+ Tile.__init__(self, dtype, shape, op="constant", storage="register")
3065
3267
 
3066
3268
 
3067
3269
  class TileLoad(Tile):
3068
- def __init__(self, array, M, N, storage="register"):
3069
- Tile.__init__(self, array.dtype, M, N, op="load", storage=storage)
3270
+ def __init__(self, array, shape, storage="register"):
3271
+ Tile.__init__(self, array.dtype, shape, op="load", storage=storage)
3070
3272
 
3071
3273
 
3072
3274
  class TileUnaryMap(Tile):
3073
- def __init__(self, t, storage="register"):
3074
- Tile.__init__(self, t.dtype, t.M, t.N, op="unary_map", storage=storage)
3275
+ def __init__(self, t, dtype=None, storage="register"):
3276
+ Tile.__init__(self, dtype, t.shape, op="unary_map", storage=storage)
3277
+
3278
+ # if no output dtype specified then assume it's the same as the first arg
3279
+ if self.dtype is None:
3280
+ self.dtype = t.dtype
3075
3281
 
3076
3282
  self.t = t
3077
3283
 
3078
3284
 
3079
3285
  class TileBinaryMap(Tile):
3080
- def __init__(self, a, b, storage="register"):
3081
- Tile.__init__(self, a.dtype, a.M, a.N, op="binary_map", storage=storage)
3286
+ def __init__(self, a, b, dtype=None, storage="register"):
3287
+ Tile.__init__(self, dtype, a.shape, op="binary_map", storage=storage)
3288
+
3289
+ # if no output dtype specified then assume it's the same as the first arg
3290
+ if self.dtype is None:
3291
+ self.dtype = a.dtype
3082
3292
 
3083
3293
  self.a = a
3084
3294
  self.b = b
@@ -3086,7 +3296,7 @@ class TileBinaryMap(Tile):
3086
3296
 
3087
3297
  class TileShared(Tile):
3088
3298
  def __init__(self, t):
3089
- Tile.__init__(self, t.dtype, t.M, t.N, "shared", storage="shared")
3299
+ Tile.__init__(self, t.dtype, t.shape, "shared", storage="shared")
3090
3300
 
3091
3301
  self.t = t
3092
3302
 
@@ -3095,35 +3305,66 @@ def is_tile(t):
3095
3305
  return isinstance(t, Tile)
3096
3306
 
3097
3307
 
3308
+ bvh_constructor_values = {"sah": 0, "median": 1, "lbvh": 2}
3309
+
3310
+
3098
3311
  class Bvh:
3099
3312
  def __new__(cls, *args, **kwargs):
3100
3313
  instance = super(Bvh, cls).__new__(cls)
3101
3314
  instance.id = None
3102
3315
  return instance
3103
3316
 
3104
- def __init__(self, lowers, uppers):
3317
+ def __init__(self, lowers: array, uppers: array, constructor: Optional[str] = None):
3105
3318
  """Class representing a bounding volume hierarchy.
3106
3319
 
3320
+ Depending on which device the input bounds live, it can be either a CPU tree or a GPU tree.
3321
+
3107
3322
  Attributes:
3108
- id: Unique identifier for this bvh object, can be passed to kernels.
3323
+ id: Unique identifier for this BVH object, can be passed to kernels.
3109
3324
  device: Device this object lives on, all buffers must live on the same device.
3110
3325
 
3111
3326
  Args:
3112
- lowers (:class:`warp.array`): Array of lower bounds :class:`warp.vec3`
3113
- uppers (:class:`warp.array`): Array of upper bounds :class:`warp.vec3`
3327
+ lowers: Array of lower bounds of data type :class:`warp.vec3`.
3328
+ uppers: Array of upper bounds of data type :class:`warp.vec3`.
3329
+ ``lowers`` and ``uppers`` must live on the same device.
3330
+ constructor: The construction algorithm used to build the tree.
3331
+ Valid choices are ``"sah"``, ``"median"``, ``"lbvh"``, or ``None``.
3332
+ When ``None``, the default constructor will be used (see the note).
3333
+
3334
+ Note:
3335
+ Explanation of BVH constructors:
3336
+
3337
+ - ``"sah"``: A CPU-based top-down constructor where the AABBs are split based on Surface Area
3338
+ Heuristics (SAH). Construction takes slightly longer than others but has the best query
3339
+ performance.
3340
+ - ``"median"``: A CPU-based top-down constructor where the AABBs are split based on the median
3341
+ of centroids of primitives in an AABB. This constructor is faster than SAH but offers
3342
+ inferior query performance.
3343
+ - ``"lbvh"``: A GPU-based bottom-up constructor which maximizes parallelism. Construction is very
3344
+ fast, especially for large models. Query performance is slightly slower than ``"sah"``.
3345
+ - ``None``: The constructor will be automatically chosen based on the device where the tree
3346
+ lives. For a GPU tree, the ``"lbvh"`` constructor will be selected; for a CPU tree, the ``"sah"``
3347
+ constructor will be selected.
3348
+
3349
+ All three constructors are supported for GPU trees. When a CPU-based constructor is selected
3350
+ for a GPU tree, bounds will be copied back to the CPU to run the CPU-based constructor. After
3351
+ construction, the CPU tree will be copied to the GPU.
3352
+
3353
+ Only ``"sah"`` and ``"median"`` are supported for CPU trees. If ``"lbvh"`` is selected for a CPU tree, a
3354
+ warning message will be issued, and the constructor will automatically fall back to ``"sah"``.
3114
3355
  """
3115
3356
 
3116
3357
  if len(lowers) != len(uppers):
3117
- raise RuntimeError("Bvh the same number of lower and upper bounds must be provided")
3358
+ raise RuntimeError("The same number of lower and upper bounds must be provided")
3118
3359
 
3119
3360
  if lowers.device != uppers.device:
3120
- raise RuntimeError("Bvh lower and upper bounds must live on the same device")
3361
+ raise RuntimeError("Lower and upper bounds must live on the same device")
3121
3362
 
3122
3363
  if lowers.dtype != vec3 or not lowers.is_contiguous:
3123
- raise RuntimeError("Bvh lowers should be a contiguous array of type wp.vec3")
3364
+ raise RuntimeError("lowers should be a contiguous array of type wp.vec3")
3124
3365
 
3125
3366
  if uppers.dtype != vec3 or not uppers.is_contiguous:
3126
- raise RuntimeError("Bvh uppers should be a contiguous array of type wp.vec3")
3367
+ raise RuntimeError("uppers should be a contiguous array of type wp.vec3")
3127
3368
 
3128
3369
  self.device = lowers.device
3129
3370
  self.lowers = lowers
@@ -3137,11 +3378,32 @@ class Bvh:
3137
3378
 
3138
3379
  self.runtime = warp.context.runtime
3139
3380
 
3381
+ if constructor is None:
3382
+ if self.device.is_cpu:
3383
+ constructor = "sah"
3384
+ else:
3385
+ constructor = "lbvh"
3386
+
3387
+ if constructor not in bvh_constructor_values:
3388
+ raise ValueError(f"Unrecognized BVH constructor type: {constructor}")
3389
+
3140
3390
  if self.device.is_cpu:
3141
- self.id = self.runtime.core.bvh_create_host(get_data(lowers), get_data(uppers), int(len(lowers)))
3391
+ if constructor == "lbvh":
3392
+ warp.utils.warn(
3393
+ "LBVH constructor is not available for a CPU tree. Falling back to SAH constructor.", stacklevel=2
3394
+ )
3395
+ constructor = "sah"
3396
+
3397
+ self.id = self.runtime.core.bvh_create_host(
3398
+ get_data(lowers), get_data(uppers), int(len(lowers)), bvh_constructor_values[constructor]
3399
+ )
3142
3400
  else:
3143
3401
  self.id = self.runtime.core.bvh_create_device(
3144
- self.device.context, get_data(lowers), get_data(uppers), int(len(lowers))
3402
+ self.device.context,
3403
+ get_data(lowers),
3404
+ get_data(uppers),
3405
+ int(len(lowers)),
3406
+ bvh_constructor_values[constructor],
3145
3407
  )
3146
3408
 
3147
3409
  def __del__(self):
@@ -3156,7 +3418,10 @@ class Bvh:
3156
3418
  self.runtime.core.bvh_destroy_device(self.id)
3157
3419
 
3158
3420
  def refit(self):
3159
- """Refit the BVH. This should be called after users modify the `lowers` and `uppers` arrays."""
3421
+ """Refit the BVH.
3422
+
3423
+ This should be called after users modify the ``lowers`` or ``uppers`` arrays.
3424
+ """
3160
3425
 
3161
3426
  if self.device.is_cpu:
3162
3427
  self.runtime.core.bvh_refit_host(self.id)
@@ -3179,7 +3444,14 @@ class Mesh:
3179
3444
  instance.id = None
3180
3445
  return instance
3181
3446
 
3182
- def __init__(self, points=None, indices=None, velocities=None, support_winding_number=False):
3447
+ def __init__(
3448
+ self,
3449
+ points: array,
3450
+ indices: array,
3451
+ velocities: Optional[array] = None,
3452
+ support_winding_number: bool = False,
3453
+ bvh_constructor: Optional[str] = None,
3454
+ ):
3183
3455
  """Class representing a triangle mesh.
3184
3456
 
3185
3457
  Attributes:
@@ -3187,10 +3459,15 @@ class Mesh:
3187
3459
  device: Device this object lives on, all buffers must live on the same device.
3188
3460
 
3189
3461
  Args:
3190
- points (:class:`warp.array`): Array of vertex positions of type :class:`warp.vec3`
3191
- indices (:class:`warp.array`): Array of triangle indices of type :class:`warp.int32`, should be a 1d array with shape (num_tris * 3)
3192
- velocities (:class:`warp.array`): Array of vertex velocities of type :class:`warp.vec3` (optional)
3193
- support_winding_number (bool): If true the mesh will build additional datastructures to support `wp.mesh_query_point_sign_winding_number()` queries
3462
+ points: Array of vertex positions of data type :class:`warp.vec3`.
3463
+ indices: Array of triangle indices of data type :class:`warp.int32`.
3464
+ Should be a 1D array with shape ``(num_tris * 3)``.
3465
+ velocities: Optional array of vertex velocities of data type :class:`warp.vec3`.
3466
+ support_winding_number: If ``True``, the mesh will build additional
3467
+ data structures to support ``wp.mesh_query_point_sign_winding_number()`` queries.
3468
+ bvh_constructor: The construction algorithm for the underlying BVH
3469
+ (see the docstring of :class:`Bvh` for explanation).
3470
+ Valid choices are ``"sah"``, ``"median"``, ``"lbvh"``, or ``None``.
3194
3471
  """
3195
3472
 
3196
3473
  if points.device != indices.device:
@@ -3215,7 +3492,22 @@ class Mesh:
3215
3492
 
3216
3493
  self.runtime = warp.context.runtime
3217
3494
 
3495
+ if bvh_constructor is None:
3496
+ if self.device.is_cpu:
3497
+ bvh_constructor = "sah"
3498
+ else:
3499
+ bvh_constructor = "lbvh"
3500
+
3501
+ if bvh_constructor not in bvh_constructor_values:
3502
+ raise ValueError(f"Unrecognized BVH constructor type: {bvh_constructor}")
3503
+
3218
3504
  if self.device.is_cpu:
3505
+ if bvh_constructor == "lbvh":
3506
+ warp.utils.warn(
3507
+ "LBVH constructor is not available for a CPU tree. Falling back to SAH constructor.", stacklevel=2
3508
+ )
3509
+ bvh_constructor = "sah"
3510
+
3219
3511
  self.id = self.runtime.core.mesh_create_host(
3220
3512
  points.__ctype__(),
3221
3513
  velocities.__ctype__() if velocities else array().__ctype__(),
@@ -3223,6 +3515,7 @@ class Mesh:
3223
3515
  int(len(points)),
3224
3516
  int(indices.size / 3),
3225
3517
  int(support_winding_number),
3518
+ bvh_constructor_values[bvh_constructor],
3226
3519
  )
3227
3520
  else:
3228
3521
  self.id = self.runtime.core.mesh_create_device(
@@ -3233,6 +3526,7 @@ class Mesh:
3233
3526
  int(len(points)),
3234
3527
  int(indices.size / 3),
3235
3528
  int(support_winding_number),
3529
+ bvh_constructor_values[bvh_constructor],
3236
3530
  )
3237
3531
 
3238
3532
  def __del__(self):
@@ -3247,7 +3541,10 @@ class Mesh:
3247
3541
  self.runtime.core.mesh_destroy_device(self.id)
3248
3542
 
3249
3543
  def refit(self):
3250
- """Refit the BVH to points. This should be called after users modify the `points` data."""
3544
+ """Refit the BVH to points.
3545
+
3546
+ This should be called after users modify the ``points`` data.
3547
+ """
3251
3548
 
3252
3549
  if self.device.is_cpu:
3253
3550
  self.runtime.core.mesh_refit_host(self.id)
@@ -3260,9 +3557,9 @@ class Mesh:
3260
3557
  """The array of mesh's vertex positions of type :class:`warp.vec3`.
3261
3558
 
3262
3559
  The `Mesh.points` property has a custom setter method. Users can modify the vertex positions in-place,
3263
- but the `refit()` method must be called manually after such modifications. Alternatively, assigning a new array
3560
+ but :meth:`refit` must be called manually after such modifications. Alternatively, assigning a new array
3264
3561
  to this property is also supported. The new array must have the same shape as the original, and once assigned,
3265
- the `Mesh` class will automatically perform a refit operation based on the new vertex positions.
3562
+ The :class:`Mesh` will automatically perform a refit operation based on the new vertex positions.
3266
3563
  """
3267
3564
  return self._points
3268
3565
 
@@ -3270,16 +3567,14 @@ class Mesh:
3270
3567
  def points(self, points_new):
3271
3568
  if points_new.device != self._points.device:
3272
3569
  raise RuntimeError(
3273
- "The new points and the original points must live on the same device, currently "
3274
- "the new points lives on {} while the old points lives on {}.".format(
3275
- points_new.device, self._points.device
3276
- )
3570
+ "The new points and the original points must live on the same device, the "
3571
+ f"new points are on {points_new.device} while the old points are on {self._points.device}."
3277
3572
  )
3278
3573
 
3279
3574
  if points_new.ndim != 1 or points_new.shape[0] != self._points.shape[0]:
3280
3575
  raise RuntimeError(
3281
- "the new points and the original points must have the same shape, currently new points shape is: {},"
3282
- " while the old points' shape is: {}".format(points_new.shape, self._points.shape)
3576
+ "The new points and the original points must have the same shape, the "
3577
+ f"new points' shape is {points_new.shape}, while the old points' shape is {self._points.shape}."
3283
3578
  )
3284
3579
 
3285
3580
  self._points = points_new
@@ -3294,7 +3589,7 @@ class Mesh:
3294
3589
  """The array of mesh's velocities of type :class:`warp.vec3`.
3295
3590
 
3296
3591
  This is a property with a custom setter method. Users can modify the velocities in-place,
3297
- or assigning a new array to this property. No refitting is needed for changing velocities.
3592
+ or assign a new array to this property. No refitting is needed for changing velocities.
3298
3593
  """
3299
3594
  return self._velocities
3300
3595
 
@@ -3302,16 +3597,14 @@ class Mesh:
3302
3597
  def velocities(self, velocities_new):
3303
3598
  if velocities_new.device != self._velocities.device:
3304
3599
  raise RuntimeError(
3305
- "The new points and the original points must live on the same device, currently "
3306
- "the new points lives on {} while the old points lives on {}.".format(
3307
- velocities_new.device, self._velocities.device
3308
- )
3600
+ "The new points and the original points must live on the same device, the "
3601
+ f"new points are on {velocities_new.device} while the old points are on {self._velocities.device}."
3309
3602
  )
3310
3603
 
3311
3604
  if velocities_new.ndim != 1 or velocities_new.shape[0] != self._velocities.shape[0]:
3312
3605
  raise RuntimeError(
3313
- "the new points and the original points must have the same shape, currently new points shape is: {},"
3314
- " while the old points' shape is: {}".format(velocities_new.shape, self._velocities.shape)
3606
+ "The new points and the original points must have the same shape, the "
3607
+ f"new points' shape is {velocities_new.shape}, while the old points' shape is {self._velocities.shape}."
3315
3608
  )
3316
3609
 
3317
3610
  self._velocities = velocities_new
@@ -3337,8 +3630,8 @@ class Volume:
3337
3630
  """Class representing a sparse grid.
3338
3631
 
3339
3632
  Args:
3340
- data (:class:`warp.array`): Array of bytes representing the volume in NanoVDB format
3341
- copy (bool): Whether the incoming data will be copied or aliased
3633
+ data: Array of bytes representing the volume in NanoVDB format.
3634
+ copy: Whether the incoming data will be copied or aliased.
3342
3635
  """
3343
3636
 
3344
3637
  # keep a runtime reference for orderly destruction
@@ -3373,14 +3666,15 @@ class Volume:
3373
3666
  self.runtime.core.volume_destroy_device(self.id)
3374
3667
 
3375
3668
  def array(self) -> array:
3376
- """Returns the raw memory buffer of the Volume as an array"""
3669
+ """Return the raw memory buffer of the :class:`Volume` as an array."""
3670
+
3377
3671
  buf = ctypes.c_void_p(0)
3378
3672
  size = ctypes.c_uint64(0)
3379
3673
  self.runtime.core.volume_get_buffer_info(self.id, ctypes.byref(buf), ctypes.byref(size))
3380
3674
  return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device, owner=False)
3381
3675
 
3382
3676
  def get_tile_count(self) -> int:
3383
- """Returns the number of tiles (NanoVDB leaf nodes) of the volume"""
3677
+ """Return the number of tiles (NanoVDB leaf nodes) of the volume."""
3384
3678
 
3385
3679
  voxel_count, tile_count = (
3386
3680
  ctypes.c_uint64(0),
@@ -3390,11 +3684,12 @@ class Volume:
3390
3684
  return tile_count.value
3391
3685
 
3392
3686
  def get_tiles(self, out: Optional[array] = None) -> array:
3393
- """Returns the integer coordinates of all allocated tiles for this volume.
3687
+ """Return the integer coordinates of all allocated tiles for this volume.
3394
3688
 
3395
3689
  Args:
3396
- out (:class:`warp.array`, optional): If provided, use the `out` array to store the tile coordinates, otherwise
3397
- a new array will be allocated. `out` must be a contiguous array of ``tile_count`` ``vec3i`` or ``tile_count x 3`` ``int32``
3690
+ out: If provided, use the `out` array to store the tile coordinates, otherwise
3691
+ a new array will be allocated. ``out`` must be a contiguous array
3692
+ of ``tile_count`` ``vec3i`` or ``tile_count x 3`` ``int32``
3398
3693
  on the same device as this volume.
3399
3694
  """
3400
3695
 
@@ -3419,7 +3714,7 @@ class Volume:
3419
3714
  return out
3420
3715
 
3421
3716
  def get_voxel_count(self) -> int:
3422
- """Returns the total number of allocated voxels for this volume"""
3717
+ """Return the total number of allocated voxels for this volume"""
3423
3718
 
3424
3719
  voxel_count, tile_count = (
3425
3720
  ctypes.c_uint64(0),
@@ -3429,10 +3724,10 @@ class Volume:
3429
3724
  return voxel_count.value
3430
3725
 
3431
3726
  def get_voxels(self, out: Optional[array] = None) -> array:
3432
- """Returns the integer coordinates of all allocated voxels for this volume.
3727
+ """Return the integer coordinates of all allocated voxels for this volume.
3433
3728
 
3434
3729
  Args:
3435
- out (:class:`warp.array`, optional): If provided, use the `out` array to store the voxel coordinates, otherwise
3730
+ out: If provided, use the `out` array to store the voxel coordinates, otherwise
3436
3731
  a new array will be allocated. `out` must be a contiguous array of ``voxel_count`` ``vec3i`` or ``voxel_count x 3`` ``int32``
3437
3732
  on the same device as this volume.
3438
3733
  """
@@ -3458,7 +3753,7 @@ class Volume:
3458
3753
  return out
3459
3754
 
3460
3755
  def get_voxel_size(self) -> Tuple[float, float, float]:
3461
- """Voxel size, i.e, world coordinates of voxel's diagonal vector"""
3756
+ """Return the voxel size, i.e, world coordinates of voxel's diagonal vector"""
3462
3757
 
3463
3758
  if self.id == 0:
3464
3759
  raise RuntimeError("Invalid Volume")
@@ -3558,7 +3853,7 @@ class Volume:
3558
3853
  return self.get_grid_info().type_str in Volume._nvdb_index_types
3559
3854
 
3560
3855
  def get_feature_array_count(self) -> int:
3561
- """Returns the number of supplemental data arrays stored alongside the grid"""
3856
+ """Return the number of supplemental data arrays stored alongside the grid"""
3562
3857
 
3563
3858
  return self.runtime.core.volume_get_blind_data_count(self.id)
3564
3859
 
@@ -3578,7 +3873,7 @@ class Volume:
3578
3873
  """String describing the type of the array values"""
3579
3874
 
3580
3875
  def get_feature_array_info(self, feature_index: int) -> Volume.FeatureArrayInfo:
3581
- """Returns the metadata associated to the feature array at `feature_index`"""
3876
+ """Return the metadata associated to the feature array at ``feature_index``."""
3582
3877
 
3583
3878
  buf = ctypes.c_void_p(0)
3584
3879
  value_count = ctypes.c_uint64(0)
@@ -3606,11 +3901,12 @@ class Volume:
3606
3901
  )
3607
3902
 
3608
3903
  def feature_array(self, feature_index: int, dtype=None) -> array:
3609
- """Returns one the grid's feature data arrays as a Warp array
3904
+ """Return one the grid's feature data arrays as a Warp array.
3610
3905
 
3611
3906
  Args:
3612
3907
  feature_index: Index of the supplemental data array in the grid
3613
- dtype: Type for the returned Warp array. If not provided, will be deduced from the array metadata.
3908
+ dtype: Data type for the returned Warp array.
3909
+ If not provided, will be deduced from the array metadata.
3614
3910
  """
3615
3911
 
3616
3912
  info = self.get_feature_array_info(feature_index)
@@ -3641,7 +3937,7 @@ class Volume:
3641
3937
 
3642
3938
  @classmethod
3643
3939
  def load_from_nvdb(cls, file_or_buffer, device=None) -> Volume:
3644
- """Creates a Volume object from a serialized NanoVDB file or in-memory buffer.
3940
+ """Create a :class:`Volume` object from a serialized NanoVDB file or in-memory buffer.
3645
3941
 
3646
3942
  Returns:
3647
3943
 
@@ -4302,6 +4598,9 @@ def matmul(
4302
4598
  ):
4303
4599
  """Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
4304
4600
 
4601
+ .. deprecated:: 1.6
4602
+ Use :doc:`tile primitives </modules/tiles>` instead.
4603
+
4305
4604
  Args:
4306
4605
  a (array2d): two-dimensional array containing matrix A
4307
4606
  b (array2d): two-dimensional array containing matrix B
@@ -4314,6 +4613,12 @@ def matmul(
4314
4613
  """
4315
4614
  from warp.context import runtime
4316
4615
 
4616
+ warp.utils.warn(
4617
+ "wp.matmul() is deprecated and will be removed in a\nfuture version. Use tile primitives instead.",
4618
+ category=DeprecationWarning,
4619
+ stacklevel=2,
4620
+ )
4621
+
4317
4622
  device = a.device
4318
4623
 
4319
4624
  if b.device != device or c.device != device or d.device != device:
@@ -4589,6 +4894,9 @@ def batched_matmul(
4589
4894
  ):
4590
4895
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
4591
4896
 
4897
+ .. deprecated:: 1.6
4898
+ Use :doc:`tile primitives </modules/tiles>` instead.
4899
+
4592
4900
  Args:
4593
4901
  a (array3d): three-dimensional array containing A matrices. Overall array dimension is {batch_count, M, K}
4594
4902
  b (array3d): three-dimensional array containing B matrices. Overall array dimension is {batch_count, K, N}