warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/stubs.py CHANGED
@@ -96,7 +96,12 @@ from warp.context import Stream, get_stream, set_stream, wait_stream, synchroniz
96
96
  from warp.context import Event, record_event, wait_event, synchronize_event, get_event_elapsed_time
97
97
  from warp.context import RegisteredGLBuffer
98
98
  from warp.context import is_mempool_supported, is_mempool_enabled, set_mempool_enabled
99
- from warp.context import set_mempool_release_threshold, get_mempool_release_threshold
99
+ from warp.context import (
100
+ set_mempool_release_threshold,
101
+ get_mempool_release_threshold,
102
+ get_mempool_used_mem_current,
103
+ get_mempool_used_mem_high,
104
+ )
100
105
  from warp.context import is_mempool_access_supported, is_mempool_access_enabled, set_mempool_access_enabled
101
106
  from warp.context import is_peer_access_supported, is_peer_access_enabled, set_peer_access_enabled
102
107
 
@@ -132,6 +137,7 @@ from warp.paddle import device_from_paddle, device_to_paddle
132
137
  from warp.paddle import stream_from_paddle
133
138
 
134
139
  from warp.build import clear_kernel_cache
140
+ from warp.build import clear_lto_cache
135
141
 
136
142
  from warp.constants import *
137
143
 
