warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/autograd.py +12 -2
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +463 -372
- warp/codegen.py +196 -124
- warp/config.py +42 -6
- warp/context.py +496 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_cloth.py +1 -1
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/distributed/example_jacobi_mpi.py +507 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/field.py +11 -1
- warp/fem/field/nodal_field.py +56 -88
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +16 -13
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +7 -20
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
- warp/jax_experimental/ffi.py +702 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +312 -116
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +100 -11
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/render/render_opengl.py +19 -17
- warp/render/render_usd.py +93 -3
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +32 -19
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/inertia.py +189 -156
- warp/sim/integrator_euler.py +8 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +8 -5
- warp/sim/model.py +71 -25
- warp/sim/render.py +4 -0
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +217 -20
- warp/tests/__main__.py +0 -15
- warp/tests/assets/torus.usda +1 -1
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
- warp/tests/sim/test_inertia.py +161 -0
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/sim/test_xpbd.py +399 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_codegen.py +24 -3
- warp/tests/test_examples.py +40 -38
- warp/tests/test_fem.py +98 -14
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +577 -156
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +356 -151
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +336 -178
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -62
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +175 -666
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/stubs.py
CHANGED
|
@@ -96,7 +96,12 @@ from warp.context import Stream, get_stream, set_stream, wait_stream, synchroniz
|
|
|
96
96
|
from warp.context import Event, record_event, wait_event, synchronize_event, get_event_elapsed_time
|
|
97
97
|
from warp.context import RegisteredGLBuffer
|
|
98
98
|
from warp.context import is_mempool_supported, is_mempool_enabled, set_mempool_enabled
|
|
99
|
-
from warp.context import
|
|
99
|
+
from warp.context import (
|
|
100
|
+
set_mempool_release_threshold,
|
|
101
|
+
get_mempool_release_threshold,
|
|
102
|
+
get_mempool_used_mem_current,
|
|
103
|
+
get_mempool_used_mem_high,
|
|
104
|
+
)
|
|
100
105
|
from warp.context import is_mempool_access_supported, is_mempool_access_enabled, set_mempool_access_enabled
|
|
101
106
|
from warp.context import is_peer_access_supported, is_peer_access_enabled, set_peer_access_enabled
|
|
102
107
|
|
|
@@ -132,6 +137,7 @@ from warp.paddle import device_from_paddle, device_to_paddle
|
|
|
132
137
|
from warp.paddle import stream_from_paddle
|
|
133
138
|
|
|
134
139
|
from warp.build import clear_kernel_cache
|
|
140
|
+
from warp.build import clear_lto_cache
|
|
135
141
|
|
|
136
142
|
from warp.constants import *
|
|
137
143
|
|
|
@@ -648,6 +654,18 @@ def matrix(*args: Scalar, shape: Tuple[int, int], dtype: Scalar) -> Matrix[Any,
|
|
|
648
654
|
...
|
|
649
655
|
|
|
650
656
|
|
|
657
|
+
@over
|
|
658
|
+
def matrix_from_cols(*args: Vector[Any, Scalar]) -> Matrix[Any, Any, Scalar]:
|
|
659
|
+
"""Construct a matrix from column vectors."""
|
|
660
|
+
...
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
@over
|
|
664
|
+
def matrix_from_rows(*args: Vector[Any, Scalar]) -> Matrix[Any, Any, Scalar]:
|
|
665
|
+
"""Construct a matrix from row vectors."""
|
|
666
|
+
...
|
|
667
|
+
|
|
668
|
+
|
|
651
669
|
@over
|
|
652
670
|
def identity(n: int32, dtype: Scalar) -> Matrix[Any, Any, Scalar]:
|
|
653
671
|
"""Create an identity matrix with shape=(n,n) with the type given by ``dtype``."""
|
|
@@ -662,6 +680,14 @@ def svd3(A: Matrix[3, 3, Float], U: Matrix[3, 3, Float], sigma: Vector[3, Float]
|
|
|
662
680
|
...
|
|
663
681
|
|
|
664
682
|
|
|
683
|
+
@over
|
|
684
|
+
def svd2(A: Matrix[2, 2, Float], U: Matrix[2, 2, Float], sigma: Vector[2, Float], V: Matrix[2, 2, Scalar]):
|
|
685
|
+
"""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
|
|
686
|
+
while the left and right basis vectors are returned in ``U`` and ``V``.
|
|
687
|
+
"""
|
|
688
|
+
...
|
|
689
|
+
|
|
690
|
+
|
|
665
691
|
@over
|
|
666
692
|
def qr3(A: Matrix[3, 3, Float], Q: Matrix[3, 3, Float], R: Matrix[3, 3, Float]):
|
|
667
693
|
"""Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
|
|
@@ -687,7 +713,7 @@ def quaternion(dtype: Float) -> Quaternion[Float]:
|
|
|
687
713
|
|
|
688
714
|
|
|
689
715
|
@over
|
|
690
|
-
def quaternion(x: Float, y: Float, z: Float, w: Float) -> Quaternion[Float]:
|
|
716
|
+
def quaternion(x: Float, y: Float, z: Float, w: Float, dtype: Scalar) -> Quaternion[Float]:
|
|
691
717
|
"""Create a quaternion using the supplied components (type inferred from component type)."""
|
|
692
718
|
...
|
|
693
719
|
|
|
@@ -724,7 +750,19 @@ def quat_to_axis_angle(quat: Quaternion[Float], axis: Vector[3, Float], angle: F
|
|
|
724
750
|
|
|
725
751
|
@over
|
|
726
752
|
def quat_from_matrix(mat: Matrix[3, 3, Float]) -> Quaternion[Float]:
|
|
727
|
-
"""Construct a quaternion from a 3x3 matrix.
|
|
753
|
+
"""Construct a quaternion from a 3x3 matrix.
|
|
754
|
+
|
|
755
|
+
If the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.
|
|
756
|
+
"""
|
|
757
|
+
...
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
@over
|
|
761
|
+
def quat_from_matrix(mat: Matrix[4, 4, Float]) -> Quaternion[Float]:
|
|
762
|
+
"""Construct a quaternion from a 4x4 matrix.
|
|
763
|
+
|
|
764
|
+
If the top-left 3x3 block of the matrix is not a pure rotation, but for example includes scaling or skewing, the result is undefined.
|
|
765
|
+
"""
|
|
728
766
|
...
|
|
729
767
|
|
|
730
768
|
|
|
@@ -1028,7 +1066,7 @@ def tile(x: Any) -> Tile:
|
|
|
1028
1066
|
|
|
1029
1067
|
This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
|
|
1030
1068
|
|
|
1031
|
-
* If the input value is a scalar, then the resulting tile has ``shape=(
|
|
1069
|
+
* If the input value is a scalar, then the resulting tile has ``shape=(block_dim,)``
|
|
1032
1070
|
* If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
|
|
1033
1071
|
|
|
1034
1072
|
:param x: A per-thread local value, e.g. scalar, vector, or matrix.
|
|
@@ -1121,13 +1159,12 @@ def tile_transpose(a: Tile) -> Tile:
|
|
|
1121
1159
|
def tile_broadcast(a: Tile, shape: Tuple[int, ...]) -> Tile:
|
|
1122
1160
|
"""Broadcast a tile.
|
|
1123
1161
|
|
|
1124
|
-
|
|
1125
|
-
|
|
1162
|
+
Broadcasts the input tile ``a`` to the destination shape.
|
|
1126
1163
|
Broadcasting follows NumPy broadcast rules.
|
|
1127
1164
|
|
|
1128
1165
|
:param a: Tile to broadcast
|
|
1129
1166
|
:param shape: The shape to broadcast to
|
|
1130
|
-
:returns: Tile with broadcast
|
|
1167
|
+
:returns: Tile with broadcast shape
|
|
1131
1168
|
"""
|
|
1132
1169
|
...
|
|
1133
1170
|
|
|
@@ -1810,6 +1847,18 @@ def randi(state: uint32, low: int32, high: int32) -> int:
|
|
|
1810
1847
|
...
|
|
1811
1848
|
|
|
1812
1849
|
|
|
1850
|
+
@over
|
|
1851
|
+
def randu(state: uint32) -> uint32:
|
|
1852
|
+
"""Return a random unsigned integer in the range [0, 2^32)."""
|
|
1853
|
+
...
|
|
1854
|
+
|
|
1855
|
+
|
|
1856
|
+
@over
|
|
1857
|
+
def randu(state: uint32, low: uint32, high: uint32) -> uint32:
|
|
1858
|
+
"""Return a random unsigned integer between [low, high)."""
|
|
1859
|
+
...
|
|
1860
|
+
|
|
1861
|
+
|
|
1813
1862
|
@over
|
|
1814
1863
|
def randf(state: uint32) -> float:
|
|
1815
1864
|
"""Return a random float between [0.0, 1.0)."""
|
|
@@ -2029,61 +2078,171 @@ def tid() -> Tuple[int, int, int, int]:
|
|
|
2029
2078
|
|
|
2030
2079
|
@over
|
|
2031
2080
|
def select(cond: bool, value_if_false: Any, value_if_true: Any) -> Any:
|
|
2032
|
-
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true
|
|
2081
|
+
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2082
|
+
|
|
2083
|
+
.. deprecated:: 1.7
|
|
2084
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2085
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
2086
|
+
"""
|
|
2033
2087
|
...
|
|
2034
2088
|
|
|
2035
2089
|
|
|
2036
2090
|
@over
|
|
2037
2091
|
def select(cond: int8, value_if_false: Any, value_if_true: Any) -> Any:
|
|
2038
|
-
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true
|
|
2092
|
+
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2093
|
+
|
|
2094
|
+
.. deprecated:: 1.7
|
|
2095
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2096
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
2097
|
+
"""
|
|
2039
2098
|
...
|
|
2040
2099
|
|
|
2041
2100
|
|
|
2042
2101
|
@over
|
|
2043
2102
|
def select(cond: uint8, value_if_false: Any, value_if_true: Any) -> Any:
|
|
2044
|
-
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true
|
|
2103
|
+
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2104
|
+
|
|
2105
|
+
.. deprecated:: 1.7
|
|
2106
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2107
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
2108
|
+
"""
|
|
2045
2109
|
...
|
|
2046
2110
|
|
|
2047
2111
|
|
|
2048
2112
|
@over
|
|
2049
2113
|
def select(cond: int16, value_if_false: Any, value_if_true: Any) -> Any:
|
|
2050
|
-
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true
|
|
2114
|
+
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2115
|
+
|
|
2116
|
+
.. deprecated:: 1.7
|
|
2117
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2118
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
2119
|
+
"""
|
|
2051
2120
|
...
|
|
2052
2121
|
|
|
2053
2122
|
|
|
2054
2123
|
@over
|
|
2055
2124
|
def select(cond: uint16, value_if_false: Any, value_if_true: Any) -> Any:
|
|
2056
|
-
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true
|
|
2125
|
+
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2126
|
+
|
|
2127
|
+
.. deprecated:: 1.7
|
|
2128
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2129
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
2130
|
+
"""
|
|
2057
2131
|
...
|
|
2058
2132
|
|
|
2059
2133
|
|
|
2060
2134
|
@over
|
|
2061
2135
|
def select(cond: int32, value_if_false: Any, value_if_true: Any) -> Any:
|
|
2062
|
-
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true
|
|
2136
|
+
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2137
|
+
|
|
2138
|
+
.. deprecated:: 1.7
|
|
2139
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2140
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
2141
|
+
"""
|
|
2063
2142
|
...
|
|
2064
2143
|
|
|
2065
2144
|
|
|
2066
2145
|
@over
|
|
2067
2146
|
def select(cond: uint32, value_if_false: Any, value_if_true: Any) -> Any:
|
|
2068
|
-
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true
|
|
2147
|
+
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2148
|
+
|
|
2149
|
+
.. deprecated:: 1.7
|
|
2150
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2151
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
2152
|
+
"""
|
|
2069
2153
|
...
|
|
2070
2154
|
|
|
2071
2155
|
|
|
2072
2156
|
@over
|
|
2073
2157
|
def select(cond: int64, value_if_false: Any, value_if_true: Any) -> Any:
|
|
2074
|
-
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true
|
|
2158
|
+
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2159
|
+
|
|
2160
|
+
.. deprecated:: 1.7
|
|
2161
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2162
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
2163
|
+
"""
|
|
2075
2164
|
...
|
|
2076
2165
|
|
|
2077
2166
|
|
|
2078
2167
|
@over
|
|
2079
2168
|
def select(cond: uint64, value_if_false: Any, value_if_true: Any) -> Any:
|
|
2080
|
-
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true
|
|
2169
|
+
"""Select between two arguments, if ``cond`` is ``False`` then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2170
|
+
|
|
2171
|
+
.. deprecated:: 1.7
|
|
2172
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2173
|
+
``where(cond, value_if_true, value_if_false)``.
|
|
2174
|
+
"""
|
|
2081
2175
|
...
|
|
2082
2176
|
|
|
2083
2177
|
|
|
2084
2178
|
@over
|
|
2085
2179
|
def select(arr: Array[Any], value_if_false: Any, value_if_true: Any) -> Any:
|
|
2086
|
-
"""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true
|
|
2180
|
+
"""Select between two arguments, if ``arr`` is null then return ``value_if_false``, otherwise return ``value_if_true``.
|
|
2181
|
+
|
|
2182
|
+
.. deprecated:: 1.7
|
|
2183
|
+
Use :func:`where` instead, which has the more intuitive argument order:
|
|
2184
|
+
``where(arr, value_if_true, value_if_false)``.
|
|
2185
|
+
"""
|
|
2186
|
+
...
|
|
2187
|
+
|
|
2188
|
+
|
|
2189
|
+
@over
|
|
2190
|
+
def where(cond: bool, value_if_true: Any, value_if_false: Any) -> Any:
|
|
2191
|
+
"""Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2192
|
+
...
|
|
2193
|
+
|
|
2194
|
+
|
|
2195
|
+
@over
|
|
2196
|
+
def where(cond: int8, value_if_true: Any, value_if_false: Any) -> Any:
|
|
2197
|
+
"""Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2198
|
+
...
|
|
2199
|
+
|
|
2200
|
+
|
|
2201
|
+
@over
|
|
2202
|
+
def where(cond: uint8, value_if_true: Any, value_if_false: Any) -> Any:
|
|
2203
|
+
"""Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2204
|
+
...
|
|
2205
|
+
|
|
2206
|
+
|
|
2207
|
+
@over
|
|
2208
|
+
def where(cond: int16, value_if_true: Any, value_if_false: Any) -> Any:
|
|
2209
|
+
"""Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2210
|
+
...
|
|
2211
|
+
|
|
2212
|
+
|
|
2213
|
+
@over
|
|
2214
|
+
def where(cond: uint16, value_if_true: Any, value_if_false: Any) -> Any:
|
|
2215
|
+
"""Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2216
|
+
...
|
|
2217
|
+
|
|
2218
|
+
|
|
2219
|
+
@over
|
|
2220
|
+
def where(cond: int32, value_if_true: Any, value_if_false: Any) -> Any:
|
|
2221
|
+
"""Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2222
|
+
...
|
|
2223
|
+
|
|
2224
|
+
|
|
2225
|
+
@over
|
|
2226
|
+
def where(cond: uint32, value_if_true: Any, value_if_false: Any) -> Any:
|
|
2227
|
+
"""Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2228
|
+
...
|
|
2229
|
+
|
|
2230
|
+
|
|
2231
|
+
@over
|
|
2232
|
+
def where(cond: int64, value_if_true: Any, value_if_false: Any) -> Any:
|
|
2233
|
+
"""Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2234
|
+
...
|
|
2235
|
+
|
|
2236
|
+
|
|
2237
|
+
@over
|
|
2238
|
+
def where(cond: uint64, value_if_true: Any, value_if_false: Any) -> Any:
|
|
2239
|
+
"""Select between two arguments, if ``cond`` is ``True`` then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2240
|
+
...
|
|
2241
|
+
|
|
2242
|
+
|
|
2243
|
+
@over
|
|
2244
|
+
def where(arr: Array[Any], value_if_true: Any, value_if_false: Any) -> Any:
|
|
2245
|
+
"""Select between two arguments, if ``arr`` is not null then return ``value_if_true``, otherwise return ``value_if_false``."""
|
|
2087
2246
|
...
|
|
2088
2247
|
|
|
2089
2248
|
|
|
@@ -2492,7 +2651,19 @@ def expect_near(a: Float, b: Float, tolerance: Float):
|
|
|
2492
2651
|
|
|
2493
2652
|
|
|
2494
2653
|
@over
|
|
2495
|
-
def expect_near(a:
|
|
2654
|
+
def expect_near(a: Vector[Any, Float], b: Vector[Any, Float], tolerance: Float):
|
|
2655
|
+
"""Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude"""
|
|
2656
|
+
...
|
|
2657
|
+
|
|
2658
|
+
|
|
2659
|
+
@over
|
|
2660
|
+
def expect_near(a: Quaternion[Float], b: Quaternion[Float], tolerance: Float):
|
|
2661
|
+
"""Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude"""
|
|
2662
|
+
...
|
|
2663
|
+
|
|
2664
|
+
|
|
2665
|
+
@over
|
|
2666
|
+
def expect_near(a: Matrix[Any, Any, Float], b: Matrix[Any, Any, Float], tolerance: Float):
|
|
2496
2667
|
"""Prints an error to stdout if any element of ``a`` and ``b`` are not closer than tolerance in magnitude"""
|
|
2497
2668
|
...
|
|
2498
2669
|
|
|
@@ -2901,7 +3072,7 @@ def tile_matmul(a: Tile, b: Tile, out: Tile) -> Tile:
|
|
|
2901
3072
|
* fp16, fp32, fp64 (real)
|
|
2902
3073
|
* vec2h, vec2f, vec2d (complex)
|
|
2903
3074
|
|
|
2904
|
-
All input and output tiles must have the same datatype. Tile data will
|
|
3075
|
+
All input and output tiles must have the same datatype. Tile data will automatically be migrated
|
|
2905
3076
|
to shared memory if necessary and will use TensorCore operations when available.
|
|
2906
3077
|
|
|
2907
3078
|
:param a: A tile with ``shape=(M, K)``
|
|
@@ -2920,7 +3091,7 @@ def tile_matmul(a: Tile, b: Tile) -> Tile:
|
|
|
2920
3091
|
* fp16, fp32, fp64 (real)
|
|
2921
3092
|
* vec2h, vec2f, vec2d (complex)
|
|
2922
3093
|
|
|
2923
|
-
Both input tiles must have the same datatype. Tile data will
|
|
3094
|
+
Both input tiles must have the same datatype. Tile data will automatically be migrated
|
|
2924
3095
|
to shared memory if necessary and will use TensorCore operations when available.
|
|
2925
3096
|
|
|
2926
3097
|
:param a: A tile with ``shape=(M, K)``
|
|
@@ -3134,3 +3305,29 @@ def smooth_normalize(v: Any, delta: float):
|
|
|
3134
3305
|
Vector[Any,Float]: The normalized vector.
|
|
3135
3306
|
"""
|
|
3136
3307
|
...
|
|
3308
|
+
|
|
3309
|
+
|
|
3310
|
+
@over
|
|
3311
|
+
def transform_from_matrix(mat: Matrix[4, 4, float32]) -> Transformation[float32]:
|
|
3312
|
+
"""Construct a transformation from a 4x4 matrix.
|
|
3313
|
+
|
|
3314
|
+
Args:
|
|
3315
|
+
mat (Matrix[4, 4, Float]): Matrix to convert.
|
|
3316
|
+
|
|
3317
|
+
Returns:
|
|
3318
|
+
Transformation[Float]: The transformation.
|
|
3319
|
+
"""
|
|
3320
|
+
...
|
|
3321
|
+
|
|
3322
|
+
|
|
3323
|
+
@over
|
|
3324
|
+
def transform_to_matrix(xform: Transformation[float32]) -> Matrix[4, 4, float32]:
|
|
3325
|
+
"""Convert a transformation to a 4x4 matrix.
|
|
3326
|
+
|
|
3327
|
+
Args:
|
|
3328
|
+
xform (Transformation[Float]): Transformation to convert.
|
|
3329
|
+
|
|
3330
|
+
Returns:
|
|
3331
|
+
Matrix[4, 4, Float]: The matrix.
|
|
3332
|
+
"""
|
|
3333
|
+
...
|
warp/tests/__main__.py
CHANGED
|
@@ -1,18 +1,3 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
#
|
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
# you may not use this file except in compliance with the License.
|
|
6
|
-
# You may obtain a copy of the License at
|
|
7
|
-
#
|
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
#
|
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
# See the License for the specific language governing permissions and
|
|
14
|
-
# limitations under the License.
|
|
15
|
-
|
|
16
1
|
from warp.thirdparty.unittest_parallel import main
|
|
17
2
|
|
|
18
3
|
if __name__ == "__main__":
|
warp/tests/assets/torus.usda
CHANGED
|
File without changes
|
|
@@ -71,6 +71,44 @@ def test_mempool_release_threshold(test, device):
|
|
|
71
71
|
test.assertEqual(wp.get_mempool_release_threshold(device), saved_threshold)
|
|
72
72
|
|
|
73
73
|
|
|
74
|
+
def test_mempool_usage_queries(test, device):
|
|
75
|
+
"""Check API to query mempool memory usage."""
|
|
76
|
+
|
|
77
|
+
device = wp.get_device(device)
|
|
78
|
+
pre_alloc_mempool_usage_curr = wp.get_mempool_used_mem_current(device)
|
|
79
|
+
pre_alloc_mempool_usage_high = wp.get_mempool_used_mem_high(device)
|
|
80
|
+
|
|
81
|
+
# Allocate a 1 MiB array
|
|
82
|
+
test_data = wp.empty(262144, dtype=wp.float32, device=device)
|
|
83
|
+
wp.synchronize_device(device)
|
|
84
|
+
|
|
85
|
+
# Query memory usage again
|
|
86
|
+
post_alloc_mempool_usage_curr = wp.get_mempool_used_mem_current(device)
|
|
87
|
+
post_alloc_mempool_usage_high = wp.get_mempool_used_mem_high(device)
|
|
88
|
+
|
|
89
|
+
test.assertEqual(
|
|
90
|
+
post_alloc_mempool_usage_curr, pre_alloc_mempool_usage_curr + 1048576, "Memory usage did not increase by 1 MiB"
|
|
91
|
+
)
|
|
92
|
+
test.assertGreaterEqual(post_alloc_mempool_usage_high, 1048576, "High-water mark is not at least 1 MiB")
|
|
93
|
+
|
|
94
|
+
# Free the allocation
|
|
95
|
+
del test_data
|
|
96
|
+
wp.synchronize_device(device)
|
|
97
|
+
|
|
98
|
+
# Query memory usage
|
|
99
|
+
post_free_mempool_usage_curr = wp.get_mempool_used_mem_current(device)
|
|
100
|
+
post_free_mempool_usage_high = wp.get_mempool_used_mem_high(device)
|
|
101
|
+
|
|
102
|
+
test.assertEqual(
|
|
103
|
+
post_free_mempool_usage_curr,
|
|
104
|
+
pre_alloc_mempool_usage_curr,
|
|
105
|
+
"Test didn't end with the same amount of used memory as the test started with.",
|
|
106
|
+
)
|
|
107
|
+
test.assertEqual(
|
|
108
|
+
post_free_mempool_usage_high, post_alloc_mempool_usage_high, "High-water mark should not change after free"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
74
112
|
def test_mempool_exceptions(test, device):
|
|
75
113
|
device = wp.get_device(device)
|
|
76
114
|
|
|
@@ -176,6 +214,7 @@ devices_without_mempools = [d for d in get_test_devices() if not d.is_mempool_su
|
|
|
176
214
|
add_function_test(
|
|
177
215
|
TestMempool, "test_mempool_release_threshold", test_mempool_release_threshold, devices=devices_with_mempools
|
|
178
216
|
)
|
|
217
|
+
add_function_test(TestMempool, "test_mempool_usage_queries", test_mempool_usage_queries, devices=devices_with_mempools)
|
|
179
218
|
add_function_test(TestMempool, "test_mempool_access_self", test_mempool_access_self, devices=devices_with_mempools)
|
|
180
219
|
|
|
181
220
|
# test devices without mempool support
|
|
@@ -342,6 +342,29 @@ def test_event_elapsed_time(test, device):
|
|
|
342
342
|
test.assertGreater(elapsed, 0)
|
|
343
343
|
|
|
344
344
|
|
|
345
|
+
def test_event_elapsed_time_graph(test, device):
|
|
346
|
+
stream = wp.get_stream(device)
|
|
347
|
+
e1 = wp.Event(device, enable_timing=True)
|
|
348
|
+
e2 = wp.Event(device, enable_timing=True)
|
|
349
|
+
|
|
350
|
+
a = wp.zeros(N, dtype=float, device=device)
|
|
351
|
+
|
|
352
|
+
wp.load_module(device=device)
|
|
353
|
+
|
|
354
|
+
with wp.ScopedCapture(device, force_module_load=False) as capture:
|
|
355
|
+
stream.record_event(e1)
|
|
356
|
+
wp.launch(inc, dim=N, inputs=[a], device=device)
|
|
357
|
+
stream.record_event(e2)
|
|
358
|
+
|
|
359
|
+
wp.capture_launch(capture.graph)
|
|
360
|
+
|
|
361
|
+
wp.synchronize_device(device)
|
|
362
|
+
|
|
363
|
+
elapsed = wp.get_event_elapsed_time(e1, e2)
|
|
364
|
+
|
|
365
|
+
test.assertGreater(elapsed, 0)
|
|
366
|
+
|
|
367
|
+
|
|
345
368
|
def test_stream_priority_basics(test, device):
|
|
346
369
|
standard_stream = wp.Stream(device)
|
|
347
370
|
test.assertEqual(standard_stream.priority, 0, "Default priority of streams must be 0.")
|
|
@@ -401,6 +424,52 @@ def test_stream_priority_timings(test, device):
|
|
|
401
424
|
test.assertLess(elapsed_hi, elapsed_lo, "Copies on higher-priority stream should be faster.")
|
|
402
425
|
|
|
403
426
|
|
|
427
|
+
@wp.kernel
|
|
428
|
+
def sum_threads(sum: wp.array(dtype=wp.uint64)):
|
|
429
|
+
i = wp.tid()
|
|
430
|
+
wp.atomic_add(sum, 0, wp.uint64(1))
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def test_stream_event_is_complete(test, device):
|
|
434
|
+
with wp.ScopedDevice(device):
|
|
435
|
+
stream = wp.Stream()
|
|
436
|
+
event = wp.Event()
|
|
437
|
+
# No operations on stream, should be complete
|
|
438
|
+
test.assertTrue(stream.is_complete)
|
|
439
|
+
|
|
440
|
+
# Event not recorded yet, should be complete
|
|
441
|
+
test.assertTrue(event.is_complete)
|
|
442
|
+
|
|
443
|
+
a = wp.zeros(1, dtype=wp.uint64)
|
|
444
|
+
|
|
445
|
+
threads = 1024 * 1024 * 8
|
|
446
|
+
|
|
447
|
+
with wp.ScopedStream(stream):
|
|
448
|
+
# Launch some work on the stream and reuse the event
|
|
449
|
+
|
|
450
|
+
for iter in range(5):
|
|
451
|
+
# Kernel takes about 1 ms to run on an RTX 3090
|
|
452
|
+
wp.launch(sum_threads, dim=threads, outputs=[a])
|
|
453
|
+
|
|
454
|
+
stream.record_event(event)
|
|
455
|
+
|
|
456
|
+
# Kernel should still be running
|
|
457
|
+
test.assertFalse(stream.is_complete)
|
|
458
|
+
|
|
459
|
+
# Event should not be finished
|
|
460
|
+
test.assertFalse(event.is_complete)
|
|
461
|
+
|
|
462
|
+
# Force the stream operations to complete
|
|
463
|
+
wp.synchronize_stream(stream)
|
|
464
|
+
|
|
465
|
+
# Now all operations are complete
|
|
466
|
+
test.assertTrue(stream.is_complete)
|
|
467
|
+
test.assertTrue(event.is_complete)
|
|
468
|
+
|
|
469
|
+
# Verify result
|
|
470
|
+
test.assertEqual(a.numpy()[0], (iter + 1) * threads)
|
|
471
|
+
|
|
472
|
+
|
|
404
473
|
devices = get_selected_cuda_test_devices()
|
|
405
474
|
|
|
406
475
|
|
|
@@ -554,9 +623,11 @@ add_function_test(TestStreams, "test_stream_scope_wait_event", test_stream_scope
|
|
|
554
623
|
add_function_test(TestStreams, "test_stream_scope_wait_stream", test_stream_scope_wait_stream, devices=devices)
|
|
555
624
|
add_function_test(TestStreams, "test_stream_priority_basics", test_stream_priority_basics, devices=devices)
|
|
556
625
|
add_function_test(TestStreams, "test_stream_priority_timings", test_stream_priority_timings, devices=devices)
|
|
626
|
+
add_function_test(TestStreams, "test_stream_event_is_complete", test_stream_event_is_complete, devices=devices)
|
|
557
627
|
|
|
558
628
|
add_function_test(TestStreams, "test_event_synchronize", test_event_synchronize, devices=devices)
|
|
559
629
|
add_function_test(TestStreams, "test_event_elapsed_time", test_event_elapsed_time, devices=devices)
|
|
630
|
+
add_function_test(TestStreams, "test_event_elapsed_time_graph", test_event_elapsed_time_graph, devices=devices)
|
|
560
631
|
|
|
561
632
|
if __name__ == "__main__":
|
|
562
633
|
wp.clear_kernel_cache()
|
|
File without changes
|