warp-lang 1.4.2__py3-none-manylinux2014_x86_64.whl → 1.5.1__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1819 -7
- warp/codegen.py +197 -61
- warp/config.py +2 -2
- warp/context.py +379 -107
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +4 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -7
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +604 -0
- warp/native/cuda_util.cpp +68 -51
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1854 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +140 -67
- warp/sim/graph_coloring.py +292 -0
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +109 -32
- warp/sparse.py +1 -1
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +251 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +21 -5
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +34 -4
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_matmul.py +6 -9
- warp/tests/test_matmul_lite.py +6 -11
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_overwrite.py +45 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_static.py +3 -3
- warp/tests/test_tile.py +744 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -2
- warp/types.py +340 -74
- warp/utils.py +23 -3
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/stubs.py
CHANGED
|
@@ -11,9 +11,6 @@ Length = TypeVar("Length", bound=int)
|
|
|
11
11
|
Rows = TypeVar("Rows", bound=int)
|
|
12
12
|
Cols = TypeVar("Cols", bound=int)
|
|
13
13
|
DType = TypeVar("DType")
|
|
14
|
-
Int = TypeVar("Int")
|
|
15
|
-
Float = TypeVar("Float")
|
|
16
|
-
Scalar = TypeVar("Scalar")
|
|
17
14
|
Vector = Generic[Length, Scalar]
|
|
18
15
|
Matrix = Generic[Rows, Cols, Scalar]
|
|
19
16
|
Quaternion = Generic[Float]
|
|
@@ -39,6 +36,8 @@ from warp.types import transform, transformh, transformf, transformd
|
|
|
39
36
|
from warp.types import spatial_vector, spatial_vectorh, spatial_vectorf, spatial_vectord
|
|
40
37
|
from warp.types import spatial_matrix, spatial_matrixh, spatial_matrixf, spatial_matrixd
|
|
41
38
|
|
|
39
|
+
from warp.types import Int, Float, Scalar
|
|
40
|
+
|
|
42
41
|
from warp.types import Bvh, Mesh, HashGrid, Volume, MarchingCubes
|
|
43
42
|
from warp.types import BvhQuery, HashGridQuery, MeshQueryAABB, MeshQueryPoint, MeshQueryRay
|
|
44
43
|
|
|
@@ -67,6 +66,7 @@ from warp.context import (
|
|
|
67
66
|
copy,
|
|
68
67
|
from_numpy,
|
|
69
68
|
launch,
|
|
69
|
+
launch_tiled,
|
|
70
70
|
synchronize,
|
|
71
71
|
force_load,
|
|
72
72
|
load_module,
|
|
@@ -894,6 +894,475 @@ def spatial_mass(
|
|
|
894
894
|
...
|
|
895
895
|
|
|
896
896
|
|
|
897
|
+
@over
|
|
898
|
+
def tile_zeros(m: int32, n: int32, dtype: Any, storage: str) -> Tile:
|
|
899
|
+
"""Allocates a tile of zero-initialized items.
|
|
900
|
+
|
|
901
|
+
:param m: Size of the first dimension of the output tile
|
|
902
|
+
:param n: Size of the second dimension of the output tile
|
|
903
|
+
:param dtype: Datatype of output tile's elements
|
|
904
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
905
|
+
(default) or ``"shared"`` for shared memory.
|
|
906
|
+
:returns: A zero-initialized tile with ``shape=(m,n)`` and the specified datatype
|
|
907
|
+
"""
|
|
908
|
+
...
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
@over
|
|
912
|
+
def tile_ones(m: int32, n: int32, dtype: Any, storage: str) -> Tile:
|
|
913
|
+
"""Allocates a tile of one-initialized items.
|
|
914
|
+
|
|
915
|
+
:param m: Size of the first dimension of the output tile
|
|
916
|
+
:param n: Size of the second dimension of the output tile
|
|
917
|
+
:param dtype: Datatype of output tile's elements
|
|
918
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
919
|
+
(default) or ``"shared"`` for shared memory.
|
|
920
|
+
:returns: A one-initialized tile with ``shape=(m,n)`` and the specified dtype
|
|
921
|
+
"""
|
|
922
|
+
...
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
@over
|
|
926
|
+
def tile_arange(*args: Scalar, dtype: Any, storage: str) -> Tile:
|
|
927
|
+
"""Generates a tile of linearly spaced elements.
|
|
928
|
+
|
|
929
|
+
:param args: Variable-length positional arguments, interpreted as:
|
|
930
|
+
|
|
931
|
+
- ``(stop,)``: Generates values from ``0`` to ``stop - 1``
|
|
932
|
+
- ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
|
|
933
|
+
- ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
|
|
934
|
+
|
|
935
|
+
:param dtype: Datatype of output tile's elements (optional, default: int)
|
|
936
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
937
|
+
(default) or ``"shared"`` for shared memory.
|
|
938
|
+
:returns: A tile with ``shape=(1,n)`` with linearly spaced elements of specified dtype
|
|
939
|
+
"""
|
|
940
|
+
...
|
|
941
|
+
|
|
942
|
+
|
|
943
|
+
@over
|
|
944
|
+
def tile_load(a: Array[Any], i: int32, n: int32, storage: str) -> Tile:
|
|
945
|
+
"""Loads a 1D tile from a global memory array.
|
|
946
|
+
|
|
947
|
+
This method will cooperatively load a tile from global memory using all threads in the block.
|
|
948
|
+
|
|
949
|
+
:param a: The source array in global memory
|
|
950
|
+
:param i: Offset in the source array measured in multiples of ``n``, i.e.: ``offset=i*n``
|
|
951
|
+
:param n: The number of elements in the tile
|
|
952
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
953
|
+
(default) or ``"shared"`` for shared memory.
|
|
954
|
+
:returns: A tile with ``shape=(1,n)`` and dtype the same as the source array
|
|
955
|
+
"""
|
|
956
|
+
...
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
@over
|
|
960
|
+
def tile_load(a: Array[Any], i: int32, j: int32, m: int32, n: int32, storage: str) -> Tile:
|
|
961
|
+
"""Loads a 2D tile from a global memory array.
|
|
962
|
+
|
|
963
|
+
This method will cooperatively load a tile from global memory using all threads in the block.
|
|
964
|
+
|
|
965
|
+
:param a: The source array in global memory
|
|
966
|
+
:param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
|
|
967
|
+
:param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
|
|
968
|
+
:param m: The size of the tile's first dimension
|
|
969
|
+
:param n: The size of the tile's second dimension
|
|
970
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
971
|
+
(default) or ``"shared"`` for shared memory.
|
|
972
|
+
:returns: A tile with ``shape=(m,n)`` and dtype the same as the source array
|
|
973
|
+
"""
|
|
974
|
+
...
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
@over
|
|
978
|
+
def tile_store(a: Array[Any], i: int32, t: Tile):
|
|
979
|
+
"""Stores a 1D tile to a global memory array.
|
|
980
|
+
|
|
981
|
+
This method will cooperatively store a tile to global memory using all threads in the block.
|
|
982
|
+
|
|
983
|
+
:param a: The destination array in global memory
|
|
984
|
+
:param i: Offset in the destination array measured in multiples of ``n``, i.e.: ``offset=i*n``
|
|
985
|
+
:param t: The source tile to store data from, must have the same dtype as the destination array
|
|
986
|
+
"""
|
|
987
|
+
...
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
@over
|
|
991
|
+
def tile_store(a: Array[Any], i: int32, j: int32, t: Tile):
|
|
992
|
+
"""Stores a tile to a global memory array.
|
|
993
|
+
|
|
994
|
+
This method will cooperatively store a tile to global memory using all threads in the block.
|
|
995
|
+
|
|
996
|
+
:param a: The destination array in global memory
|
|
997
|
+
:param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
|
|
998
|
+
:param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
|
|
999
|
+
:param t: The source tile to store data from, must have the same dtype as the destination array
|
|
1000
|
+
"""
|
|
1001
|
+
...
|
|
1002
|
+
|
|
1003
|
+
|
|
1004
|
+
@over
|
|
1005
|
+
def tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Tile) -> Tile:
|
|
1006
|
+
"""Atomically add a tile to the array `a`, each element will be updated atomically.
|
|
1007
|
+
|
|
1008
|
+
:param a: Array in global memory, should have the same ``dtype`` as the input tile
|
|
1009
|
+
:param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
|
|
1010
|
+
:param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
|
|
1011
|
+
:param t: Source tile to add to the destination array
|
|
1012
|
+
:returns: A tile with the same dimensions and type as the source tile, holding the original value of the destination elements
|
|
1013
|
+
"""
|
|
1014
|
+
...
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
@over
|
|
1018
|
+
def tile_view(t: Tile, i: int32, j: int32, m: int32, n: int32) -> Tile:
|
|
1019
|
+
"""Return a subrange of a given tile from coordinates (i,j) to (i+m, j+n).
|
|
1020
|
+
|
|
1021
|
+
:param t: Input tile to extract a subrange from
|
|
1022
|
+
:param i: Offset in the source tile along the first dimension
|
|
1023
|
+
:param j: Offset in the source tile along the second dimensions
|
|
1024
|
+
:param m: Size of the subrange to return along the first dimension
|
|
1025
|
+
:param n: Size of the subrange to return along the second dimension
|
|
1026
|
+
:returns: A tile with dimensions (m,n) and the same datatype as the input tile
|
|
1027
|
+
"""
|
|
1028
|
+
...
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
@over
|
|
1032
|
+
def tile_assign(dst: Tile, i: int32, j: int32, src: Tile):
|
|
1033
|
+
"""Assign a tile to a subrange of a destination tile at coordinates (i,j).
|
|
1034
|
+
|
|
1035
|
+
:param t: The destination tile to assign to
|
|
1036
|
+
:param i: Offset in the source tile along the first dimension
|
|
1037
|
+
:param j: Offset in the source tile along the second dimensions
|
|
1038
|
+
:param src: The source tile to read values from
|
|
1039
|
+
"""
|
|
1040
|
+
...
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
@over
|
|
1044
|
+
def tile(x: Any) -> Tile:
|
|
1045
|
+
"""Constructs a new Tile from per-thread kernel values.
|
|
1046
|
+
|
|
1047
|
+
This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
|
|
1048
|
+
|
|
1049
|
+
* If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
|
|
1050
|
+
* If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
|
|
1051
|
+
|
|
1052
|
+
:param x: A per-thread local value, e.g.: scalar, vector, or matrix.
|
|
1053
|
+
:returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
|
|
1054
|
+
|
|
1055
|
+
This example shows how to create a linear sequence from thread variables:
|
|
1056
|
+
|
|
1057
|
+
.. code-block:: python
|
|
1058
|
+
|
|
1059
|
+
@wp.kernel
|
|
1060
|
+
def compute():
|
|
1061
|
+
i = wp.tid()
|
|
1062
|
+
t = wp.tile(i * 2)
|
|
1063
|
+
print(t)
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
wp.launch(compute, dim=16, inputs=[], block_dim=16)
|
|
1067
|
+
|
|
1068
|
+
Prints:
|
|
1069
|
+
|
|
1070
|
+
.. code-block:: text
|
|
1071
|
+
|
|
1072
|
+
tile(m=1, n=16, storage=register) = [[0 2 4 6 8 ...]]
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
"""
|
|
1076
|
+
...
|
|
1077
|
+
|
|
1078
|
+
|
|
1079
|
+
@over
|
|
1080
|
+
def untile(a: Tile) -> Scalar:
|
|
1081
|
+
"""Convert a Tile back to per-thread values.
|
|
1082
|
+
|
|
1083
|
+
This function converts a block-wide tile back to per-thread values.
|
|
1084
|
+
|
|
1085
|
+
* If the input tile is 1-dimensional then the resulting value will be a per-thread scalar
|
|
1086
|
+
* If the input tile is 2-dimensional then the resulting value will be a per-thread vector of length M
|
|
1087
|
+
|
|
1088
|
+
:param a: A tile with dimensions ``shape=(M, block_dim)``
|
|
1089
|
+
:returns: A single value per-thread with the same dtype as the tile
|
|
1090
|
+
|
|
1091
|
+
This example shows how to create a linear sequence from thread variables:
|
|
1092
|
+
|
|
1093
|
+
.. code-block:: python
|
|
1094
|
+
|
|
1095
|
+
@wp.kernel
|
|
1096
|
+
def compute():
|
|
1097
|
+
i = wp.tid()
|
|
1098
|
+
|
|
1099
|
+
# create block-wide tile
|
|
1100
|
+
t = wp.tile(i) * 2
|
|
1101
|
+
|
|
1102
|
+
# convert back to per-thread values
|
|
1103
|
+
s = wp.untile(t)
|
|
1104
|
+
|
|
1105
|
+
print(s)
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
wp.launch(compute, dim=16, inputs=[], block_dim=16)
|
|
1109
|
+
|
|
1110
|
+
Prints:
|
|
1111
|
+
|
|
1112
|
+
.. code-block:: text
|
|
1113
|
+
|
|
1114
|
+
0
|
|
1115
|
+
2
|
|
1116
|
+
4
|
|
1117
|
+
6
|
|
1118
|
+
8
|
|
1119
|
+
...
|
|
1120
|
+
|
|
1121
|
+
"""
|
|
1122
|
+
...
|
|
1123
|
+
|
|
1124
|
+
|
|
1125
|
+
@over
|
|
1126
|
+
def tile_extract(a: Tile, i: int32, j: int32) -> Scalar:
|
|
1127
|
+
"""Extracts a single element from the tile and returns it as a scalar type.
|
|
1128
|
+
|
|
1129
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
1130
|
+
|
|
1131
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
1132
|
+
|
|
1133
|
+
:param a: Tile to extract the element from
|
|
1134
|
+
:param i: Coordinate of element on first dimension
|
|
1135
|
+
:param j: Coordinate of element on the second dimension
|
|
1136
|
+
:returns: The value of the element at the specified tile location, with the same type as the input tile's per-element dtype
|
|
1137
|
+
"""
|
|
1138
|
+
...
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
@over
|
|
1142
|
+
def tile_transpose(a: Tile) -> Tile:
|
|
1143
|
+
"""Transpose a tile.
|
|
1144
|
+
|
|
1145
|
+
For shared memory tiles this operation will alias the input tile, register tiles will first be transferred to shared memory before transposition.
|
|
1146
|
+
|
|
1147
|
+
:param a: Tile to transpose with ``shape=(M,N)``
|
|
1148
|
+
:returns: Tile with ``shape=(N,M)``
|
|
1149
|
+
"""
|
|
1150
|
+
...
|
|
1151
|
+
|
|
1152
|
+
|
|
1153
|
+
@over
|
|
1154
|
+
def tile_broadcast(a: Tile, m: int32, n: int32) -> Tile:
|
|
1155
|
+
"""Broadcast a tile.
|
|
1156
|
+
|
|
1157
|
+
This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
|
|
1158
|
+
|
|
1159
|
+
:param a: Tile to broadcast
|
|
1160
|
+
:returns: Tile with broadcast ``shape=(m, n)``
|
|
1161
|
+
"""
|
|
1162
|
+
...
|
|
1163
|
+
|
|
1164
|
+
|
|
1165
|
+
@over
|
|
1166
|
+
def tile_sum(a: Tile) -> Tile:
|
|
1167
|
+
"""Cooperatively compute the sum of the tile elements using all threads in the block.
|
|
1168
|
+
|
|
1169
|
+
:param a: The tile to compute the sum of
|
|
1170
|
+
:returns: A single-element tile with dimensions of (1,1) holding the sum
|
|
1171
|
+
|
|
1172
|
+
Example:
|
|
1173
|
+
|
|
1174
|
+
.. code-block:: python
|
|
1175
|
+
|
|
1176
|
+
@wp.kernel
|
|
1177
|
+
def compute():
|
|
1178
|
+
t = wp.tile_ones(dtype=float, m=16, n=16)
|
|
1179
|
+
s = wp.tile_sum(t)
|
|
1180
|
+
|
|
1181
|
+
print(s)
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
1185
|
+
|
|
1186
|
+
Prints:
|
|
1187
|
+
|
|
1188
|
+
.. code-block:: text
|
|
1189
|
+
|
|
1190
|
+
tile(m=1, n=1, storage=register) = [[256]]
|
|
1191
|
+
|
|
1192
|
+
|
|
1193
|
+
"""
|
|
1194
|
+
...
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
@over
|
|
1198
|
+
def tile_min(a: Tile) -> Tile:
|
|
1199
|
+
"""Cooperatively compute the minimum of the tile elements using all threads in the block.
|
|
1200
|
+
|
|
1201
|
+
:param a: The tile to compute the minimum of
|
|
1202
|
+
:returns: A single-element tile with dimensions of (1,1) holding the minimum value
|
|
1203
|
+
|
|
1204
|
+
Example:
|
|
1205
|
+
|
|
1206
|
+
.. code-block:: python
|
|
1207
|
+
|
|
1208
|
+
@wp.kernel
|
|
1209
|
+
def compute():
|
|
1210
|
+
t = wp.tile_arange(64, 128)
|
|
1211
|
+
s = wp.tile_min(t)
|
|
1212
|
+
|
|
1213
|
+
print(s)
|
|
1214
|
+
|
|
1215
|
+
|
|
1216
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
1217
|
+
|
|
1218
|
+
Prints:
|
|
1219
|
+
|
|
1220
|
+
.. code-block:: text
|
|
1221
|
+
|
|
1222
|
+
tile(m=1, n=1, storage=register) = [[64 ]]
|
|
1223
|
+
|
|
1224
|
+
|
|
1225
|
+
"""
|
|
1226
|
+
...
|
|
1227
|
+
|
|
1228
|
+
|
|
1229
|
+
@over
|
|
1230
|
+
def tile_max(a: Tile) -> Tile:
|
|
1231
|
+
"""Cooperatively compute the maximum of the tile elements using all threads in the block.
|
|
1232
|
+
|
|
1233
|
+
:param a: The tile to compute the maximum from
|
|
1234
|
+
:returns: A single-element tile with dimensions of (1,1) holding the maximum value
|
|
1235
|
+
|
|
1236
|
+
Example:
|
|
1237
|
+
|
|
1238
|
+
.. code-block:: python
|
|
1239
|
+
|
|
1240
|
+
@wp.kernel
|
|
1241
|
+
def compute():
|
|
1242
|
+
t = wp.tile_arange(64, 128)
|
|
1243
|
+
s = wp.tile_max(t)
|
|
1244
|
+
|
|
1245
|
+
print(s)
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
1249
|
+
|
|
1250
|
+
Prints:
|
|
1251
|
+
|
|
1252
|
+
.. code-block:: text
|
|
1253
|
+
|
|
1254
|
+
tile(m=1, n=1, storage=register) = [[127 ]]
|
|
1255
|
+
|
|
1256
|
+
|
|
1257
|
+
"""
|
|
1258
|
+
...
|
|
1259
|
+
|
|
1260
|
+
|
|
1261
|
+
@over
|
|
1262
|
+
def tile_reduce(op: Callable, a: Tile) -> Tile:
|
|
1263
|
+
"""Apply a custom reduction operator across the tile.
|
|
1264
|
+
|
|
1265
|
+
This function cooperatively performs a reduction using the provided operator across the tile.
|
|
1266
|
+
|
|
1267
|
+
:param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
|
|
1268
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
1269
|
+
:returns: A single-element tile with ``shape=(1,1)`` with the same datatype as the input tile.
|
|
1270
|
+
|
|
1271
|
+
Example:
|
|
1272
|
+
|
|
1273
|
+
.. code-block:: python
|
|
1274
|
+
|
|
1275
|
+
@wp.kernel
|
|
1276
|
+
def factorial():
|
|
1277
|
+
t = wp.tile_arange(1, 10, dtype=int)
|
|
1278
|
+
s = wp.tile_reduce(wp.mul, t)
|
|
1279
|
+
|
|
1280
|
+
print(s)
|
|
1281
|
+
|
|
1282
|
+
|
|
1283
|
+
wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
|
|
1284
|
+
|
|
1285
|
+
Prints:
|
|
1286
|
+
|
|
1287
|
+
.. code-block:: text
|
|
1288
|
+
|
|
1289
|
+
tile(m=1, n=1, storage=register) = [[362880]]
|
|
1290
|
+
|
|
1291
|
+
"""
|
|
1292
|
+
...
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
@over
|
|
1296
|
+
def tile_map(op: Callable, a: Tile) -> Tile:
|
|
1297
|
+
"""Apply a unary function onto the tile.
|
|
1298
|
+
|
|
1299
|
+
This function cooperatively applies a unary function to each element of the tile using all threads in the block.
|
|
1300
|
+
|
|
1301
|
+
:param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
|
|
1302
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
1303
|
+
:returns: A tile with the same dimensions and datatype as the input tile.
|
|
1304
|
+
|
|
1305
|
+
Example:
|
|
1306
|
+
|
|
1307
|
+
.. code-block:: python
|
|
1308
|
+
|
|
1309
|
+
@wp.kernel
|
|
1310
|
+
def compute():
|
|
1311
|
+
t = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
|
|
1312
|
+
s = wp.tile_map(wp.sin, t)
|
|
1313
|
+
|
|
1314
|
+
print(s)
|
|
1315
|
+
|
|
1316
|
+
|
|
1317
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
|
|
1318
|
+
|
|
1319
|
+
Prints:
|
|
1320
|
+
|
|
1321
|
+
.. code-block:: text
|
|
1322
|
+
|
|
1323
|
+
tile(m=1, n=10, storage=register) = [[0 0.0998334 0.198669 0.29552 ...]]
|
|
1324
|
+
|
|
1325
|
+
"""
|
|
1326
|
+
...
|
|
1327
|
+
|
|
1328
|
+
|
|
1329
|
+
@over
|
|
1330
|
+
def tile_map(op: Callable, a: Tile, b: Tile) -> Tile:
|
|
1331
|
+
"""Apply a binary function onto the tile.
|
|
1332
|
+
|
|
1333
|
+
This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
|
|
1334
|
+
Both input tiles must have the same dimensions and datatype.
|
|
1335
|
+
|
|
1336
|
+
:param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
|
|
1337
|
+
:param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
1338
|
+
:param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
1339
|
+
:returns: A tile with the same dimensions and datatype as the input tiles.
|
|
1340
|
+
|
|
1341
|
+
Example:
|
|
1342
|
+
|
|
1343
|
+
.. code-block:: python
|
|
1344
|
+
|
|
1345
|
+
@wp.kernel
|
|
1346
|
+
def compute():
|
|
1347
|
+
a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
|
|
1348
|
+
b = wp.tile_ones(m=1, n=10, dtype=float)
|
|
1349
|
+
|
|
1350
|
+
s = wp.tile_map(wp.add, a, b)
|
|
1351
|
+
|
|
1352
|
+
print(s)
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
|
|
1356
|
+
|
|
1357
|
+
Prints:
|
|
1358
|
+
|
|
1359
|
+
.. code-block:: text
|
|
1360
|
+
|
|
1361
|
+
tile(m=1, n=10, storage=register) = [[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]]
|
|
1362
|
+
"""
|
|
1363
|
+
...
|
|
1364
|
+
|
|
1365
|
+
|
|
897
1366
|
@over
|
|
898
1367
|
def mlp(
|
|
899
1368
|
weights: Array[float32],
|
|
@@ -1162,6 +1631,12 @@ def closest_point_edge_edge(p1: vec3f, q1: vec3f, p2: vec3f, q2: vec3f, epsilon:
|
|
|
1162
1631
|
...
|
|
1163
1632
|
|
|
1164
1633
|
|
|
1634
|
+
@over
|
|
1635
|
+
def reversed(range: range_t) -> range_t:
|
|
1636
|
+
"""Returns the range in reversed order."""
|
|
1637
|
+
...
|
|
1638
|
+
|
|
1639
|
+
|
|
1165
1640
|
@over
|
|
1166
1641
|
def volume_sample(id: uint64, uvw: vec3f, sampling_mode: int32, dtype: Any) -> Any:
|
|
1167
1642
|
"""Sample the volume of type `dtype` given by ``id`` at the volume local-space point ``uvw``.
|
|
@@ -1376,7 +1851,7 @@ def randf(state: uint32, low: float32, high: float32) -> float:
|
|
|
1376
1851
|
|
|
1377
1852
|
@over
|
|
1378
1853
|
def randn(state: uint32) -> float:
|
|
1379
|
-
"""Sample a normal distribution."""
|
|
1854
|
+
"""Sample a normal (Gaussian) distribution of mean 0 and variance 1."""
|
|
1380
1855
|
...
|
|
1381
1856
|
|
|
1382
1857
|
|
|
@@ -2091,6 +2566,12 @@ def add(a: Transformation[Scalar], b: Transformation[Scalar]) -> Transformation[
|
|
|
2091
2566
|
...
|
|
2092
2567
|
|
|
2093
2568
|
|
|
2569
|
+
@over
|
|
2570
|
+
def add(a: Tile, b: Tile) -> Tile:
|
|
2571
|
+
"""Add each element of two tiles together"""
|
|
2572
|
+
...
|
|
2573
|
+
|
|
2574
|
+
|
|
2094
2575
|
@over
|
|
2095
2576
|
def sub(a: Scalar, b: Scalar) -> Scalar:
|
|
2096
2577
|
""" """
|
|
@@ -2241,6 +2722,18 @@ def mul(a: Transformation[Scalar], b: Scalar) -> Transformation[Scalar]:
|
|
|
2241
2722
|
...
|
|
2242
2723
|
|
|
2243
2724
|
|
|
2725
|
+
@over
|
|
2726
|
+
def mul(x: Tile, y: Scalar) -> Tile:
|
|
2727
|
+
"""Multiply each element of a tile by a scalar"""
|
|
2728
|
+
...
|
|
2729
|
+
|
|
2730
|
+
|
|
2731
|
+
@over
|
|
2732
|
+
def mul(x: Scalar, y: Tile) -> Tile:
|
|
2733
|
+
"""Multiply each element of a tile by a scalar"""
|
|
2734
|
+
...
|
|
2735
|
+
|
|
2736
|
+
|
|
2244
2737
|
@over
|
|
2245
2738
|
def mod(a: Scalar, b: Scalar) -> Scalar:
|
|
2246
2739
|
"""Modulo operation using truncated division."""
|
|
@@ -2349,6 +2842,12 @@ def neg(x: Matrix[Any, Any, Scalar]) -> Matrix[Any, Any, Scalar]:
|
|
|
2349
2842
|
...
|
|
2350
2843
|
|
|
2351
2844
|
|
|
2845
|
+
@over
|
|
2846
|
+
def neg(x: Tile) -> Tile:
|
|
2847
|
+
"""Negate each element of a tile"""
|
|
2848
|
+
...
|
|
2849
|
+
|
|
2850
|
+
|
|
2352
2851
|
@over
|
|
2353
2852
|
def unot(a: bool) -> bool:
|
|
2354
2853
|
""" """
|
|
@@ -2409,6 +2908,72 @@ def unot(a: Array[Any]) -> bool:
|
|
|
2409
2908
|
...
|
|
2410
2909
|
|
|
2411
2910
|
|
|
2911
|
+
@over
|
|
2912
|
+
def tile_matmul(a: Tile, b: Tile, out: Tile) -> Tile:
|
|
2913
|
+
"""Computes the matrix product and accumulates ``out += a*b``.
|
|
2914
|
+
|
|
2915
|
+
Supported datatypes are:
|
|
2916
|
+
* fp16, fp32, fp64 (real)
|
|
2917
|
+
* vec2h, vec2f, vec2d (complex)
|
|
2918
|
+
|
|
2919
|
+
All input and output tiles must have the same datatype. Tile data will be automatically be migrated
|
|
2920
|
+
to shared memory if necessary and will use TensorCore operations when available.
|
|
2921
|
+
|
|
2922
|
+
:param a: A tile with ``shape=(M, K)``
|
|
2923
|
+
:param b: A tile with ``shape=(K, N)``
|
|
2924
|
+
:param out: A tile with ``shape=(M, N)``
|
|
2925
|
+
|
|
2926
|
+
"""
|
|
2927
|
+
...
|
|
2928
|
+
|
|
2929
|
+
|
|
2930
|
+
@over
|
|
2931
|
+
def tile_matmul(a: Tile, b: Tile) -> Tile:
|
|
2932
|
+
"""Computes the matrix product ``out = a*b``.
|
|
2933
|
+
|
|
2934
|
+
Supported datatypes are:
|
|
2935
|
+
* fp16, fp32, fp64 (real)
|
|
2936
|
+
* vec2h, vec2f, vec2d (complex)
|
|
2937
|
+
|
|
2938
|
+
Both input tiles must have the same datatype. Tile data will be automatically be migrated
|
|
2939
|
+
to shared memory if necessary and will use TensorCore operations when available.
|
|
2940
|
+
|
|
2941
|
+
:param a: A tile with ``shape=(M, K)``
|
|
2942
|
+
:param b: A tile with ``shape=(K, N)``
|
|
2943
|
+
:returns: A tile with ``shape=(M, N)``
|
|
2944
|
+
|
|
2945
|
+
"""
|
|
2946
|
+
...
|
|
2947
|
+
|
|
2948
|
+
|
|
2949
|
+
@over
|
|
2950
|
+
def tile_fft(inout: Tile) -> Tile:
|
|
2951
|
+
"""Compute the forward FFT along the second dimension of a 2D tile of data.
|
|
2952
|
+
|
|
2953
|
+
This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
|
|
2954
|
+
|
|
2955
|
+
Supported datatypes are:
|
|
2956
|
+
* vec2f, vec2d
|
|
2957
|
+
|
|
2958
|
+
:param inout: The input/output tile
|
|
2959
|
+
"""
|
|
2960
|
+
...
|
|
2961
|
+
|
|
2962
|
+
|
|
2963
|
+
@over
|
|
2964
|
+
def tile_ifft(inout: Tile) -> Tile:
|
|
2965
|
+
"""Compute the inverse FFT along the second dimension of a 2D tile of data.
|
|
2966
|
+
|
|
2967
|
+
This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
|
|
2968
|
+
|
|
2969
|
+
Supported datatypes are:
|
|
2970
|
+
* vec2f, vec2d
|
|
2971
|
+
|
|
2972
|
+
:param inout: The input/output tile
|
|
2973
|
+
"""
|
|
2974
|
+
...
|
|
2975
|
+
|
|
2976
|
+
|
|
2412
2977
|
@over
|
|
2413
2978
|
def static(expr: Any) -> Any:
|
|
2414
2979
|
"""Evaluates a static Python expression and replaces it with its result.
|
warp/tape.py
CHANGED
|
@@ -15,7 +15,7 @@ class Tape:
|
|
|
15
15
|
"""
|
|
16
16
|
Record kernel launches within a Tape scope to enable automatic differentiation.
|
|
17
17
|
Gradients can be computed after the operations have been recorded on the tape via
|
|
18
|
-
|
|
18
|
+
:meth:`Tape.backward()`.
|
|
19
19
|
|
|
20
20
|
Example
|
|
21
21
|
-------
|
|
@@ -131,6 +131,7 @@ class Tape:
|
|
|
131
131
|
inputs = launch[3]
|
|
132
132
|
outputs = launch[4]
|
|
133
133
|
device = launch[5]
|
|
134
|
+
block_dim = launch[6]
|
|
134
135
|
|
|
135
136
|
adj_inputs = []
|
|
136
137
|
adj_outputs = []
|
|
@@ -153,13 +154,14 @@ class Tape:
|
|
|
153
154
|
device=device,
|
|
154
155
|
adjoint=True,
|
|
155
156
|
max_blocks=max_blocks,
|
|
157
|
+
block_dim=block_dim,
|
|
156
158
|
)
|
|
157
159
|
|
|
158
160
|
# record a kernel launch on the tape
|
|
159
|
-
def record_launch(self, kernel, dim, max_blocks, inputs, outputs, device, metadata=None):
|
|
161
|
+
def record_launch(self, kernel, dim, max_blocks, inputs, outputs, device, block_dim=0, metadata=None):
|
|
160
162
|
if metadata is None:
|
|
161
163
|
metadata = {}
|
|
162
|
-
self.launches.append([kernel, dim, max_blocks, inputs, outputs, device, metadata])
|
|
164
|
+
self.launches.append([kernel, dim, max_blocks, inputs, outputs, device, block_dim, metadata])
|
|
163
165
|
|
|
164
166
|
def record_func(self, backward, arrays):
|
|
165
167
|
"""
|
|
@@ -614,7 +616,9 @@ class ArrayStatsVisitor(TapeVisitor):
|
|
|
614
616
|
self.array_grad_stats.insert(0, grad_stats)
|
|
615
617
|
|
|
616
618
|
|
|
617
|
-
Launch = namedtuple(
|
|
619
|
+
Launch = namedtuple(
|
|
620
|
+
"Launch", ["id", "kernel", "dim", "max_blocks", "inputs", "outputs", "device", "block_dim", "metadata"]
|
|
621
|
+
)
|
|
618
622
|
RepeatedSequence = namedtuple("RepeatedSequence", ["start", "end", "repetitions"])
|
|
619
623
|
|
|
620
624
|
|
|
@@ -645,8 +649,8 @@ def visit_tape(
|
|
|
645
649
|
def get_launch_id(launch):
|
|
646
650
|
kernel = launch[0]
|
|
647
651
|
suffix = ""
|
|
648
|
-
if len(launch) >
|
|
649
|
-
metadata = launch[
|
|
652
|
+
if len(launch) > 7:
|
|
653
|
+
metadata = launch[7]
|
|
650
654
|
# calling function helps to identify unique launches
|
|
651
655
|
if "caller" in metadata:
|
|
652
656
|
caller = metadata["caller"]
|
|
@@ -680,7 +684,8 @@ def visit_tape(
|
|
|
680
684
|
inputs=launch[3],
|
|
681
685
|
outputs=launch[4],
|
|
682
686
|
device=launch[5],
|
|
683
|
-
|
|
687
|
+
block_dim=launch[6],
|
|
688
|
+
metadata=launch[7] if len(launch) > 7 else {},
|
|
684
689
|
)
|
|
685
690
|
for launch in kernel_launches
|
|
686
691
|
]
|
|
Binary file
|