@@ -648,6 +654,18 @@ def matrix(*args: Scalar, shape: Tuple[int, int], dtype: Scalar) -> Matrix[Any,
648
654
  ...
649
655
 
650
656
 
657
+ @over
658
+ def matrix_from_cols(*args: Vector[Any, Scalar]) -> Matrix[Any, Any, Scalar]:
659
+ """Construct a matrix from column vectors."""
660
+ ...
661
+
662
+
663
+ @over
664
+ def matrix_from_rows(*args: Vector[Any, Scalar]) -> Matrix[Any, Any, Scalar]:
665
+ """Construct a matrix from row vectors."""
666
+ ...
667
+
668
+
651
669
  @over
652
670
  def identity(n: int32, dtype: Scalar) -> Matrix[Any, Any, Scalar]:
653
671
  """Create an identity matrix with shape=(n,n) with the type given by ``dtype``."""
@@ -662,6 +680,14 @@ def svd3(A: Matrix[3, 3, Float], U: Matrix[3, 3, Float], sigma: Vector[3, Float]
662
680
  ...
663
681
 
664
682
 
683
+ @over
684
+ def svd2(A: Matrix[2, 2, Float], U: Matrix[2, 2, Float], sigma: Vector[2, Float], V: Matrix[2, 2, Scalar]):
685
+ """Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
686
+ while the left and right basis vectors are returned in ``U`` and ``V``.
687
+ """
688
+ ...
689
+
690
+
665
691
  @over
666
692
  def qr3(A: Matrix[3, 3, Float], Q: Matrix[3, 3, Float], R: Matrix[3, 3, Float]):
667
693
  """Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
@@ -687,7 +713,7 @@ def quaternion(dtype: Float) -> Quaternion[Float]:
687
713
 
688
714
 
689
715
  @over
690
- def quaternion(x: Float, y: Float, z: Float, w: Float) -> Quaternion[Float]:
716
+ def quaternion(x: Float, y: Float, z: Float, w: Float, dtype: Scalar) -> Quaternion[Float]:
691
717
  """Create a quaternion using the supplied components (type inferred from component type)."""
692
718
  ...
693
719
 
@@ -724,7 +750,19 @@ def quat_to_axis_angle(quat: Quaternion[Float], axis: Vector[3, Float], angle: F
724
750
 
725
751
  @over
726
752
  def quat_from_matrix(mat: Matrix[3, 3, Float]) -> Quaternion[Float]:
727
- """Construct a quaternion from a 3x3 matrix."""
753
+ """Construct a quaternion from a 3x3 matrix.
754
+
755
+ If the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.
756
+ """
757
+ ...
758
+
759
+
760
+ @over
761
+ def quat_from_matrix(mat: Matrix[4, 4, Float]) -> Quaternion[Float]:
762
+ """Construct a quaternion from a 4x4 matrix.
763
+
764
+ If the top-left 3x3 block of the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.
765
+ """
728
766
  ...
729
767
 
730
768
 
@@ -1028,7 +1066,7 @@ def tile(x: Any) -> Tile:
1028
1066
 
1029
1067
  This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
1030
1068
 
1031
- * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
1069
+ * If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
1032
1070
  * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
1033
1071
 
1034
1072
  :param x: A per-thread local value, e.g. scalar, vector, or matrix.
@@ -1121,13 +1159,12 @@ def tile_transpose(a: Tile) -> Tile:
1121
1159
  def tile_broadcast(a: Tile, shape: Tuple[int, ...]) -> Tile:
1122
1160
  """Broadcast a tile.
1123
1161
 
1124
- This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n).
1125
-
1162
+ Broadcasts the input tile ``a`` to the destination shape.
1126
1163
  Broadcasting follows NumPy broadcast rules.
1127
1164
 
1128
1165
  :param a: Tile to broadcast
1129
1166
  :param shape: The shape to broadcast to
1130
- :returns: Tile with broadcast ``shape=(m, n)``
1167
+ :returns: Tile with broadcast shape
1131
1168
  """
1132
1169
  ...
1133
1170
 
@@ -1810,6 +1847,18 @@ def randi(state: uint32, low: int32, high: int32) -> int:
1810
1847
  ...
1811
1848
 
1812
1849
 
1850
+ @over
1851
+ def randu(state: uint32) -> uint32:
1852
+ """Return a random unsigned integer in the range [0, 2^32)."""
1853
+ ...
1854
+
1855
+
1856
+ @over
1857
+ def randu(state: uint32, low: uint32, high: uint32) -> uint32:
1858
+ """Return a random unsigned integer between [low, high)."""
1859
+ ...
1860
+
1861
+
1813
1862
  @over
1814
1863
  def randf(state: uint32) -> float:
1815
1864
  """Return a random float between [0.0, 1.0)."""
@@ -2029,61 +2078,171 @@ def tid() -> Tuple[int, int, int, int]:
2029
2078
 
2030
2079
  @over
2031
2080
  def select(cond: bool, value_if_false: Any, value_if_true: Any) -> Any:
2032
- """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``"""
2081
+ """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
2082
+
2083
+ .. deprecated:: 1.7
2084
+ Use :func:`where` instead, which has the more intuitive argument order:
2085
+ ``where(cond, value_if_true, value_if_false)``.
2086
+ """
2033
2087
  ...
2034
2088
 
2035
2089
 
2036
2090
  @over
2037
2091
  def select(cond: int8, value_if_false: Any, value_if_true: Any) -> Any:
2038
- """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``"""
2092
+ """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
2093
+
2094
+ .. deprecated:: 1.7
2095
+ Use :func:`where` instead, which has the more intuitive argument order:
2096
+ ``where(cond, value_if_true, value_if_false)``.
2097
+ """
2039
2098
  ...
2040
2099
 
2041
2100
 
2042
2101
  @over
2043
2102
  def select(cond: uint8, value_if_false: Any, value_if_true: Any) -> Any:
2044
- """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``"""
2103
+ """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
2104
+
2105
+ .. deprecated:: 1.7
2106
+ Use :func:`where` instead, which has the more intuitive argument order:
2107
+ ``where(cond, value_if_true, value_if_false)``.
2108
+ """
2045
2109
  ...
2046
2110
 
2047
2111
 
2048
2112
  @over
2049
2113
  def select(cond: int16, value_if_false: Any, value_if_true: Any) -> Any:
2050
- """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``"""
2114
+ """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
2115
+
2116
+ .. deprecated:: 1.7
2117
+ Use :func:`where` instead, which has the more intuitive argument order:
2118
+ ``where(cond, value_if_true, value_if_false)``.
2119
+ """
2051
2120
  ...
2052
2121
 
2053
2122
 
2054
2123
  @over
2055
2124
  def select(cond: uint16, value_if_false: Any, value_if_true: Any) -> Any:
2056
- """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``"""
2125
+ """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
2126
+
2127
+ .. deprecated:: 1.7
2128
+ Use :func:`where` instead, which has the more intuitive argument order:
2129
+ ``where(cond, value_if_true, value_if_false)``.
2130
+ """
2057
2131
  ...
2058
2132
 
2059
2133
 
2060
2134
  @over
2061
2135
  def select(cond: int32, value_if_false: Any, value_if_true: Any) -> Any:
2062
- """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``"""
2136
+ """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
2137
+
2138
+ .. deprecated:: 1.7
2139
+ Use :func:`where` instead, which has the more intuitive argument order:
2140
+ ``where(cond, value_if_true, value_if_false)``.
2141
+ """
2063
2142
  ...
2064
2143
 
2065
2144
 
2066
2145
  @over
2067
2146
  def select(cond: uint32, value_if_false: Any, value_if_true: Any) -> Any:
2068
- """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``"""
2147
+ """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
2148
+
2149
+ .. deprecated:: 1.7
2150
+ Use :func:`where` instead, which has the more intuitive argument order:
2151
+ ``where(cond, value_if_true, value_if_false)``.
2152
+ """
2069
2153
  ...
2070
2154
 
2071
2155
 
2072
2156
  @over
2073
2157
  def select(cond: int64, value_if_false: Any, value_if_true: Any) -> Any:
2074
- """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``"""
2158
+ """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
2159
+
2160
+ .. deprecated:: 1.7
2161
+ Use :func:`where` instead, which has the more intuitive argument order:
2162
+ ``where(cond, value_if_true, value_if_false)``.
2163
+ """
2075
2164
  ...
2076
2165
 
2077
2166
 
2078
2167
  @over
2079
2168
  def select(cond: uint64, value_if_false: Any, value_if_true: Any) -> Any:
2080
- """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``"""
2169
+ """Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
2170
+
2171
+ .. deprecated:: 1.7
2172
+ Use :func:`where` instead, which has the more intuitive argument order:
2173
+ ``where(cond, value_if_true, value_if_false)``.
2174
+ """
2081
2175
  ...
2082
2176
 
2083
2177
 
2084
2178
  @over
2085
2179
  def select(arr: Array[Any], value_if_false: Any, value_if_true: Any) -> Any:
2086
- """Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``"""
2180
+ """Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
2181
+
2182
+ .. deprecated:: 1.7
2183
+ Use :func:`where` instead, which has the more intuitive argument order:
2184
+ ``where(arr, value_if_true, value_if_false)``.
2185
+ """
2186
+ ...
2187
+
2188
+
2189
+ @over
2190
+ def where(cond: bool, value_if_true: Any, value_if_false: Any) -> Any:
2191
+ """Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
2192
+ ...
2193
+
2194
+
2195
+ @over
2196
+ def where(cond: int8, value_if_true: Any, value_if_false: Any) -> Any:
2197
+ """Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
2198
+ ...
2199
+
2200
+
2201
+ @over
2202
+ def where(cond: uint8, value_if_true: Any, value_if_false: Any) -> Any:
2203
+ """Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
2204
+ ...
2205
+
2206
+
2207
+ @over
2208
+ def where(cond: int16, value_if_true: Any, value_if_false: Any) -> Any:
2209
+ """Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
2210
+ ...
2211
+
2212
+
2213
+ @over
2214
+ def where(cond: uint16, value_if_true: Any, value_if_false: Any) -> Any:
2215
+ """Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
2216
+ ...
2217
+
2218
+
2219
+ @over
2220
+ def where(cond: int32, value_if_true: Any, value_if_false: Any) -> Any:
2221
+ """Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
2222
+ ...
2223
+
2224
+
2225
+ @over
2226
+ def where(cond: uint32, value_if_true: Any, value_if_false: Any) -> Any:
2227
+ """Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
2228
+ ...
2229
+
2230
+
2231
+ @over
2232
+ def where(cond: int64, value_if_true: Any, value_if_false: Any) -> Any:
2233
+ """Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
2234
+ ...
2235
+
2236
+
2237
+ @over
2238
+ def where(cond: uint64, value_if_true: Any, value_if_false: Any) -> Any:
2239
+ """Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
2240
+ ...
2241
+
2242
+
2243
+ @over
2244
+ def where(arr: Array[Any], value_if_true: Any, value_if_false: Any) -> Any:
2245
+ """Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``."""
2087
2246
  ...
2088
2247
 
2089
2248
 
@@ -2492,7 +2651,19 @@ def expect_near(a: Float, b: Float, tolerance: Float):
2492
2651
 
2493
2652
 
2494
2653
  @over
2495
- def expect_near(a: vec3f, b: vec3f, tolerance: float32):
2654
+ def expect_near(a: Vector[Any, Float], b: Vector[Any, Float], tolerance: Float):
2655
+ """Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude"""
2656
+ ...
2657
+
2658
+
2659
+ @over
2660
+ def expect_near(a: Quaternion[Float], b: Quaternion[Float], tolerance: Float):
2661
+ """Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude"""
2662
+ ...
2663
+
2664
+
2665
+ @over
2666
+ def expect_near(a: Matrix[Any, Any, Float], b: Matrix[Any, Any, Float], tolerance: Float):
2496
2667
  """Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude"""
2497
2668
  ...
2498
2669
 
@@ -2901,7 +3072,7 @@ def tile_matmul(a: Tile, b: Tile, out: Tile) -> Tile:
2901
3072
  * fp16, fp32, fp64 (real)
2902
3073
  * vec2h, vec2f, vec2d (complex)
2903
3074
 
2904
- All input and output tiles must have the same datatype. Tile data will be automatically be migrated
3075
+ All input and output tiles must have the same datatype. Tile data will automatically be migrated
2905
3076
  to shared memory if necessary and will use TensorCore operations when available.
2906
3077
 
2907
3078
  :param a: A tile with ``shape=(M, K)``
@@ -2920,7 +3091,7 @@ def tile_matmul(a: Tile, b: Tile) -> Tile:
2920
3091
  * fp16, fp32, fp64 (real)
2921
3092
  * vec2h, vec2f, vec2d (complex)
2922
3093
 
2923
- Both input tiles must have the same datatype. Tile data will be automatically be migrated
3094
+ Both input tiles must have the same datatype. Tile data will automatically be migrated
2924
3095
  to shared memory if necessary and will use TensorCore operations when available.
2925
3096
 
2926
3097
  :param a: A tile with ``shape=(M, K)``
@@ -3134,3 +3305,29 @@ def smooth_normalize(v: Any, delta: float):
3134
3305
  Vector[Any,Float]: The normalized vector.
3135
3306
  """
3136
3307
  ...
3308
+
3309
+
3310
+ @over
3311
+ def transform_from_matrix(mat: Matrix[4, 4, float32]) -> Transformation[float32]:
3312
+ """Construct a transformation from a 4x4 matrix.
3313
+
3314
+ Args:
3315
+ mat (Matrix[4, 4, Float]): Matrix to convert.
3316
+
3317
+ Returns:
3318
+ Transformation[Float]: The transformation.
3319
+ """
3320
+ ...
3321
+
3322
+
3323
+ @over
3324
+ def transform_to_matrix(xform: Transformation[float32]) -> Matrix[4, 4, float32]:
3325
+ """Convert a transformation to a 4x4 matrix.
3326
+
3327
+ Args:
3328
+ xform (Transformation[Float]): Transformation to convert.
3329
+
3330
+ Returns:
3331
+ Matrix[4, 4, Float]: The matrix.
3332
+ """
3333
+ ...
warp/tests/__main__.py CHANGED
@@ -1,18 +1,3 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
1
  from warp.thirdparty.unittest_parallel import main
17
2
 
18
3
  if __name__ == "__main__":
@@ -59,9 +59,9 @@
59
59
  }
60
60
  defaultPrim = "World"
61
61
  endTimeCode = 100
62
+ framesPerSecond = 24
62
63
  metersPerUnit = 0.01
63
64
  startTimeCode = 0
64
- timeCodesPerSecond = 24
65
65
  upAxis = "Y"
66
66
  )
67
67
 
File without changes
@@ -71,6 +71,44 @@ def test_mempool_release_threshold(test, device):
71
71
  test.assertEqual(wp.get_mempool_release_threshold(device), saved_threshold)
72
72
 
73
73
 
74
+ def test_mempool_usage_queries(test, device):
75
+ """Check API to query mempool memory usage."""
76
+
77
+ device = wp.get_device(device)
78
+ pre_alloc_mempool_usage_curr = wp.get_mempool_used_mem_current(device)
79
+ pre_alloc_mempool_usage_high = wp.get_mempool_used_mem_high(device)
80
+
81
+ # Allocate a 1 MiB array
82
+ test_data = wp.empty(262144, dtype=wp.float32, device=device)
83
+ wp.synchronize_device(device)
84
+
85
+ # Query memory usage again
86
+ post_alloc_mempool_usage_curr = wp.get_mempool_used_mem_current(device)
87
+ post_alloc_mempool_usage_high = wp.get_mempool_used_mem_high(device)
88
+
89
+ test.assertEqual(
90
+ post_alloc_mempool_usage_curr, pre_alloc_mempool_usage_curr + 1048576, "Memory usage did not increase by 1 MiB"
91
+ )
92
+ test.assertGreaterEqual(post_alloc_mempool_usage_high, 1048576, "High-water mark is not at least 1 MiB")
93
+
94
+ # Free the allocation
95
+ del test_data
96
+ wp.synchronize_device(device)
97
+
98
+ # Query memory usage
99
+ post_free_mempool_usage_curr = wp.get_mempool_used_mem_current(device)
100
+ post_free_mempool_usage_high = wp.get_mempool_used_mem_high(device)
101
+
102
+ test.assertEqual(
103
+ post_free_mempool_usage_curr,
104
+ pre_alloc_mempool_usage_curr,
105
+ "Test didn't end with the same amount of used memory as the test started with.",
106
+ )
107
+ test.assertEqual(
108
+ post_free_mempool_usage_high, post_alloc_mempool_usage_high, "High-water mark should not change after free"
109
+ )
110
+
111
+
74
112
  def test_mempool_exceptions(test, device):
75
113
  device = wp.get_device(device)
76
114
 
@@ -176,6 +214,7 @@ devices_without_mempools = [d for d in get_test_devices() if not d.is_mempool_su
176
214
  add_function_test(
177
215
  TestMempool, "test_mempool_release_threshold", test_mempool_release_threshold, devices=devices_with_mempools
178
216
  )
217
+ add_function_test(TestMempool, "test_mempool_usage_queries", test_mempool_usage_queries, devices=devices_with_mempools)
179
218
  add_function_test(TestMempool, "test_mempool_access_self", test_mempool_access_self, devices=devices_with_mempools)
180
219
 
181
220
  # test devices without mempool support
@@ -342,6 +342,29 @@ def test_event_elapsed_time(test, device):
342
342
  test.assertGreater(elapsed, 0)
343
343
 
344
344
 
345
+ def test_event_elapsed_time_graph(test, device):
346
+ stream = wp.get_stream(device)
347
+ e1 = wp.Event(device, enable_timing=True)
348
+ e2 = wp.Event(device, enable_timing=True)
349
+
350
+ a = wp.zeros(N, dtype=float, device=device)
351
+
352
+ wp.load_module(device=device)
353
+
354
+ with wp.ScopedCapture(device, force_module_load=False) as capture:
355
+ stream.record_event(e1)
356
+ wp.launch(inc, dim=N, inputs=[a], device=device)
357
+ stream.record_event(e2)
358
+
359
+ wp.capture_launch(capture.graph)
360
+
361
+ wp.synchronize_device(device)
362
+
363
+ elapsed = wp.get_event_elapsed_time(e1, e2)
364
+
365
+ test.assertGreater(elapsed, 0)
366
+
367
+
345
368
  def test_stream_priority_basics(test, device):
346
369
  standard_stream = wp.Stream(device)
347
370
  test.assertEqual(standard_stream.priority, 0, "Default priority of streams must be 0.")
@@ -401,6 +424,52 @@ def test_stream_priority_timings(test, device):
401
424
  test.assertLess(elapsed_hi, elapsed_lo, "Copies on higher-priority stream should be faster.")
402
425
 
403
426
 
427
+ @wp.kernel
428
+ def sum_threads(sum: wp.array(dtype=wp.uint64)):
429
+ i = wp.tid()
430
+ wp.atomic_add(sum, 0, wp.uint64(1))
431
+
432
+
433
+ def test_stream_event_is_complete(test, device):
434
+ with wp.ScopedDevice(device):
435
+ stream = wp.Stream()
436
+ event = wp.Event()
437
+ # No operations on stream, should be complete
438
+ test.assertTrue(stream.is_complete)
439
+
440
+ # Event not recorded yet, should be complete
441
+ test.assertTrue(event.is_complete)
442
+
443
+ a = wp.zeros(1, dtype=wp.uint64)
444
+
445
+ threads = 1024 * 1024 * 8
446
+
447
+ with wp.ScopedStream(stream):
448
+ # Launch some work on the stream and reuse the event
449
+
450
+ for iter in range(5):
451
+ # Kernel takes about 1 ms to run on an RTX 3090
452
+ wp.launch(sum_threads, dim=threads, outputs=[a])
453
+
454
+ stream.record_event(event)
455
+
456
+ # Kernel should still be running
457
+ test.assertFalse(stream.is_complete)
458
+
459
+ # Event should not be finished
460
+ test.assertFalse(event.is_complete)
461
+
462
+ # Force the stream operations to complete
463
+ wp.synchronize_stream(stream)
464
+
465
+ # Now all operations are complete
466
+ test.assertTrue(stream.is_complete)
467
+ test.assertTrue(event.is_complete)
468
+
469
+ # Verify result
470
+ test.assertEqual(a.numpy()[0], (iter + 1) * threads)
471
+
472
+
404
473
  devices = get_selected_cuda_test_devices()
405
474
 
406
475
 
@@ -554,9 +623,11 @@ add_function_test(TestStreams, "test_stream_scope_wait_event", test_stream_scope
554
623
  add_function_test(TestStreams, "test_stream_scope_wait_stream", test_stream_scope_wait_stream, devices=devices)
555
624
  add_function_test(TestStreams, "test_stream_priority_basics", test_stream_priority_basics, devices=devices)
556
625
  add_function_test(TestStreams, "test_stream_priority_timings", test_stream_priority_timings, devices=devices)
626
+ add_function_test(TestStreams, "test_stream_event_is_complete", test_stream_event_is_complete, devices=devices)
557
627
 
558
628
  add_function_test(TestStreams, "test_event_synchronize", test_event_synchronize, devices=devices)
559
629
  add_function_test(TestStreams, "test_event_elapsed_time", test_event_elapsed_time, devices=devices)
630
+ add_function_test(TestStreams, "test_event_elapsed_time_graph", test_event_elapsed_time_graph, devices=devices)
560
631
 
561
632
  if __name__ == "__main__":
562
633
  wp.clear_kernel_cache()
File without changes