warp-lang 1.4.1__py3-none-win_amd64.whl → 1.5.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1920 -111
- warp/codegen.py +186 -62
- warp/config.py +2 -2
- warp/context.py +322 -73
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/core/example_dem.py +2 -1
- warp/examples/core/example_mesh_intersect.py +3 -3
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/optim/example_walker.py +2 -2
- warp/examples/sim/example_cloth.py +2 -25
- warp/examples/sim/example_jacobian_ik.py +6 -2
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -5
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +55 -40
- warp/native/builtin.h +124 -43
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +600 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1857 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +137 -65
- warp/sim/graph_coloring.py +292 -0
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +90 -17
- warp/stubs.py +651 -85
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +207 -48
- warp/tests/test_closest_point_edge_edge.py +8 -8
- warp/tests/test_codegen.py +120 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +241 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +18 -4
- warp/tests/test_fabricarray.py +33 -0
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +48 -1
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_mesh_query_point.py +5 -4
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +191 -1
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_tile.py +700 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +23 -2
- warp/tests/unittest_utils.py +4 -0
- warp/types.py +339 -73
- warp/utils.py +22 -1
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -5,6 +5,9 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
import builtins
|
|
8
|
+
import functools
|
|
9
|
+
import tempfile
|
|
10
|
+
from pathlib import Path
|
|
8
11
|
from typing import Any, Callable, Mapping, Sequence
|
|
9
12
|
|
|
10
13
|
from warp.codegen import Reference, Var, strip_reference
|
|
@@ -1497,7 +1500,8 @@ add_builtin(
|
|
|
1497
1500
|
doc="""Apply the transform to a point ``point`` treating the homogeneous coordinate as w=1.
|
|
1498
1501
|
|
|
1499
1502
|
The transformation is applied treating ``point`` as a column vector, e.g.: ``y = mat*point``.
|
|
1500
|
-
|
|
1503
|
+
|
|
1504
|
+
This is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = point^T*mat^T``.
|
|
1501
1505
|
If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
|
|
1502
1506
|
matrix before calling this method.""",
|
|
1503
1507
|
)
|
|
@@ -1515,8 +1519,9 @@ add_builtin(
|
|
|
1515
1519
|
group="Vector Math",
|
|
1516
1520
|
doc="""Apply the transform to a vector ``vec`` treating the homogeneous coordinate as w=0.
|
|
1517
1521
|
|
|
1518
|
-
The transformation is applied treating ``vec`` as a column vector, e.g.: ``y = mat*vec
|
|
1519
|
-
|
|
1522
|
+
The transformation is applied treating ``vec`` as a column vector, e.g.: ``y = mat*vec``.
|
|
1523
|
+
|
|
1524
|
+
This is in contrast to some libraries, notably USD, which applies transforms to row vectors, ``y^T = vec^T*mat^T``.
|
|
1520
1525
|
If the transform is coming from a library that uses row-vectors, then users should transpose the transformation
|
|
1521
1526
|
matrix before calling this method.""",
|
|
1522
1527
|
)
|
|
@@ -1622,84 +1627,1343 @@ add_builtin(
|
|
|
1622
1627
|
|
|
1623
1628
|
|
|
1624
1629
|
add_builtin(
|
|
1625
|
-
"spatial_adjoint",
|
|
1626
|
-
input_types={"r": matrix(shape=(3, 3), dtype=Float), "s": matrix(shape=(3, 3), dtype=Float)},
|
|
1627
|
-
value_func=lambda arg_types, arg_values: matrix(shape=(6, 6), dtype=float_infer_type(arg_types)),
|
|
1628
|
-
group="Spatial Math",
|
|
1629
|
-
doc="Construct a 6x6 spatial inertial matrix from two 3x3 diagonal blocks.",
|
|
1630
|
+
"spatial_adjoint",
|
|
1631
|
+
input_types={"r": matrix(shape=(3, 3), dtype=Float), "s": matrix(shape=(3, 3), dtype=Float)},
|
|
1632
|
+
value_func=lambda arg_types, arg_values: matrix(shape=(6, 6), dtype=float_infer_type(arg_types)),
|
|
1633
|
+
group="Spatial Math",
|
|
1634
|
+
doc="Construct a 6x6 spatial inertial matrix from two 3x3 diagonal blocks.",
|
|
1635
|
+
export=False,
|
|
1636
|
+
)
|
|
1637
|
+
add_builtin(
|
|
1638
|
+
"spatial_dot",
|
|
1639
|
+
input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
|
|
1640
|
+
value_func=float_sametypes_value_func,
|
|
1641
|
+
group="Spatial Math",
|
|
1642
|
+
doc="Compute the dot product of two 6D screw vectors.",
|
|
1643
|
+
)
|
|
1644
|
+
add_builtin(
|
|
1645
|
+
"spatial_cross",
|
|
1646
|
+
input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
|
|
1647
|
+
value_func=sametypes_create_value_func(vector(length=6, dtype=Float)),
|
|
1648
|
+
group="Spatial Math",
|
|
1649
|
+
doc="Compute the cross product of two 6D screw vectors.",
|
|
1650
|
+
)
|
|
1651
|
+
add_builtin(
|
|
1652
|
+
"spatial_cross_dual",
|
|
1653
|
+
input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
|
|
1654
|
+
value_func=sametypes_create_value_func(vector(length=6, dtype=Float)),
|
|
1655
|
+
group="Spatial Math",
|
|
1656
|
+
doc="Compute the dual cross product of two 6D screw vectors.",
|
|
1657
|
+
)
|
|
1658
|
+
|
|
1659
|
+
add_builtin(
|
|
1660
|
+
"spatial_top",
|
|
1661
|
+
input_types={"svec": vector(length=6, dtype=Float)},
|
|
1662
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=Float)
|
|
1663
|
+
if arg_types is None
|
|
1664
|
+
else vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
|
|
1665
|
+
group="Spatial Math",
|
|
1666
|
+
doc="Return the top (first) part of a 6D screw vector.",
|
|
1667
|
+
)
|
|
1668
|
+
add_builtin(
|
|
1669
|
+
"spatial_bottom",
|
|
1670
|
+
input_types={"svec": vector(length=6, dtype=Float)},
|
|
1671
|
+
value_func=lambda arg_types, arg_values: vector(length=3, dtype=Float)
|
|
1672
|
+
if arg_types is None
|
|
1673
|
+
else vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
|
|
1674
|
+
group="Spatial Math",
|
|
1675
|
+
doc="Return the bottom (second) part of a 6D screw vector.",
|
|
1676
|
+
)
|
|
1677
|
+
|
|
1678
|
+
add_builtin(
|
|
1679
|
+
"spatial_jacobian",
|
|
1680
|
+
input_types={
|
|
1681
|
+
"S": array(dtype=vector(length=6, dtype=Float)),
|
|
1682
|
+
"joint_parents": array(dtype=int),
|
|
1683
|
+
"joint_qd_start": array(dtype=int),
|
|
1684
|
+
"joint_start": int,
|
|
1685
|
+
"joint_count": int,
|
|
1686
|
+
"J_start": int,
|
|
1687
|
+
"J_out": array(dtype=Float),
|
|
1688
|
+
},
|
|
1689
|
+
value_type=None,
|
|
1690
|
+
doc="",
|
|
1691
|
+
group="Spatial Math",
|
|
1692
|
+
)
|
|
1693
|
+
|
|
1694
|
+
add_builtin(
|
|
1695
|
+
"spatial_mass",
|
|
1696
|
+
input_types={
|
|
1697
|
+
"I_s": array(dtype=matrix(shape=(6, 6), dtype=Float)),
|
|
1698
|
+
"joint_start": int,
|
|
1699
|
+
"joint_count": int,
|
|
1700
|
+
"M_start": int,
|
|
1701
|
+
"M": array(dtype=Float),
|
|
1702
|
+
},
|
|
1703
|
+
value_type=None,
|
|
1704
|
+
doc="",
|
|
1705
|
+
group="Spatial Math",
|
|
1706
|
+
)
|
|
1707
|
+
|
|
1708
|
+
# ------------------
|
|
1709
|
+
# Tile-based primitives
|
|
1710
|
+
shared_memory_id = 0
|
|
1711
|
+
|
|
1712
|
+
|
|
1713
|
+
def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1714
|
+
# return generic type (for doc builds)
|
|
1715
|
+
if arg_types is None:
|
|
1716
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
1717
|
+
|
|
1718
|
+
if "m" not in arg_values:
|
|
1719
|
+
raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
|
|
1720
|
+
|
|
1721
|
+
if "n" not in arg_values:
|
|
1722
|
+
raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
|
|
1723
|
+
|
|
1724
|
+
if "dtype" not in arg_values:
|
|
1725
|
+
raise RuntimeError("'dtype' keyword argument must be specified when calling tile_zeros() function")
|
|
1726
|
+
|
|
1727
|
+
if "storage" not in arg_values:
|
|
1728
|
+
raise ValueError("'storage' keyword not provided for tile_zeros")
|
|
1729
|
+
|
|
1730
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
1731
|
+
raise ValueError(
|
|
1732
|
+
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1733
|
+
)
|
|
1734
|
+
|
|
1735
|
+
m, n = arg_values["m"], arg_values["n"]
|
|
1736
|
+
dtype = arg_values["dtype"]
|
|
1737
|
+
|
|
1738
|
+
return TileZeros(dtype=dtype, M=m, N=n, storage=arg_values["storage"])
|
|
1739
|
+
|
|
1740
|
+
|
|
1741
|
+
def tile_zeros_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1742
|
+
m, n, dtype = arg_values["m"], arg_values["n"], arg_values["dtype"]
|
|
1743
|
+
|
|
1744
|
+
template_args = []
|
|
1745
|
+
template_args.append(dtype)
|
|
1746
|
+
template_args.append(m.constant)
|
|
1747
|
+
template_args.append(n.constant)
|
|
1748
|
+
|
|
1749
|
+
return ([], template_args)
|
|
1750
|
+
|
|
1751
|
+
|
|
1752
|
+
add_builtin(
|
|
1753
|
+
"tile_zeros",
|
|
1754
|
+
input_types={"m": int, "n": int, "dtype": Any, "storage": str},
|
|
1755
|
+
defaults={"storage": "register"},
|
|
1756
|
+
value_func=tile_zeros_value_func,
|
|
1757
|
+
dispatch_func=tile_zeros_dispatch_func,
|
|
1758
|
+
variadic=False,
|
|
1759
|
+
missing_grad=True,
|
|
1760
|
+
doc="""Allocates a tile of zero-initialized items.
|
|
1761
|
+
|
|
1762
|
+
:param m: Size of the first dimension of the output tile
|
|
1763
|
+
:param n: Size of the second dimension of the output tile
|
|
1764
|
+
:param dtype: Datatype of output tile's elements
|
|
1765
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1766
|
+
(default) or ``"shared"`` for shared memory.
|
|
1767
|
+
:returns: A zero-initialized tile with ``shape=(m,n)`` and the specified datatype""",
|
|
1768
|
+
group="Tile Primitives",
|
|
1769
|
+
export=False,
|
|
1770
|
+
)
|
|
1771
|
+
|
|
1772
|
+
|
|
1773
|
+
def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1774
|
+
# return generic type (for doc builds)
|
|
1775
|
+
if arg_types is None:
|
|
1776
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
1777
|
+
|
|
1778
|
+
if "m" not in arg_values:
|
|
1779
|
+
raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
|
|
1780
|
+
|
|
1781
|
+
if "n" not in arg_values:
|
|
1782
|
+
raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
|
|
1783
|
+
|
|
1784
|
+
if "dtype" not in arg_values:
|
|
1785
|
+
raise RuntimeError("'dtype' keyword argument must be specified when calling tile_zeros() function")
|
|
1786
|
+
|
|
1787
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
1788
|
+
raise ValueError(
|
|
1789
|
+
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1790
|
+
)
|
|
1791
|
+
|
|
1792
|
+
m, n = arg_values["m"], arg_values["n"]
|
|
1793
|
+
dtype = arg_values["dtype"]
|
|
1794
|
+
|
|
1795
|
+
return TileZeros(dtype=dtype, M=m, N=n, storage=arg_values["storage"])
|
|
1796
|
+
|
|
1797
|
+
|
|
1798
|
+
def tile_ones_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1799
|
+
m, n, dtype = arg_values["m"], arg_values["n"], arg_values["dtype"]
|
|
1800
|
+
|
|
1801
|
+
template_args = []
|
|
1802
|
+
template_args.append(dtype)
|
|
1803
|
+
template_args.append(m.constant)
|
|
1804
|
+
template_args.append(n.constant)
|
|
1805
|
+
|
|
1806
|
+
return ([], template_args)
|
|
1807
|
+
|
|
1808
|
+
|
|
1809
|
+
add_builtin(
|
|
1810
|
+
"tile_ones",
|
|
1811
|
+
input_types={"m": int, "n": int, "dtype": Any, "storage": str},
|
|
1812
|
+
defaults={"storage": "register"},
|
|
1813
|
+
value_func=tile_ones_value_func,
|
|
1814
|
+
dispatch_func=tile_ones_dispatch_func,
|
|
1815
|
+
missing_grad=True,
|
|
1816
|
+
doc="""Allocates a tile of one-initialized items.
|
|
1817
|
+
|
|
1818
|
+
:param m: Size of the first dimension of the output tile
|
|
1819
|
+
:param n: Size of the second dimension of the output tile
|
|
1820
|
+
:param dtype: Datatype of output tile's elements
|
|
1821
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1822
|
+
(default) or ``"shared"`` for shared memory.
|
|
1823
|
+
:returns: A one-initialized tile with ``shape=(m,n)`` and the specified dtype""",
|
|
1824
|
+
group="Tile Primitives",
|
|
1825
|
+
export=False,
|
|
1826
|
+
)
|
|
1827
|
+
|
|
1828
|
+
|
|
1829
|
+
def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1830
|
+
# return generic type (for doc builds)
|
|
1831
|
+
if arg_types is None:
|
|
1832
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
1833
|
+
|
|
1834
|
+
start = 0
|
|
1835
|
+
stop = 0
|
|
1836
|
+
step = 1
|
|
1837
|
+
dtype = int
|
|
1838
|
+
|
|
1839
|
+
args = arg_values["args"]
|
|
1840
|
+
|
|
1841
|
+
if len(args) == 1:
|
|
1842
|
+
start = 0
|
|
1843
|
+
stop = args[0]
|
|
1844
|
+
|
|
1845
|
+
elif len(args) == 2:
|
|
1846
|
+
start = args[0]
|
|
1847
|
+
stop = args[1]
|
|
1848
|
+
|
|
1849
|
+
elif len(args) == 3:
|
|
1850
|
+
start = args[0]
|
|
1851
|
+
stop = args[1]
|
|
1852
|
+
step = args[2]
|
|
1853
|
+
|
|
1854
|
+
if start is None or stop is None or step is None:
|
|
1855
|
+
raise RuntimeError("wp.tile_arange() arguments must be compile time constants")
|
|
1856
|
+
|
|
1857
|
+
if "dtype" in arg_values:
|
|
1858
|
+
dtype = arg_values["dtype"]
|
|
1859
|
+
else:
|
|
1860
|
+
dtype = float
|
|
1861
|
+
|
|
1862
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
1863
|
+
raise ValueError(
|
|
1864
|
+
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1865
|
+
)
|
|
1866
|
+
|
|
1867
|
+
return TileRange(dtype=dtype, start=start, stop=stop, step=step, storage=arg_values["storage"])
|
|
1868
|
+
|
|
1869
|
+
|
|
1870
|
+
def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1871
|
+
m, n, dtype = return_type.M, return_type.N, return_type.dtype
|
|
1872
|
+
|
|
1873
|
+
template_args = []
|
|
1874
|
+
template_args.append(dtype)
|
|
1875
|
+
template_args.append(m)
|
|
1876
|
+
template_args.append(n)
|
|
1877
|
+
|
|
1878
|
+
# todo: it is somewhat redundant to create new vars here since some of start,stop,step
|
|
1879
|
+
# already exist depending on which form the function was called by the user
|
|
1880
|
+
start = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.start)
|
|
1881
|
+
stop = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.stop)
|
|
1882
|
+
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
|
|
1883
|
+
|
|
1884
|
+
function_args = []
|
|
1885
|
+
function_args.append(start)
|
|
1886
|
+
function_args.append(stop)
|
|
1887
|
+
function_args.append(step)
|
|
1888
|
+
|
|
1889
|
+
return (function_args, template_args)
|
|
1890
|
+
|
|
1891
|
+
|
|
1892
|
+
add_builtin(
|
|
1893
|
+
"tile_arange",
|
|
1894
|
+
input_types={"*args": Scalar, "dtype": Any, "storage": str},
|
|
1895
|
+
defaults={"dtype": None, "storage": "register"},
|
|
1896
|
+
value_func=tile_arange_value_func,
|
|
1897
|
+
dispatch_func=tile_arange_dispatch_func,
|
|
1898
|
+
variadic=True,
|
|
1899
|
+
missing_grad=True,
|
|
1900
|
+
doc="""Generates a tile of linearly spaced elements.
|
|
1901
|
+
|
|
1902
|
+
:param args: Variable-length positional arguments, interpreted as:
|
|
1903
|
+
|
|
1904
|
+
- ``(stop,)``: Generates values from ``0`` to ``stop - 1``
|
|
1905
|
+
- ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
|
|
1906
|
+
- ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
|
|
1907
|
+
|
|
1908
|
+
:param dtype: Datatype of output tile's elements (optional, default: int)
|
|
1909
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1910
|
+
(default) or ``"shared"`` for shared memory.
|
|
1911
|
+
:returns: A tile with ``shape=(1,n)`` with linearly spaced elements of specified dtype""",
|
|
1912
|
+
group="Tile Primitives",
|
|
1913
|
+
export=False,
|
|
1914
|
+
)
|
|
1915
|
+
|
|
1916
|
+
|
|
1917
|
+
def tile_load_1d_value_func(arg_types, arg_values):
|
|
1918
|
+
# return generic type (for doc builds)
|
|
1919
|
+
if arg_types is None:
|
|
1920
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
1921
|
+
|
|
1922
|
+
if not is_array(arg_types["a"]):
|
|
1923
|
+
raise RuntimeError("tile_load() argument 0 must be an array")
|
|
1924
|
+
|
|
1925
|
+
if arg_types["a"].ndim != 1:
|
|
1926
|
+
raise RuntimeError(
|
|
1927
|
+
"tile_load() argument 0 must be 1-dimensional if using the ``wp.tile_load(array, i, n)`` syntax."
|
|
1928
|
+
)
|
|
1929
|
+
|
|
1930
|
+
if not type_is_int(arg_types["i"]):
|
|
1931
|
+
raise RuntimeError("tile_load() argument 1 must be an integer")
|
|
1932
|
+
|
|
1933
|
+
if "n" not in arg_values:
|
|
1934
|
+
raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
|
|
1935
|
+
|
|
1936
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
1937
|
+
raise ValueError(
|
|
1938
|
+
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1939
|
+
)
|
|
1940
|
+
|
|
1941
|
+
a = arg_types["a"]
|
|
1942
|
+
_m, n = 1, arg_values["n"]
|
|
1943
|
+
|
|
1944
|
+
return TileLoad(a, 1, n, arg_values["storage"])
|
|
1945
|
+
|
|
1946
|
+
|
|
1947
|
+
def tile_load_1d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1948
|
+
array = arg_values["a"]
|
|
1949
|
+
i = arg_values["i"]
|
|
1950
|
+
n = arg_values["n"].constant
|
|
1951
|
+
dtype = arg_values["a"].type.dtype
|
|
1952
|
+
|
|
1953
|
+
template_args = []
|
|
1954
|
+
template_args.append(dtype)
|
|
1955
|
+
template_args.append(n)
|
|
1956
|
+
|
|
1957
|
+
return ((array, i), template_args)
|
|
1958
|
+
|
|
1959
|
+
|
|
1960
|
+
add_builtin(
|
|
1961
|
+
"tile_load",
|
|
1962
|
+
input_types={"a": array(dtype=Any), "i": int, "n": int, "storage": str},
|
|
1963
|
+
defaults={"storage": "register"},
|
|
1964
|
+
value_func=tile_load_1d_value_func,
|
|
1965
|
+
dispatch_func=tile_load_1d_dispatch_func,
|
|
1966
|
+
variadic=False,
|
|
1967
|
+
doc="""Loads a 1D tile from a global memory array.
|
|
1968
|
+
|
|
1969
|
+
This method will cooperatively load a tile from global memory using all threads in the block.
|
|
1970
|
+
|
|
1971
|
+
:param a: The source array in global memory
|
|
1972
|
+
:param i: Offset in the source array measured in multiples of ``n``, i.e.: ``offset=i*n``
|
|
1973
|
+
:param n: The number of elements in the tile
|
|
1974
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1975
|
+
(default) or ``"shared"`` for shared memory.
|
|
1976
|
+
:returns: A tile with ``shape=(1,n)`` and dtype the same as the source array""",
|
|
1977
|
+
group="Tile Primitives",
|
|
1978
|
+
export=False,
|
|
1979
|
+
)
|
|
1980
|
+
|
|
1981
|
+
|
|
1982
|
+
def tile_load_2d_value_func(arg_types, arg_values):
|
|
1983
|
+
# return generic type (for doc builds)
|
|
1984
|
+
if arg_types is None:
|
|
1985
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
1986
|
+
|
|
1987
|
+
if not is_array(arg_types["a"]):
|
|
1988
|
+
raise RuntimeError("tile_load() argument 0 must be an array")
|
|
1989
|
+
|
|
1990
|
+
if arg_types["a"].ndim != 2:
|
|
1991
|
+
raise RuntimeError(
|
|
1992
|
+
"tile_load() argument 0 must be 2-dimensional if using the ``wp.tile_load(array, i, j, m, n)`` syntax."
|
|
1993
|
+
)
|
|
1994
|
+
|
|
1995
|
+
if not type_is_int(arg_types["i"]):
|
|
1996
|
+
raise RuntimeError("tile_load() argument 1 must be an integer")
|
|
1997
|
+
|
|
1998
|
+
if not type_is_int(arg_types["j"]):
|
|
1999
|
+
raise RuntimeError("tile_load() argument 1 must be an integer")
|
|
2000
|
+
|
|
2001
|
+
if "m" not in arg_values:
|
|
2002
|
+
raise RuntimeError("'m' keyword argument must be specified when calling tile_load() function")
|
|
2003
|
+
|
|
2004
|
+
if "n" not in arg_values:
|
|
2005
|
+
raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
|
|
2006
|
+
|
|
2007
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
2008
|
+
raise ValueError(
|
|
2009
|
+
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
2010
|
+
)
|
|
2011
|
+
|
|
2012
|
+
a = arg_types["a"]
|
|
2013
|
+
m, n = arg_values["m"], arg_values["n"]
|
|
2014
|
+
|
|
2015
|
+
return TileLoad(a, m, n, arg_values["storage"])
|
|
2016
|
+
|
|
2017
|
+
|
|
2018
|
+
def tile_load_2d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2019
|
+
array = arg_values["a"]
|
|
2020
|
+
i, j = arg_values["i"], arg_values["j"]
|
|
2021
|
+
m, n = arg_values["m"].constant, arg_values["n"].constant
|
|
2022
|
+
dtype = arg_values["a"].type.dtype
|
|
2023
|
+
|
|
2024
|
+
template_args = []
|
|
2025
|
+
template_args.append(dtype)
|
|
2026
|
+
template_args.append(m)
|
|
2027
|
+
template_args.append(n)
|
|
2028
|
+
|
|
2029
|
+
return ((array, i, j), template_args)
|
|
2030
|
+
|
|
2031
|
+
|
|
2032
|
+
add_builtin(
|
|
2033
|
+
"tile_load",
|
|
2034
|
+
input_types={"a": array(dtype=Any), "i": int, "j": int, "m": int, "n": int, "storage": str},
|
|
2035
|
+
defaults={"storage": "register"},
|
|
2036
|
+
value_func=tile_load_2d_value_func,
|
|
2037
|
+
dispatch_func=tile_load_2d_dispatch_func,
|
|
2038
|
+
variadic=False,
|
|
2039
|
+
doc="""Loads a 2D tile from a global memory array.
|
|
2040
|
+
|
|
2041
|
+
This method will cooperatively load a tile from global memory using all threads in the block.
|
|
2042
|
+
|
|
2043
|
+
:param a: The source array in global memory
|
|
2044
|
+
:param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
|
|
2045
|
+
:param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
|
|
2046
|
+
:param m: The size of the tile's first dimension
|
|
2047
|
+
:param n: The size of the tile's second dimension
|
|
2048
|
+
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
2049
|
+
(default) or ``"shared"`` for shared memory.
|
|
2050
|
+
:returns: A tile with ``shape=(m,n)`` and dtype the same as the source array""",
|
|
2051
|
+
group="Tile Primitives",
|
|
2052
|
+
export=False,
|
|
2053
|
+
)
|
|
2054
|
+
|
|
2055
|
+
|
|
2056
|
+
def tile_store_1d_value_func(arg_types, arg_values):
|
|
2057
|
+
# return generic type (for doc builds)
|
|
2058
|
+
if arg_types is None:
|
|
2059
|
+
return None
|
|
2060
|
+
|
|
2061
|
+
if len(arg_types) != 3:
|
|
2062
|
+
raise RuntimeError("tile_store() requires 3 positional args")
|
|
2063
|
+
|
|
2064
|
+
if not is_array(arg_types["a"]):
|
|
2065
|
+
raise RuntimeError("tile_store() argument 0 must be an array")
|
|
2066
|
+
|
|
2067
|
+
if arg_types["a"].ndim != 1:
|
|
2068
|
+
raise RuntimeError(
|
|
2069
|
+
"tile_load() argument 0 must be a 1-dimensional array if using the ``wp.tile_store(array, i, t)`` syntax."
|
|
2070
|
+
)
|
|
2071
|
+
|
|
2072
|
+
if not type_is_int(arg_types["i"]):
|
|
2073
|
+
raise RuntimeError("tile_store() argument 1 must be an integer")
|
|
2074
|
+
|
|
2075
|
+
if not is_tile(arg_types["t"]):
|
|
2076
|
+
raise RuntimeError("tile_store() argument 2 must be a tile")
|
|
2077
|
+
|
|
2078
|
+
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2079
|
+
raise RuntimeError("tile_store() destination array must have same type as source tile")
|
|
2080
|
+
|
|
2081
|
+
return None
|
|
2082
|
+
|
|
2083
|
+
|
|
2084
|
+
add_builtin(
|
|
2085
|
+
"tile_store",
|
|
2086
|
+
input_types={"a": array(dtype=Any), "i": int, "t": Any},
|
|
2087
|
+
value_func=tile_store_1d_value_func,
|
|
2088
|
+
variadic=False,
|
|
2089
|
+
skip_replay=True,
|
|
2090
|
+
doc="""Stores a 1D tile to a global memory array.
|
|
2091
|
+
|
|
2092
|
+
This method will cooperatively store a tile to global memory using all threads in the block.
|
|
2093
|
+
|
|
2094
|
+
:param a: The destination array in global memory
|
|
2095
|
+
:param i: Offset in the destination array measured in multiples of ``n``, i.e.: ``offset=i*n``
|
|
2096
|
+
:param t: The source tile to store data from, must have the same dtype as the destination array""",
|
|
2097
|
+
group="Tile Primitives",
|
|
2098
|
+
export=False,
|
|
2099
|
+
)
|
|
2100
|
+
|
|
2101
|
+
|
|
2102
|
+
def tile_store_2d_value_func(arg_types, arg_values):
|
|
2103
|
+
# return generic type (for doc builds)
|
|
2104
|
+
if arg_types is None:
|
|
2105
|
+
return None
|
|
2106
|
+
|
|
2107
|
+
if len(arg_types) != 4:
|
|
2108
|
+
raise RuntimeError("tile_store() requires 4 positional args")
|
|
2109
|
+
|
|
2110
|
+
if not is_array(arg_types["a"]):
|
|
2111
|
+
raise RuntimeError("tile_store() argument 0 must be an array")
|
|
2112
|
+
|
|
2113
|
+
if arg_types["a"].ndim != 2:
|
|
2114
|
+
raise RuntimeError(
|
|
2115
|
+
"tile_load() argument 0 must be a 2-dimensional array if using the ``wp.tile_store(array, i, j, t)`` syntax."
|
|
2116
|
+
)
|
|
2117
|
+
|
|
2118
|
+
if not type_is_int(arg_types["i"]):
|
|
2119
|
+
raise RuntimeError("tile_store() argument 1 must be an integer")
|
|
2120
|
+
|
|
2121
|
+
if not type_is_int(arg_types["j"]):
|
|
2122
|
+
raise RuntimeError("tile_store() argument 2 must be an integer")
|
|
2123
|
+
|
|
2124
|
+
if not is_tile(arg_types["t"]):
|
|
2125
|
+
raise RuntimeError("tile_store() argument 3 must be a tile")
|
|
2126
|
+
|
|
2127
|
+
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2128
|
+
raise RuntimeError("tile_store() destination array must have same type as source tile")
|
|
2129
|
+
|
|
2130
|
+
return None
|
|
2131
|
+
|
|
2132
|
+
|
|
2133
|
+
add_builtin(
|
|
2134
|
+
"tile_store",
|
|
2135
|
+
input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Any},
|
|
2136
|
+
value_func=tile_store_2d_value_func,
|
|
2137
|
+
variadic=False,
|
|
2138
|
+
skip_replay=True,
|
|
2139
|
+
doc="""Stores a tile to a global memory array.
|
|
2140
|
+
|
|
2141
|
+
This method will cooperatively store a tile to global memory using all threads in the block.
|
|
2142
|
+
|
|
2143
|
+
:param a: The destination array in global memory
|
|
2144
|
+
:param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
|
|
2145
|
+
:param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
|
|
2146
|
+
:param t: The source tile to store data from, must have the same dtype as the destination array""",
|
|
2147
|
+
group="Tile Primitives",
|
|
2148
|
+
export=False,
|
|
2149
|
+
)
|
|
2150
|
+
|
|
2151
|
+
|
|
2152
|
+
def tile_atomic_add_value_func(arg_types, arg_values):
|
|
2153
|
+
# return generic type (for doc builds)
|
|
2154
|
+
if arg_types is None:
|
|
2155
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
2156
|
+
|
|
2157
|
+
if len(arg_types) != 4:
|
|
2158
|
+
raise RuntimeError("tile_atomic_add() requires 4 positional args")
|
|
2159
|
+
|
|
2160
|
+
if not is_array(arg_types["a"]):
|
|
2161
|
+
raise RuntimeError("tile_atomic_add() argument 0 must be an array")
|
|
2162
|
+
|
|
2163
|
+
if not type_is_int(arg_types["x"]):
|
|
2164
|
+
raise RuntimeError("tile_atomic_add() argument 1 must be an integer")
|
|
2165
|
+
|
|
2166
|
+
if not type_is_int(arg_types["y"]):
|
|
2167
|
+
raise RuntimeError("tile_atomic_add() argument 2 must be an integer")
|
|
2168
|
+
|
|
2169
|
+
if not is_tile(arg_types["t"]):
|
|
2170
|
+
raise RuntimeError("tile_atomic_add() argument 3 must be a tile")
|
|
2171
|
+
|
|
2172
|
+
if arg_types["a"].dtype != arg_types["t"].dtype:
|
|
2173
|
+
raise RuntimeError("tile_atomic_add() tile dtype and array dtype must match")
|
|
2174
|
+
|
|
2175
|
+
return Tile(dtype=arg_types["t"].dtype, M=arg_types["t"].M, N=arg_types["t"].N)
|
|
2176
|
+
|
|
2177
|
+
|
|
2178
|
+
add_builtin(
|
|
2179
|
+
"tile_atomic_add",
|
|
2180
|
+
input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Any},
|
|
2181
|
+
value_func=tile_atomic_add_value_func,
|
|
2182
|
+
variadic=True,
|
|
2183
|
+
skip_replay=True,
|
|
2184
|
+
doc="""Atomically add a tile to the array `a`, each element will be updated atomically.
|
|
2185
|
+
|
|
2186
|
+
:param a: Array in global memory, should have the same ``dtype`` as the input tile
|
|
2187
|
+
:param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
|
|
2188
|
+
:param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
|
|
2189
|
+
:param t: Source tile to add to the destination array
|
|
2190
|
+
:returns: A tile with the same dimensions and type as the source tile, holding the original value of the destination elements""",
|
|
2191
|
+
group="Tile Primitives",
|
|
2192
|
+
export=False,
|
|
2193
|
+
)
|
|
2194
|
+
|
|
2195
|
+
|
|
2196
|
+
def tile_view_value_func(arg_types, arg_values):
|
|
2197
|
+
# return generic type (for doc builds)
|
|
2198
|
+
if arg_types is None:
|
|
2199
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
2200
|
+
|
|
2201
|
+
tile = arg_types["t"]
|
|
2202
|
+
|
|
2203
|
+
if "m" not in arg_values:
|
|
2204
|
+
m = 1
|
|
2205
|
+
else:
|
|
2206
|
+
m = arg_values["m"]
|
|
2207
|
+
|
|
2208
|
+
if "n" not in arg_values:
|
|
2209
|
+
n = tile.N
|
|
2210
|
+
else:
|
|
2211
|
+
n = arg_values["n"]
|
|
2212
|
+
|
|
2213
|
+
if m > tile.M or n > tile.N:
|
|
2214
|
+
raise RuntimeError(
|
|
2215
|
+
f"Trying to view a tile subrange with dimensions ({m}, {n}) which is larger than source tile with dimensions ({tile.M}, {tile.N})"
|
|
2216
|
+
)
|
|
2217
|
+
|
|
2218
|
+
# force source tile to shared memory
|
|
2219
|
+
tile.storage = "shared"
|
|
2220
|
+
|
|
2221
|
+
output = Tile(dtype=tile.dtype, M=m, N=n, strides=tile.strides, layout=tile.layout, storage="shared", owner=False)
|
|
2222
|
+
return output
|
|
2223
|
+
|
|
2224
|
+
|
|
2225
|
+
def tile_view_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2226
|
+
tile = arg_values["t"]
|
|
2227
|
+
i = arg_values["i"]
|
|
2228
|
+
|
|
2229
|
+
if "j" not in arg_values:
|
|
2230
|
+
j = warp.codegen.Var(label=None, type=int, constant=0)
|
|
2231
|
+
else:
|
|
2232
|
+
j = arg_values["j"]
|
|
2233
|
+
|
|
2234
|
+
template_args = []
|
|
2235
|
+
template_args.append(return_type.M)
|
|
2236
|
+
template_args.append(return_type.N)
|
|
2237
|
+
|
|
2238
|
+
return ((tile, i, j), template_args)
|
|
2239
|
+
|
|
2240
|
+
|
|
2241
|
+
add_builtin(
|
|
2242
|
+
"tile_view",
|
|
2243
|
+
input_types={"t": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int, "m": int, "n": int},
|
|
2244
|
+
value_func=tile_view_value_func,
|
|
2245
|
+
dispatch_func=tile_view_dispatch_func,
|
|
2246
|
+
defaults={"j": None, "m": None, "n": None},
|
|
2247
|
+
variadic=True,
|
|
2248
|
+
doc="""Return a subrange of a given tile from coordinates (i,j) to (i+m, j+n).
|
|
2249
|
+
|
|
2250
|
+
:param t: Input tile to extract a subrange from
|
|
2251
|
+
:param i: Offset in the source tile along the first dimension
|
|
2252
|
+
:param j: Offset in the source tile along the second dimensions
|
|
2253
|
+
:param m: Size of the subrange to return along the first dimension
|
|
2254
|
+
:param n: Size of the subrange to return along the second dimension
|
|
2255
|
+
:returns: A tile with dimensions (m,n) and the same datatype as the input tile""",
|
|
2256
|
+
group="Tile Primitives",
|
|
2257
|
+
export=False,
|
|
2258
|
+
)
|
|
2259
|
+
|
|
2260
|
+
|
|
2261
|
+
def tile_assign_value_func(arg_types, arg_values):
|
|
2262
|
+
# return generic type (for doc builds)
|
|
2263
|
+
return None
|
|
2264
|
+
|
|
2265
|
+
|
|
2266
|
+
add_builtin(
|
|
2267
|
+
"tile_assign",
|
|
2268
|
+
input_types={"dst": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int, "src": Tile(dtype=Any, M=Any, N=Any)},
|
|
2269
|
+
value_func=tile_assign_value_func,
|
|
2270
|
+
# dispatch_func=tile_assign_dispatch_func,
|
|
2271
|
+
doc="""Assign a tile to a subrange of a destination tile at coordinates (i,j).
|
|
2272
|
+
|
|
2273
|
+
:param t: The destination tile to assign to
|
|
2274
|
+
:param i: Offset in the source tile along the first dimension
|
|
2275
|
+
:param j: Offset in the source tile along the second dimensions
|
|
2276
|
+
:param src: The source tile to read values from""",
|
|
2277
|
+
group="Tile Primitives",
|
|
2278
|
+
export=False,
|
|
2279
|
+
)
|
|
2280
|
+
|
|
2281
|
+
|
|
2282
|
+
def tile_value_func(arg_types, arg_values):
|
|
2283
|
+
# return generic type (for doc builds)
|
|
2284
|
+
if arg_types is None:
|
|
2285
|
+
return Tile
|
|
2286
|
+
|
|
2287
|
+
if len(arg_types) != 1:
|
|
2288
|
+
raise RuntimeError("tile() requires 1 positional arg")
|
|
2289
|
+
|
|
2290
|
+
dtype = None
|
|
2291
|
+
length = None
|
|
2292
|
+
|
|
2293
|
+
if type_is_vector(arg_types["x"]):
|
|
2294
|
+
dtype = arg_types["x"]._wp_scalar_type_
|
|
2295
|
+
length = arg_types["x"]._shape_[0]
|
|
2296
|
+
else:
|
|
2297
|
+
dtype = arg_types["x"]
|
|
2298
|
+
length = 1
|
|
2299
|
+
|
|
2300
|
+
return Tile(dtype=dtype, M=length, N=warp.codegen.options["block_dim"], op="tile")
|
|
2301
|
+
|
|
2302
|
+
|
|
2303
|
+
add_builtin(
|
|
2304
|
+
"tile",
|
|
2305
|
+
input_types={"x": Any},
|
|
2306
|
+
value_func=tile_value_func,
|
|
2307
|
+
variadic=True,
|
|
2308
|
+
doc="""Constructs a new Tile from per-thread kernel values.
|
|
2309
|
+
|
|
2310
|
+
This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
|
|
2311
|
+
|
|
2312
|
+
* If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
|
|
2313
|
+
* If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
|
|
2314
|
+
|
|
2315
|
+
:param x: A per-thread local value, e.g.: scalar, vector, or matrix.
|
|
2316
|
+
:returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
|
|
2317
|
+
|
|
2318
|
+
This example shows how to create a linear sequence from thread variables:
|
|
2319
|
+
|
|
2320
|
+
.. code-block:: python
|
|
2321
|
+
|
|
2322
|
+
@wp.kernel
|
|
2323
|
+
def compute():
|
|
2324
|
+
i = wp.tid()
|
|
2325
|
+
t = wp.tile(i*2)
|
|
2326
|
+
print(t)
|
|
2327
|
+
|
|
2328
|
+
wp.launch(compute, dim=16, inputs=[], block_dim=16)
|
|
2329
|
+
|
|
2330
|
+
Prints:
|
|
2331
|
+
|
|
2332
|
+
.. code-block:: text
|
|
2333
|
+
|
|
2334
|
+
tile(m=1, n=16, storage=register) = [[0 2 4 6 8 ...]]
|
|
2335
|
+
|
|
2336
|
+
""",
|
|
2337
|
+
group="Tile Primitives",
|
|
2338
|
+
export=False,
|
|
2339
|
+
)
|
|
2340
|
+
|
|
2341
|
+
|
|
2342
|
+
def untile_value_func(arg_types, arg_values):
|
|
2343
|
+
# return generic type (for doc builds)
|
|
2344
|
+
if arg_types is None:
|
|
2345
|
+
return Scalar
|
|
2346
|
+
|
|
2347
|
+
if len(arg_types) != 1:
|
|
2348
|
+
raise RuntimeError("untile() requires 1 positional arg")
|
|
2349
|
+
|
|
2350
|
+
t = arg_types["a"]
|
|
2351
|
+
|
|
2352
|
+
if not is_tile(t):
|
|
2353
|
+
raise RuntimeError(f"untile() accepts arguments of type tile only, got {arg_types[0]}")
|
|
2354
|
+
|
|
2355
|
+
if t.N != warp.codegen.options["block_dim"]:
|
|
2356
|
+
raise RuntimeError(
|
|
2357
|
+
f"untile() argument must have the same length as the block width, got {t.N}, expected {warp.codegen.options['block_dim']}"
|
|
2358
|
+
)
|
|
2359
|
+
|
|
2360
|
+
if t.M == 1:
|
|
2361
|
+
return t.dtype
|
|
2362
|
+
elif t.M > 1:
|
|
2363
|
+
return warp.types.vector(t.M, t.dtype)
|
|
2364
|
+
|
|
2365
|
+
|
|
2366
|
+
add_builtin(
|
|
2367
|
+
"untile",
|
|
2368
|
+
input_types={"a": Any},
|
|
2369
|
+
value_func=untile_value_func,
|
|
2370
|
+
variadic=True,
|
|
2371
|
+
doc="""Convert a Tile back to per-thread values.
|
|
2372
|
+
|
|
2373
|
+
This function converts a block-wide tile back to per-thread values.
|
|
2374
|
+
|
|
2375
|
+
* If the input tile is 1-dimensional then the resulting value will be a per-thread scalar
|
|
2376
|
+
* If the input tile is 2-dimensional then the resulting value will be a per-thread vector of length M
|
|
2377
|
+
|
|
2378
|
+
:param a: A tile with dimensions ``shape=(M, block_dim)``
|
|
2379
|
+
:returns: A single value per-thread with the same dtype as the tile
|
|
2380
|
+
|
|
2381
|
+
This example shows how to create a linear sequence from thread variables:
|
|
2382
|
+
|
|
2383
|
+
.. code-block:: python
|
|
2384
|
+
|
|
2385
|
+
@wp.kernel
|
|
2386
|
+
def compute():
|
|
2387
|
+
i = wp.tid()
|
|
2388
|
+
|
|
2389
|
+
# create block-wide tile
|
|
2390
|
+
t = wp.tile(i)*2
|
|
2391
|
+
|
|
2392
|
+
# convert back to per-thread values
|
|
2393
|
+
s = wp.untile()
|
|
2394
|
+
|
|
2395
|
+
print(s)
|
|
2396
|
+
|
|
2397
|
+
wp.launch(compute, dim=16, inputs=[], block_dim=16)
|
|
2398
|
+
|
|
2399
|
+
Prints:
|
|
2400
|
+
|
|
2401
|
+
.. code-block:: text
|
|
2402
|
+
|
|
2403
|
+
0
|
|
2404
|
+
2
|
|
2405
|
+
4
|
|
2406
|
+
6
|
|
2407
|
+
8
|
|
2408
|
+
...
|
|
2409
|
+
""",
|
|
2410
|
+
group="Tile Primitives",
|
|
2411
|
+
export=False,
|
|
2412
|
+
)
|
|
2413
|
+
|
|
2414
|
+
|
|
2415
|
+
def tile_extract_value_func(arg_types, arg_values):
|
|
2416
|
+
# return generic type (for doc builds)
|
|
2417
|
+
if arg_types is None:
|
|
2418
|
+
return Scalar
|
|
2419
|
+
|
|
2420
|
+
if len(arg_types) != 3:
|
|
2421
|
+
raise RuntimeError("tile_extract() requires 3 positional args")
|
|
2422
|
+
|
|
2423
|
+
if not is_tile(arg_types["a"]):
|
|
2424
|
+
raise RuntimeError("tile_extract() argument 0 must be a tile")
|
|
2425
|
+
|
|
2426
|
+
return arg_types["a"].dtype
|
|
2427
|
+
|
|
2428
|
+
|
|
2429
|
+
add_builtin(
|
|
2430
|
+
"tile_extract",
|
|
2431
|
+
input_types={"a": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int},
|
|
2432
|
+
value_func=tile_extract_value_func,
|
|
2433
|
+
variadic=True,
|
|
2434
|
+
doc="""Extracts a single element from the tile and returns it as a scalar type.
|
|
2435
|
+
|
|
2436
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2437
|
+
|
|
2438
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
2439
|
+
|
|
2440
|
+
:param a: Tile to extract the element from
|
|
2441
|
+
:param i: Coordinate of element on first dimension
|
|
2442
|
+
:param j: Coordinate of element on the second dimension
|
|
2443
|
+
:returns: The value of the element at the specified tile location, with the same type as the input tile's per-element dtype""",
|
|
2444
|
+
group="Tile Primitives",
|
|
2445
|
+
export=False,
|
|
2446
|
+
)
|
|
2447
|
+
|
|
2448
|
+
|
|
2449
|
+
def tile_transpose_value_func(arg_types, arg_values):
|
|
2450
|
+
# return generic type (for doc builds)
|
|
2451
|
+
if arg_types is None:
|
|
2452
|
+
return Tile
|
|
2453
|
+
|
|
2454
|
+
if len(arg_types) != 1:
|
|
2455
|
+
raise RuntimeError("tile_transpose() requires 1 positional args")
|
|
2456
|
+
|
|
2457
|
+
t = arg_types["a"]
|
|
2458
|
+
|
|
2459
|
+
if not is_tile(t):
|
|
2460
|
+
raise RuntimeError("tile_transpose() argument 0 must be a tile")
|
|
2461
|
+
|
|
2462
|
+
layout = None
|
|
2463
|
+
|
|
2464
|
+
# flip layout
|
|
2465
|
+
if t.layout == "rowmajor":
|
|
2466
|
+
layout = "colmajor"
|
|
2467
|
+
elif t.layout == "colmajor":
|
|
2468
|
+
layout = "rowmajor"
|
|
2469
|
+
|
|
2470
|
+
# force the input tile to shared memory
|
|
2471
|
+
t.storage = "shared"
|
|
2472
|
+
|
|
2473
|
+
return Tile(
|
|
2474
|
+
dtype=t.dtype,
|
|
2475
|
+
M=t.N,
|
|
2476
|
+
N=t.M,
|
|
2477
|
+
op="transpose",
|
|
2478
|
+
storage=t.storage,
|
|
2479
|
+
strides=t.strides[::-1],
|
|
2480
|
+
layout=layout,
|
|
2481
|
+
owner=False,
|
|
2482
|
+
)
|
|
2483
|
+
|
|
2484
|
+
|
|
2485
|
+
add_builtin(
|
|
2486
|
+
"tile_transpose",
|
|
2487
|
+
input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
|
|
2488
|
+
value_func=tile_transpose_value_func,
|
|
2489
|
+
variadic=True,
|
|
2490
|
+
doc="""Transpose a tile.
|
|
2491
|
+
|
|
2492
|
+
For shared memory tiles this operation will alias the input tile, register tiles will first be transferred to shared memory before transposition.
|
|
2493
|
+
|
|
2494
|
+
:param a: Tile to transpose with ``shape=(M,N)``
|
|
2495
|
+
:returns: Tile with ``shape=(N,M)``""",
|
|
2496
|
+
group="Tile Primitives",
|
|
2497
|
+
export=False,
|
|
2498
|
+
)
|
|
2499
|
+
|
|
2500
|
+
|
|
2501
|
+
def tile_broadcast_value_func(arg_types, arg_values):
|
|
2502
|
+
# return generic type (for doc builds)
|
|
2503
|
+
if arg_types is None:
|
|
2504
|
+
return Tile
|
|
2505
|
+
|
|
2506
|
+
if len(arg_types) != 3:
|
|
2507
|
+
raise RuntimeError("tile_broadcast() requires 1 positional args")
|
|
2508
|
+
|
|
2509
|
+
t = arg_types["a"]
|
|
2510
|
+
m = arg_values["m"]
|
|
2511
|
+
n = arg_values["n"]
|
|
2512
|
+
|
|
2513
|
+
if not is_tile(t):
|
|
2514
|
+
raise RuntimeError("tile_broadcast() argument 0 must be a tile")
|
|
2515
|
+
|
|
2516
|
+
# try to broadcast last dimension
|
|
2517
|
+
if t.N == 1:
|
|
2518
|
+
stride_n = 0
|
|
2519
|
+
elif t.N == n:
|
|
2520
|
+
stride_n = t.strides[1]
|
|
2521
|
+
else:
|
|
2522
|
+
raise RuntimeError(
|
|
2523
|
+
f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
|
|
2524
|
+
)
|
|
2525
|
+
|
|
2526
|
+
# try to broadcast first dimension
|
|
2527
|
+
if t.M == 1:
|
|
2528
|
+
stride_m = 0
|
|
2529
|
+
elif t.M == m:
|
|
2530
|
+
stride_m = t.strides[0]
|
|
2531
|
+
else:
|
|
2532
|
+
raise RuntimeError(
|
|
2533
|
+
f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
|
|
2534
|
+
)
|
|
2535
|
+
|
|
2536
|
+
# force the input tile to shared memory
|
|
2537
|
+
t.storage = "shared"
|
|
2538
|
+
|
|
2539
|
+
tile_type = Tile(
|
|
2540
|
+
dtype=t.dtype, M=m, N=n, op="broadcast", storage=t.storage, strides=(stride_m, stride_n), owner=False
|
|
2541
|
+
)
|
|
2542
|
+
return tile_type
|
|
2543
|
+
|
|
2544
|
+
|
|
2545
|
+
def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2546
|
+
tile = arg_values["a"]
|
|
2547
|
+
|
|
2548
|
+
template_args = []
|
|
2549
|
+
template_args.append(return_type.M)
|
|
2550
|
+
template_args.append(return_type.N)
|
|
2551
|
+
template_args.append(return_type.strides[0])
|
|
2552
|
+
template_args.append(return_type.strides[1])
|
|
2553
|
+
|
|
2554
|
+
return ((tile,), template_args)
|
|
2555
|
+
|
|
2556
|
+
|
|
2557
|
+
add_builtin(
|
|
2558
|
+
"tile_broadcast",
|
|
2559
|
+
input_types={"a": Tile(dtype=Any, M=Any, N=Any), "m": int, "n": int},
|
|
2560
|
+
value_func=tile_broadcast_value_func,
|
|
2561
|
+
dispatch_func=tile_broadcast_dispatch_func,
|
|
2562
|
+
variadic=True,
|
|
2563
|
+
doc="""Broadcast a tile.
|
|
2564
|
+
|
|
2565
|
+
This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
|
|
2566
|
+
|
|
2567
|
+
:param a: Tile to broadcast
|
|
2568
|
+
:returns: Tile with broadcast ``shape=(m, n)``""",
|
|
2569
|
+
group="Tile Primitives",
|
|
2570
|
+
export=False,
|
|
2571
|
+
)
|
|
2572
|
+
|
|
2573
|
+
|
|
2574
|
+
def tile_matmul_value_func(arg_types, arg_values):
|
|
2575
|
+
# return generic type (for doc builds)
|
|
2576
|
+
if arg_types is None:
|
|
2577
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
2578
|
+
|
|
2579
|
+
if len(arg_types) != 3:
|
|
2580
|
+
raise RuntimeError("tile_matmul() requires 4 positional args")
|
|
2581
|
+
|
|
2582
|
+
if not is_tile(arg_types["a"]):
|
|
2583
|
+
raise RuntimeError("tile_matmul() argument 0 must be a tile")
|
|
2584
|
+
|
|
2585
|
+
if not is_tile(arg_types["b"]):
|
|
2586
|
+
raise RuntimeError("tile_matmul() argument 1 must be an tile")
|
|
2587
|
+
|
|
2588
|
+
if not isinstance(arg_types["out"], Tile):
|
|
2589
|
+
raise RuntimeError("tile_matmul() output argument must be a tile")
|
|
2590
|
+
|
|
2591
|
+
return None
|
|
2592
|
+
|
|
2593
|
+
|
|
2594
|
+
def tile_matmul_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2595
|
+
a = arg_values["a"]
|
|
2596
|
+
b = arg_values["b"]
|
|
2597
|
+
out = arg_values["out"]
|
|
2598
|
+
|
|
2599
|
+
# force the storage type of the input variables to shared memory
|
|
2600
|
+
a.type.storage = "shared"
|
|
2601
|
+
b.type.storage = "shared"
|
|
2602
|
+
out.type.storage = "shared"
|
|
2603
|
+
|
|
2604
|
+
template_args = []
|
|
2605
|
+
return ((a, b, out), template_args)
|
|
2606
|
+
|
|
2607
|
+
|
|
2608
|
+
add_builtin(
|
|
2609
|
+
"tile_matmul_scalar",
|
|
2610
|
+
input_types={"a": Tile, "b": Tile, "out": Tile},
|
|
2611
|
+
value_func=tile_matmul_value_func,
|
|
2612
|
+
dispatch_func=tile_matmul_dispatch_func,
|
|
2613
|
+
variadic=True,
|
|
2614
|
+
doc="Compute matrix product and accumulate out += a*b.",
|
|
2615
|
+
group="Tile Primitives",
|
|
2616
|
+
hidden=True,
|
|
2617
|
+
export=False,
|
|
2618
|
+
)
|
|
2619
|
+
|
|
2620
|
+
|
|
2621
|
+
def tile_sum_value_func(arg_types, arg_values):
|
|
2622
|
+
# return generic type (for doc builds)
|
|
2623
|
+
if arg_types is None:
|
|
2624
|
+
return Tile(dtype=Any, M=1, N=1)
|
|
2625
|
+
|
|
2626
|
+
if len(arg_types) != 1:
|
|
2627
|
+
raise RuntimeError("tile_sum() requires 1 positional args")
|
|
2628
|
+
|
|
2629
|
+
a = arg_types["a"]
|
|
2630
|
+
|
|
2631
|
+
if not is_tile(a):
|
|
2632
|
+
raise RuntimeError("tile_sum() argument 0 must be a tile")
|
|
2633
|
+
|
|
2634
|
+
return Tile(dtype=a.dtype, M=1, N=1, op="sum")
|
|
2635
|
+
|
|
2636
|
+
|
|
2637
|
+
add_builtin(
|
|
2638
|
+
"tile_sum",
|
|
2639
|
+
input_types={"a": Tile},
|
|
2640
|
+
value_func=tile_sum_value_func,
|
|
2641
|
+
variadic=True,
|
|
2642
|
+
doc="""Cooperatively compute the sum of the tile elements using all threads in the block.
|
|
2643
|
+
|
|
2644
|
+
:param a: The tile to compute the sum of
|
|
2645
|
+
:returns: A single-element tile with dimensions of (1,1) holding the sum
|
|
2646
|
+
|
|
2647
|
+
Example:
|
|
2648
|
+
|
|
2649
|
+
.. code-block:: python
|
|
2650
|
+
|
|
2651
|
+
@wp.kernel
|
|
2652
|
+
def compute():
|
|
2653
|
+
|
|
2654
|
+
t = wp.tile_ones(dtype=float, m=16, n=16)
|
|
2655
|
+
s = wp.tile_sum(t)
|
|
2656
|
+
|
|
2657
|
+
print(t)
|
|
2658
|
+
|
|
2659
|
+
wp.launch(compute, dim=[64], inputs=[])
|
|
2660
|
+
|
|
2661
|
+
Prints:
|
|
2662
|
+
|
|
2663
|
+
.. code-block:: text
|
|
2664
|
+
|
|
2665
|
+
tile(m=1, n=1, storage=register) = [[256]]
|
|
2666
|
+
|
|
2667
|
+
""",
|
|
2668
|
+
group="Tile Primitives",
|
|
2669
|
+
export=False,
|
|
2670
|
+
)
|
|
2671
|
+
|
|
2672
|
+
|
|
2673
|
+
def tile_min_value_func(arg_types, arg_values):
|
|
2674
|
+
# return generic type (for doc builds)
|
|
2675
|
+
if arg_types is None:
|
|
2676
|
+
return Tile(dtype=Any, M=1, N=1)
|
|
2677
|
+
|
|
2678
|
+
if len(arg_types) != 1:
|
|
2679
|
+
raise RuntimeError("tile_min() requires 1 positional args")
|
|
2680
|
+
|
|
2681
|
+
a = arg_types["a"]
|
|
2682
|
+
|
|
2683
|
+
if not is_tile(a):
|
|
2684
|
+
raise RuntimeError("tile_min() argument 0 must be a tile")
|
|
2685
|
+
|
|
2686
|
+
return Tile(dtype=a.dtype, M=1, N=1, op="min")
|
|
2687
|
+
|
|
2688
|
+
|
|
2689
|
+
add_builtin(
|
|
2690
|
+
"tile_min",
|
|
2691
|
+
input_types={"a": Tile},
|
|
2692
|
+
value_func=tile_min_value_func,
|
|
2693
|
+
variadic=True,
|
|
2694
|
+
doc="""Cooperatively compute the minimum of the tile elements using all threads in the block.
|
|
2695
|
+
|
|
2696
|
+
:param a: The tile to compute the minimum of
|
|
2697
|
+
:returns: A single-element tile with dimensions of (1,1) holding the minimum value
|
|
2698
|
+
|
|
2699
|
+
Example:
|
|
2700
|
+
|
|
2701
|
+
.. code-block:: python
|
|
2702
|
+
|
|
2703
|
+
@wp.kernel
|
|
2704
|
+
def compute():
|
|
2705
|
+
|
|
2706
|
+
t = wp.tile_arange(start=--10, stop=10, dtype=float)
|
|
2707
|
+
s = wp.tile_min(t)
|
|
2708
|
+
|
|
2709
|
+
print(t)
|
|
2710
|
+
|
|
2711
|
+
wp.launch(compute, dim=[64], inputs=[])
|
|
2712
|
+
|
|
2713
|
+
Prints:
|
|
2714
|
+
|
|
2715
|
+
.. code-block:: text
|
|
2716
|
+
|
|
2717
|
+
tile(m=1, n=1, storage=register) = [[-10]]
|
|
2718
|
+
|
|
2719
|
+
""",
|
|
2720
|
+
group="Tile Primitives",
|
|
2721
|
+
export=False,
|
|
2722
|
+
)
|
|
2723
|
+
|
|
2724
|
+
|
|
2725
|
+
def tile_max_value_func(arg_types, arg_values):
|
|
2726
|
+
# return generic type (for doc builds)
|
|
2727
|
+
if arg_types is None:
|
|
2728
|
+
return Tile(dtype=Any, M=1, N=1)
|
|
2729
|
+
|
|
2730
|
+
if len(arg_types) != 1:
|
|
2731
|
+
raise RuntimeError("tile_max() requires 1 positional args")
|
|
2732
|
+
|
|
2733
|
+
a = arg_types["a"]
|
|
2734
|
+
|
|
2735
|
+
if not is_tile(a):
|
|
2736
|
+
raise RuntimeError("tile_max() argument 0 must be a tile")
|
|
2737
|
+
|
|
2738
|
+
return Tile(dtype=a.dtype, M=1, N=1, op="min")
|
|
2739
|
+
|
|
2740
|
+
|
|
2741
|
+
add_builtin(
|
|
2742
|
+
"tile_max",
|
|
2743
|
+
input_types={"a": Tile},
|
|
2744
|
+
value_func=tile_max_value_func,
|
|
2745
|
+
variadic=True,
|
|
2746
|
+
doc="""Cooperatively compute the maximum of the tile elements using all threads in the block.
|
|
2747
|
+
|
|
2748
|
+
:param a: The tile to compute the maximum from
|
|
2749
|
+
:returns: A single-element tile with dimensions of (1,1) holding the maximum value
|
|
2750
|
+
|
|
2751
|
+
Example:
|
|
2752
|
+
|
|
2753
|
+
.. code-block:: python
|
|
2754
|
+
|
|
2755
|
+
@wp.kernel
|
|
2756
|
+
def compute():
|
|
2757
|
+
|
|
2758
|
+
t = wp.tile_arange(start=--10, stop=10, dtype=float)
|
|
2759
|
+
s = wp.tile_min(t)
|
|
2760
|
+
|
|
2761
|
+
print(t)
|
|
2762
|
+
|
|
2763
|
+
wp.launch(compute, dim=[64], inputs=[])
|
|
2764
|
+
|
|
2765
|
+
Prints:
|
|
2766
|
+
|
|
2767
|
+
.. code-block:: text
|
|
2768
|
+
|
|
2769
|
+
tile(m=1, n=1, storage=register) = [[10]]
|
|
2770
|
+
|
|
2771
|
+
""",
|
|
2772
|
+
group="Tile Primitives",
|
|
2773
|
+
export=False,
|
|
2774
|
+
)
|
|
2775
|
+
|
|
2776
|
+
|
|
2777
|
+
# does type propagation for load()
|
|
2778
|
+
def tile_reduce_value_func(arg_types, arg_values):
|
|
2779
|
+
if arg_types is None:
|
|
2780
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
2781
|
+
|
|
2782
|
+
a = arg_types["a"]
|
|
2783
|
+
|
|
2784
|
+
# check all args are tiles
|
|
2785
|
+
if not is_tile(a):
|
|
2786
|
+
raise RuntimeError(f"tile_reduce() arguments must be tiles, got type {a}")
|
|
2787
|
+
|
|
2788
|
+
return Tile(dtype=a.dtype, M=1, N=1, op="reduce")
|
|
2789
|
+
|
|
2790
|
+
|
|
2791
|
+
def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2792
|
+
func_args = (args["op"], *args["args"])
|
|
2793
|
+
template_args = ()
|
|
2794
|
+
return (func_args, template_args)
|
|
2795
|
+
|
|
2796
|
+
|
|
2797
|
+
add_builtin(
|
|
2798
|
+
"tile_reduce",
|
|
2799
|
+
input_types={"op": Callable, "a": Any},
|
|
2800
|
+
value_func=tile_reduce_value_func,
|
|
2801
|
+
native_func="tile_reduce",
|
|
2802
|
+
doc="""Apply a custom reduction operator across the tile.
|
|
2803
|
+
|
|
2804
|
+
This function cooperatively performs a reduction using the provided operator across the tile.
|
|
2805
|
+
|
|
2806
|
+
:param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
|
|
2807
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
2808
|
+
:returns: A single-element tile with ``shape=(1,1)`` with the same datatype as the input tile.
|
|
2809
|
+
|
|
2810
|
+
Example:
|
|
2811
|
+
|
|
2812
|
+
.. code-block:: python
|
|
2813
|
+
|
|
2814
|
+
@wp.kernel
|
|
2815
|
+
def factorial():
|
|
2816
|
+
|
|
2817
|
+
t = wp.tile_arange(1, 10, dtype=int)
|
|
2818
|
+
s = wp.tile_reduce(wp.mul, t)
|
|
2819
|
+
|
|
2820
|
+
print(s)
|
|
2821
|
+
|
|
2822
|
+
wp.launch(factorial, dim=[16], inputs=[], block_dim=16)
|
|
2823
|
+
|
|
2824
|
+
Prints:
|
|
2825
|
+
|
|
2826
|
+
.. code-block:: text
|
|
2827
|
+
|
|
2828
|
+
tile(m=1, n=1, storage=register) = [[362880]]
|
|
2829
|
+
""",
|
|
2830
|
+
group="Tile Primitives",
|
|
2831
|
+
export=False,
|
|
2832
|
+
)
|
|
2833
|
+
|
|
2834
|
+
# maps
|
|
2835
|
+
|
|
2836
|
+
|
|
2837
|
+
# does type propagation for load()
|
|
2838
|
+
def tile_unary_map_value_func(arg_types, arg_values):
|
|
2839
|
+
if arg_types is None:
|
|
2840
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
2841
|
+
|
|
2842
|
+
a = arg_types["a"]
|
|
2843
|
+
|
|
2844
|
+
# check all args are tiles
|
|
2845
|
+
if not is_tile(a):
|
|
2846
|
+
raise RuntimeError(f"tile_map() arguments must be tiles, got type {a}")
|
|
2847
|
+
|
|
2848
|
+
return TileUnaryMap(a)
|
|
2849
|
+
|
|
2850
|
+
|
|
2851
|
+
# def tile_map_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2852
|
+
# func_args = (args["op"], *args["args"])
|
|
2853
|
+
# template_args = ()
|
|
2854
|
+
# return (func_args, template_args)
|
|
2855
|
+
|
|
2856
|
+
|
|
2857
|
+
add_builtin(
|
|
2858
|
+
"tile_map",
|
|
2859
|
+
input_types={"op": Callable, "a": Any},
|
|
2860
|
+
value_func=tile_unary_map_value_func,
|
|
2861
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
2862
|
+
# variadic=True,
|
|
2863
|
+
native_func="tile_unary_map",
|
|
2864
|
+
doc="""Apply a unary function onto the tile.
|
|
2865
|
+
|
|
2866
|
+
This function cooperatively applies a unary function to each element of the tile using all threads in the block.
|
|
2867
|
+
|
|
2868
|
+
:param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
|
|
2869
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
2870
|
+
:returns: A tile with the same dimensions and datatype as the input tile.
|
|
2871
|
+
|
|
2872
|
+
Example:
|
|
2873
|
+
|
|
2874
|
+
.. code-block:: python
|
|
2875
|
+
|
|
2876
|
+
@wp.kernel
|
|
2877
|
+
def compute():
|
|
2878
|
+
|
|
2879
|
+
t = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
|
|
2880
|
+
s = wp.tile_map(wp.sin, t)
|
|
2881
|
+
|
|
2882
|
+
print(s)
|
|
2883
|
+
|
|
2884
|
+
wp.launch(compute, dim=[16], inputs=[])
|
|
2885
|
+
|
|
2886
|
+
Prints:
|
|
2887
|
+
|
|
2888
|
+
.. code-block:: text
|
|
2889
|
+
|
|
2890
|
+
tile(m=1, n=10, storage=register) = [[0 0.0998334 0.198669 0.29552 ...]]
|
|
2891
|
+
""",
|
|
2892
|
+
group="Tile Primitives",
|
|
1630
2893
|
export=False,
|
|
1631
2894
|
)
|
|
1632
|
-
add_builtin(
|
|
1633
|
-
"spatial_dot",
|
|
1634
|
-
input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
|
|
1635
|
-
value_func=float_sametypes_value_func,
|
|
1636
|
-
group="Spatial Math",
|
|
1637
|
-
doc="Compute the dot product of two 6D screw vectors.",
|
|
1638
|
-
)
|
|
1639
|
-
add_builtin(
|
|
1640
|
-
"spatial_cross",
|
|
1641
|
-
input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
|
|
1642
|
-
value_func=sametypes_create_value_func(vector(length=6, dtype=Float)),
|
|
1643
|
-
group="Spatial Math",
|
|
1644
|
-
doc="Compute the cross product of two 6D screw vectors.",
|
|
1645
|
-
)
|
|
1646
|
-
add_builtin(
|
|
1647
|
-
"spatial_cross_dual",
|
|
1648
|
-
input_types={"a": vector(length=6, dtype=Float), "b": vector(length=6, dtype=Float)},
|
|
1649
|
-
value_func=sametypes_create_value_func(vector(length=6, dtype=Float)),
|
|
1650
|
-
group="Spatial Math",
|
|
1651
|
-
doc="Compute the dual cross product of two 6D screw vectors.",
|
|
1652
|
-
)
|
|
1653
2895
|
|
|
1654
|
-
add_builtin(
|
|
1655
|
-
"spatial_top",
|
|
1656
|
-
input_types={"svec": vector(length=6, dtype=Float)},
|
|
1657
|
-
value_func=lambda arg_types, arg_values: vector(length=3, dtype=Float)
|
|
1658
|
-
if arg_types is None
|
|
1659
|
-
else vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
|
|
1660
|
-
group="Spatial Math",
|
|
1661
|
-
doc="Return the top (first) part of a 6D screw vector.",
|
|
1662
|
-
)
|
|
1663
|
-
add_builtin(
|
|
1664
|
-
"spatial_bottom",
|
|
1665
|
-
input_types={"svec": vector(length=6, dtype=Float)},
|
|
1666
|
-
value_func=lambda arg_types, arg_values: vector(length=3, dtype=Float)
|
|
1667
|
-
if arg_types is None
|
|
1668
|
-
else vector(length=3, dtype=arg_types["svec"]._wp_scalar_type_),
|
|
1669
|
-
group="Spatial Math",
|
|
1670
|
-
doc="Return the bottom (second) part of a 6D screw vector.",
|
|
1671
|
-
)
|
|
1672
2896
|
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
2897
|
+
def tile_binary_map_value_func(arg_types, arg_values):
|
|
2898
|
+
if arg_types is None:
|
|
2899
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
2900
|
+
|
|
2901
|
+
a = arg_types["a"]
|
|
2902
|
+
b = arg_types["b"]
|
|
2903
|
+
|
|
2904
|
+
# check all args are tiles
|
|
2905
|
+
if not is_tile(a):
|
|
2906
|
+
raise RuntimeError(f"tile_map() arguments must be tiles, got type {a}")
|
|
2907
|
+
|
|
2908
|
+
if not is_tile(b):
|
|
2909
|
+
raise RuntimeError(f"tile_map() arguments must be tiles, got type {b}")
|
|
2910
|
+
|
|
2911
|
+
# use first argument to define output type
|
|
2912
|
+
if not types_equal(a.dtype, b.dtype):
|
|
2913
|
+
raise RuntimeError(f"tile_map() arguments must all have the same type {a.dtype} != {b.dtype}")
|
|
2914
|
+
|
|
2915
|
+
if a.M != b.M:
|
|
2916
|
+
raise RuntimeError(f"tile_map() arguments must all have the same m dimension {a.M} != {b.M}")
|
|
2917
|
+
|
|
2918
|
+
if a.N != b.N:
|
|
2919
|
+
raise RuntimeError(f"tile_map() arguments must all have the same n dimension {a.N} != {b.N}")
|
|
2920
|
+
|
|
2921
|
+
return TileBinaryMap(a, b)
|
|
2922
|
+
|
|
1688
2923
|
|
|
1689
2924
|
add_builtin(
|
|
1690
|
-
"
|
|
1691
|
-
input_types={
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
2925
|
+
"tile_map",
|
|
2926
|
+
input_types={"op": Callable, "a": Any, "b": Any},
|
|
2927
|
+
value_func=tile_binary_map_value_func,
|
|
2928
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
2929
|
+
# variadic=True,
|
|
2930
|
+
native_func="tile_binary_map",
|
|
2931
|
+
doc="""Apply a binary function onto the tile.
|
|
2932
|
+
|
|
2933
|
+
This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
|
|
2934
|
+
Both input tiles must have the same dimensions and datatype.
|
|
2935
|
+
|
|
2936
|
+
:param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
|
|
2937
|
+
:param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
2938
|
+
:param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
|
|
2939
|
+
:returns: A tile with the same dimensions and datatype as the input tiles.
|
|
2940
|
+
|
|
2941
|
+
Example:
|
|
2942
|
+
|
|
2943
|
+
.. code-block:: python
|
|
2944
|
+
|
|
2945
|
+
@wp.kernel
|
|
2946
|
+
def compute():
|
|
2947
|
+
|
|
2948
|
+
a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
|
|
2949
|
+
b = wp.tile_ones(m=1, n=10, dtype=float)
|
|
2950
|
+
|
|
2951
|
+
s = wp.tile_map(wp.add, a, b)
|
|
2952
|
+
|
|
2953
|
+
print(s)
|
|
2954
|
+
|
|
2955
|
+
wp.launch(compute, dim=[16], inputs=[])
|
|
2956
|
+
|
|
2957
|
+
Prints:
|
|
2958
|
+
|
|
2959
|
+
.. code-block:: text
|
|
2960
|
+
|
|
2961
|
+
tile(m=1, n=10, storage=register) = [[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]]""",
|
|
2962
|
+
group="Tile Primitives",
|
|
2963
|
+
export=False,
|
|
1701
2964
|
)
|
|
1702
2965
|
|
|
2966
|
+
|
|
1703
2967
|
# ---------------------------------
|
|
1704
2968
|
# Linear Algebra
|
|
1705
2969
|
|
|
@@ -2387,6 +3651,16 @@ add_builtin(
|
|
|
2387
3651
|
"iter_next", input_types={"query": mesh_query_aabb_t}, value_type=int, group="Utility", export=False, hidden=True
|
|
2388
3652
|
)
|
|
2389
3653
|
|
|
3654
|
+
add_builtin(
|
|
3655
|
+
"reversed",
|
|
3656
|
+
input_types={"range": range_t},
|
|
3657
|
+
value_type=range_t,
|
|
3658
|
+
native_func="iter_reverse",
|
|
3659
|
+
group="Utility",
|
|
3660
|
+
doc="""Returns the range in reversed order.""",
|
|
3661
|
+
export=False,
|
|
3662
|
+
)
|
|
3663
|
+
|
|
2390
3664
|
# ---------------------------------
|
|
2391
3665
|
# Volumes
|
|
2392
3666
|
|
|
@@ -2802,7 +4076,11 @@ add_builtin(
|
|
|
2802
4076
|
doc="Return a random float between [low, high).",
|
|
2803
4077
|
)
|
|
2804
4078
|
add_builtin(
|
|
2805
|
-
"randn",
|
|
4079
|
+
"randn",
|
|
4080
|
+
input_types={"state": uint32},
|
|
4081
|
+
value_type=float,
|
|
4082
|
+
group="Random",
|
|
4083
|
+
doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
|
|
2806
4084
|
)
|
|
2807
4085
|
|
|
2808
4086
|
add_builtin(
|
|
@@ -2974,12 +4252,20 @@ add_builtin(
|
|
|
2974
4252
|
)
|
|
2975
4253
|
|
|
2976
4254
|
|
|
4255
|
+
def printf_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
4256
|
+
if arg_types is not None:
|
|
4257
|
+
if len(arg_types.get("args", ())) > 32:
|
|
4258
|
+
raise RuntimeError("the maximum number of variadic arguments that can be passed to `printf` is 32")
|
|
4259
|
+
|
|
4260
|
+
return None
|
|
4261
|
+
|
|
4262
|
+
|
|
2977
4263
|
def printf_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2978
4264
|
# We're in the codegen stage where we emit the code calling the built-in.
|
|
2979
4265
|
# Further validate the given argument values if needed and map them
|
|
2980
4266
|
# to the underlying C++ function's runtime and template params.
|
|
2981
4267
|
|
|
2982
|
-
func_args = (args["fmt"], *args
|
|
4268
|
+
func_args = (args["fmt"], *args.get("args", ()))
|
|
2983
4269
|
template_args = ()
|
|
2984
4270
|
return (func_args, template_args)
|
|
2985
4271
|
|
|
@@ -2990,6 +4276,7 @@ add_builtin(
|
|
|
2990
4276
|
input_types={"fmt": str, "*args": Any},
|
|
2991
4277
|
namespace="",
|
|
2992
4278
|
variadic=True,
|
|
4279
|
+
value_func=printf_value_func,
|
|
2993
4280
|
dispatch_func=printf_dispatch_func,
|
|
2994
4281
|
group="Utility",
|
|
2995
4282
|
doc="Allows printing formatted strings using C-style format specifiers.",
|
|
@@ -3167,7 +4454,8 @@ def address_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, A
|
|
|
3167
4454
|
for array_type in array_types:
|
|
3168
4455
|
add_builtin(
|
|
3169
4456
|
"address",
|
|
3170
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4457
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int},
|
|
4458
|
+
constraint=sametypes,
|
|
3171
4459
|
defaults={"j": None, "k": None, "l": None},
|
|
3172
4460
|
hidden=True,
|
|
3173
4461
|
value_func=address_value_func,
|
|
@@ -3211,8 +4499,9 @@ def view_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]
|
|
|
3211
4499
|
for array_type in array_types:
|
|
3212
4500
|
add_builtin(
|
|
3213
4501
|
"view",
|
|
3214
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4502
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int},
|
|
3215
4503
|
defaults={"j": None, "k": None},
|
|
4504
|
+
constraint=sametypes,
|
|
3216
4505
|
hidden=True,
|
|
3217
4506
|
value_func=view_value_func,
|
|
3218
4507
|
group="Utility",
|
|
@@ -3254,7 +4543,8 @@ def array_store_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
|
|
|
3254
4543
|
for array_type in array_types:
|
|
3255
4544
|
add_builtin(
|
|
3256
4545
|
"array_store",
|
|
3257
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4546
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4547
|
+
constraint=sametypes,
|
|
3258
4548
|
hidden=True,
|
|
3259
4549
|
value_func=array_store_value_func,
|
|
3260
4550
|
skip_replay=True,
|
|
@@ -3262,7 +4552,8 @@ for array_type in array_types:
|
|
|
3262
4552
|
)
|
|
3263
4553
|
add_builtin(
|
|
3264
4554
|
"array_store",
|
|
3265
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4555
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4556
|
+
constraint=sametypes,
|
|
3266
4557
|
hidden=True,
|
|
3267
4558
|
value_func=array_store_value_func,
|
|
3268
4559
|
skip_replay=True,
|
|
@@ -3270,7 +4561,8 @@ for array_type in array_types:
|
|
|
3270
4561
|
)
|
|
3271
4562
|
add_builtin(
|
|
3272
4563
|
"array_store",
|
|
3273
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4564
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4565
|
+
constraint=sametypes,
|
|
3274
4566
|
hidden=True,
|
|
3275
4567
|
value_func=array_store_value_func,
|
|
3276
4568
|
skip_replay=True,
|
|
@@ -3278,7 +4570,8 @@ for array_type in array_types:
|
|
|
3278
4570
|
)
|
|
3279
4571
|
add_builtin(
|
|
3280
4572
|
"array_store",
|
|
3281
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4573
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4574
|
+
constraint=sametypes,
|
|
3282
4575
|
hidden=True,
|
|
3283
4576
|
value_func=array_store_value_func,
|
|
3284
4577
|
skip_replay=True,
|
|
@@ -3330,6 +4623,11 @@ add_builtin(
|
|
|
3330
4623
|
)
|
|
3331
4624
|
|
|
3332
4625
|
|
|
4626
|
+
def atomic_op_constraint(arg_types: Mapping[str, Any]):
|
|
4627
|
+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
4628
|
+
return all(types_equal(idx_types[0], t) for t in idx_types[1:]) and arg_types["arr"].ndim == len(idx_types)
|
|
4629
|
+
|
|
4630
|
+
|
|
3333
4631
|
def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3334
4632
|
if arg_types is None:
|
|
3335
4633
|
return Any
|
|
@@ -3374,7 +4672,8 @@ for array_type in array_types:
|
|
|
3374
4672
|
add_builtin(
|
|
3375
4673
|
"atomic_add",
|
|
3376
4674
|
hidden=hidden,
|
|
3377
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4675
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4676
|
+
constraint=atomic_op_constraint,
|
|
3378
4677
|
value_func=atomic_op_value_func,
|
|
3379
4678
|
doc="Atomically add ``value`` onto ``arr[i]`` and return the old value.",
|
|
3380
4679
|
group="Utility",
|
|
@@ -3383,7 +4682,8 @@ for array_type in array_types:
|
|
|
3383
4682
|
add_builtin(
|
|
3384
4683
|
"atomic_add",
|
|
3385
4684
|
hidden=hidden,
|
|
3386
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4685
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4686
|
+
constraint=atomic_op_constraint,
|
|
3387
4687
|
value_func=atomic_op_value_func,
|
|
3388
4688
|
doc="Atomically add ``value`` onto ``arr[i,j]`` and return the old value.",
|
|
3389
4689
|
group="Utility",
|
|
@@ -3392,7 +4692,8 @@ for array_type in array_types:
|
|
|
3392
4692
|
add_builtin(
|
|
3393
4693
|
"atomic_add",
|
|
3394
4694
|
hidden=hidden,
|
|
3395
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4695
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4696
|
+
constraint=atomic_op_constraint,
|
|
3396
4697
|
value_func=atomic_op_value_func,
|
|
3397
4698
|
doc="Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value.",
|
|
3398
4699
|
group="Utility",
|
|
@@ -3401,7 +4702,8 @@ for array_type in array_types:
|
|
|
3401
4702
|
add_builtin(
|
|
3402
4703
|
"atomic_add",
|
|
3403
4704
|
hidden=hidden,
|
|
3404
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4705
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4706
|
+
constraint=atomic_op_constraint,
|
|
3405
4707
|
value_func=atomic_op_value_func,
|
|
3406
4708
|
doc="Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
|
|
3407
4709
|
group="Utility",
|
|
@@ -3411,7 +4713,8 @@ for array_type in array_types:
|
|
|
3411
4713
|
add_builtin(
|
|
3412
4714
|
"atomic_sub",
|
|
3413
4715
|
hidden=hidden,
|
|
3414
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4716
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4717
|
+
constraint=atomic_op_constraint,
|
|
3415
4718
|
value_func=atomic_op_value_func,
|
|
3416
4719
|
doc="Atomically subtract ``value`` onto ``arr[i]`` and return the old value.",
|
|
3417
4720
|
group="Utility",
|
|
@@ -3420,7 +4723,8 @@ for array_type in array_types:
|
|
|
3420
4723
|
add_builtin(
|
|
3421
4724
|
"atomic_sub",
|
|
3422
4725
|
hidden=hidden,
|
|
3423
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4726
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4727
|
+
constraint=atomic_op_constraint,
|
|
3424
4728
|
value_func=atomic_op_value_func,
|
|
3425
4729
|
doc="Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value.",
|
|
3426
4730
|
group="Utility",
|
|
@@ -3429,7 +4733,8 @@ for array_type in array_types:
|
|
|
3429
4733
|
add_builtin(
|
|
3430
4734
|
"atomic_sub",
|
|
3431
4735
|
hidden=hidden,
|
|
3432
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4736
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4737
|
+
constraint=atomic_op_constraint,
|
|
3433
4738
|
value_func=atomic_op_value_func,
|
|
3434
4739
|
doc="Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value.",
|
|
3435
4740
|
group="Utility",
|
|
@@ -3438,7 +4743,8 @@ for array_type in array_types:
|
|
|
3438
4743
|
add_builtin(
|
|
3439
4744
|
"atomic_sub",
|
|
3440
4745
|
hidden=hidden,
|
|
3441
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4746
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4747
|
+
constraint=atomic_op_constraint,
|
|
3442
4748
|
value_func=atomic_op_value_func,
|
|
3443
4749
|
doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
|
|
3444
4750
|
group="Utility",
|
|
@@ -3448,44 +4754,48 @@ for array_type in array_types:
|
|
|
3448
4754
|
add_builtin(
|
|
3449
4755
|
"atomic_min",
|
|
3450
4756
|
hidden=hidden,
|
|
3451
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4757
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4758
|
+
constraint=atomic_op_constraint,
|
|
3452
4759
|
value_func=atomic_op_value_func,
|
|
3453
4760
|
doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
3454
4761
|
|
|
3455
|
-
|
|
4762
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3456
4763
|
group="Utility",
|
|
3457
4764
|
skip_replay=True,
|
|
3458
4765
|
)
|
|
3459
4766
|
add_builtin(
|
|
3460
4767
|
"atomic_min",
|
|
3461
4768
|
hidden=hidden,
|
|
3462
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4769
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4770
|
+
constraint=atomic_op_constraint,
|
|
3463
4771
|
value_func=atomic_op_value_func,
|
|
3464
4772
|
doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
3465
4773
|
|
|
3466
|
-
|
|
4774
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3467
4775
|
group="Utility",
|
|
3468
4776
|
skip_replay=True,
|
|
3469
4777
|
)
|
|
3470
4778
|
add_builtin(
|
|
3471
4779
|
"atomic_min",
|
|
3472
4780
|
hidden=hidden,
|
|
3473
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4781
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4782
|
+
constraint=atomic_op_constraint,
|
|
3474
4783
|
value_func=atomic_op_value_func,
|
|
3475
4784
|
doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
3476
4785
|
|
|
3477
|
-
|
|
4786
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3478
4787
|
group="Utility",
|
|
3479
4788
|
skip_replay=True,
|
|
3480
4789
|
)
|
|
3481
4790
|
add_builtin(
|
|
3482
4791
|
"atomic_min",
|
|
3483
4792
|
hidden=hidden,
|
|
3484
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4793
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4794
|
+
constraint=atomic_op_constraint,
|
|
3485
4795
|
value_func=atomic_op_value_func,
|
|
3486
4796
|
doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
3487
4797
|
|
|
3488
|
-
|
|
4798
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3489
4799
|
group="Utility",
|
|
3490
4800
|
skip_replay=True,
|
|
3491
4801
|
)
|
|
@@ -3493,44 +4803,48 @@ for array_type in array_types:
|
|
|
3493
4803
|
add_builtin(
|
|
3494
4804
|
"atomic_max",
|
|
3495
4805
|
hidden=hidden,
|
|
3496
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4806
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4807
|
+
constraint=atomic_op_constraint,
|
|
3497
4808
|
value_func=atomic_op_value_func,
|
|
3498
4809
|
doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
3499
4810
|
|
|
3500
|
-
|
|
4811
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3501
4812
|
group="Utility",
|
|
3502
4813
|
skip_replay=True,
|
|
3503
4814
|
)
|
|
3504
4815
|
add_builtin(
|
|
3505
4816
|
"atomic_max",
|
|
3506
4817
|
hidden=hidden,
|
|
3507
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4818
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4819
|
+
constraint=atomic_op_constraint,
|
|
3508
4820
|
value_func=atomic_op_value_func,
|
|
3509
4821
|
doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
3510
4822
|
|
|
3511
|
-
|
|
4823
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3512
4824
|
group="Utility",
|
|
3513
4825
|
skip_replay=True,
|
|
3514
4826
|
)
|
|
3515
4827
|
add_builtin(
|
|
3516
4828
|
"atomic_max",
|
|
3517
4829
|
hidden=hidden,
|
|
3518
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4830
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4831
|
+
constraint=atomic_op_constraint,
|
|
3519
4832
|
value_func=atomic_op_value_func,
|
|
3520
4833
|
doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
3521
4834
|
|
|
3522
|
-
|
|
4835
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3523
4836
|
group="Utility",
|
|
3524
4837
|
skip_replay=True,
|
|
3525
4838
|
)
|
|
3526
4839
|
add_builtin(
|
|
3527
4840
|
"atomic_max",
|
|
3528
4841
|
hidden=hidden,
|
|
3529
|
-
input_types={"arr": array_type(dtype=Any), "i":
|
|
4842
|
+
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4843
|
+
constraint=atomic_op_constraint,
|
|
3530
4844
|
value_func=atomic_op_value_func,
|
|
3531
4845
|
doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
3532
4846
|
|
|
3533
|
-
|
|
4847
|
+
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
3534
4848
|
group="Utility",
|
|
3535
4849
|
skip_replay=True,
|
|
3536
4850
|
)
|
|
@@ -3746,6 +5060,15 @@ for t in scalar_types + vector_types + (bool,):
|
|
|
3746
5060
|
hidden=True,
|
|
3747
5061
|
)
|
|
3748
5062
|
|
|
5063
|
+
add_builtin(
|
|
5064
|
+
"expect_neq",
|
|
5065
|
+
input_types={"a": t, "b": t},
|
|
5066
|
+
value_type=None,
|
|
5067
|
+
doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
|
|
5068
|
+
group="Utility",
|
|
5069
|
+
hidden=True,
|
|
5070
|
+
)
|
|
5071
|
+
|
|
3749
5072
|
|
|
3750
5073
|
def expect_eq_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
3751
5074
|
if not types_equal(arg_types["a"], arg_types["b"]):
|
|
@@ -4286,6 +5609,493 @@ for t in int_types:
|
|
|
4286
5609
|
|
|
4287
5610
|
add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators")
|
|
4288
5611
|
|
|
5612
|
+
|
|
5613
|
+
# Tile operators
|
|
5614
|
+
def tile_unary_value_func(arg_types, arg_values):
|
|
5615
|
+
if arg_types is None:
|
|
5616
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
5617
|
+
|
|
5618
|
+
t = arg_types["x"]
|
|
5619
|
+
|
|
5620
|
+
if not is_tile(t):
|
|
5621
|
+
raise RuntimeError("Expected tile for unary expression")
|
|
5622
|
+
|
|
5623
|
+
return TileUnaryMap(t)
|
|
5624
|
+
|
|
5625
|
+
|
|
5626
|
+
def tile_scalar_mul_value_func(arg_types, arg_values):
|
|
5627
|
+
if arg_types is None:
|
|
5628
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
5629
|
+
|
|
5630
|
+
x = arg_types["x"]
|
|
5631
|
+
y = arg_types["y"]
|
|
5632
|
+
|
|
5633
|
+
# tile*scalar
|
|
5634
|
+
if is_tile(x):
|
|
5635
|
+
if x.dtype != y:
|
|
5636
|
+
raise RuntimeError(
|
|
5637
|
+
"Scalar factor should have the same type as tile for tile*scalar, tile type: {x} scalar type: {y}"
|
|
5638
|
+
)
|
|
5639
|
+
|
|
5640
|
+
return TileBinaryMap(x, TileConstant(y, x.M, x.N))
|
|
5641
|
+
|
|
5642
|
+
# scalar*tile
|
|
5643
|
+
if is_tile(y):
|
|
5644
|
+
if y.dtype != x:
|
|
5645
|
+
raise RuntimeError(
|
|
5646
|
+
"Scalar factor should have the same type as tile for scalar*tile, tile type: {x} scalar type: {y}"
|
|
5647
|
+
)
|
|
5648
|
+
|
|
5649
|
+
return TileBinaryMap(TileConstant(x, y.M, y.N), y)
|
|
5650
|
+
|
|
5651
|
+
|
|
5652
|
+
add_builtin(
|
|
5653
|
+
"neg",
|
|
5654
|
+
input_types={"x": Tile(dtype=Any, M=Any, N=Any)},
|
|
5655
|
+
value_func=tile_unary_value_func,
|
|
5656
|
+
doc="Negate each element of a tile",
|
|
5657
|
+
export=False,
|
|
5658
|
+
native_func="tile_neg",
|
|
5659
|
+
group="Operators",
|
|
5660
|
+
)
|
|
5661
|
+
|
|
5662
|
+
add_builtin(
|
|
5663
|
+
"add",
|
|
5664
|
+
input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
|
|
5665
|
+
value_func=tile_binary_map_value_func,
|
|
5666
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
5667
|
+
# variadic=True,
|
|
5668
|
+
native_func="tile_add",
|
|
5669
|
+
doc="Add each element of two tiles together",
|
|
5670
|
+
group="Tile Primitives",
|
|
5671
|
+
export=False,
|
|
5672
|
+
)
|
|
5673
|
+
|
|
5674
|
+
add_builtin(
|
|
5675
|
+
"mul",
|
|
5676
|
+
input_types={"x": Tile(dtype=Any, M=Any, N=Any), "y": Scalar},
|
|
5677
|
+
value_func=tile_scalar_mul_value_func,
|
|
5678
|
+
doc="Multiply each element of a tile by a scalar",
|
|
5679
|
+
export=False,
|
|
5680
|
+
native_func="tile_mul",
|
|
5681
|
+
group="Operators",
|
|
5682
|
+
)
|
|
5683
|
+
|
|
5684
|
+
add_builtin(
|
|
5685
|
+
"mul",
|
|
5686
|
+
input_types={"x": Scalar, "y": Tile(dtype=Any, M=Any, N=Any)},
|
|
5687
|
+
value_func=tile_scalar_mul_value_func,
|
|
5688
|
+
doc="Multiply each element of a tile by a scalar",
|
|
5689
|
+
export=False,
|
|
5690
|
+
native_func="tile_mul",
|
|
5691
|
+
group="Operators",
|
|
5692
|
+
)
|
|
5693
|
+
|
|
5694
|
+
|
|
5695
|
+
##
|
|
5696
|
+
## MathDx, LTOIR-based, Tile functions
|
|
5697
|
+
##
|
|
5698
|
+
|
|
5699
|
+
|
|
5700
|
+
##
|
|
5701
|
+
## Matmul
|
|
5702
|
+
##
|
|
5703
|
+
def tile_matmul_generic_value_func(arg_types, arg_values):
|
|
5704
|
+
# return generic type (for doc builds)
|
|
5705
|
+
if arg_types is None:
|
|
5706
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
5707
|
+
|
|
5708
|
+
a = arg_types["a"]
|
|
5709
|
+
b = arg_types["b"]
|
|
5710
|
+
|
|
5711
|
+
if not is_tile(a):
|
|
5712
|
+
raise RuntimeError("tile_matmul() argument 0 must be a tile")
|
|
5713
|
+
if not is_tile(b):
|
|
5714
|
+
raise RuntimeError("tile_matmul() argument 1 must be an tile")
|
|
5715
|
+
|
|
5716
|
+
# out = wp.tile_matmul(a, b)
|
|
5717
|
+
if len(arg_types) == 2:
|
|
5718
|
+
return Tile(dtype=a.dtype, M=a.M, N=b.N, storage="shared")
|
|
5719
|
+
|
|
5720
|
+
# wp.tile_matmul(a, b, out)
|
|
5721
|
+
elif len(arg_types) == 3:
|
|
5722
|
+
if not is_tile(arg_types["out"]):
|
|
5723
|
+
raise RuntimeError("tile_matmul() output argument must be a tile")
|
|
5724
|
+
|
|
5725
|
+
return None
|
|
5726
|
+
|
|
5727
|
+
|
|
5728
|
+
def tile_matmul_generic_lto_dispatch_func(
|
|
5729
|
+
arg_types: Mapping[str, type],
|
|
5730
|
+
return_type: Any,
|
|
5731
|
+
return_values: List[Var],
|
|
5732
|
+
arg_values: Mapping[str, Var],
|
|
5733
|
+
options: Mapping[str, Any],
|
|
5734
|
+
builder: warp.context.ModuleBuilder,
|
|
5735
|
+
):
|
|
5736
|
+
a = arg_values["a"]
|
|
5737
|
+
b = arg_values["b"]
|
|
5738
|
+
|
|
5739
|
+
if len(return_values) > 0:
|
|
5740
|
+
accumulate = 0 # for c = tile_matmul(a,b) case we want to overwrite c value
|
|
5741
|
+
out = return_values[0]
|
|
5742
|
+
else:
|
|
5743
|
+
accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value
|
|
5744
|
+
out = arg_values["out"]
|
|
5745
|
+
|
|
5746
|
+
if any(not is_tile(arg.type) for arg in [a, b, out]):
|
|
5747
|
+
raise RuntimeError("tile_matmul() requires three Tile arguments")
|
|
5748
|
+
|
|
5749
|
+
if any(arg.type.dtype not in [float16, float32, float64, vec2h, vec2f, vec2d] for arg in [a, b, out]):
|
|
5750
|
+
raise RuntimeError(
|
|
5751
|
+
"tile_matmul() arguments must be tiles of float16, float32 or float64, vec2h, vec2f, vec2d entries"
|
|
5752
|
+
)
|
|
5753
|
+
|
|
5754
|
+
if (a.type.N != b.type.M) or (a.type.M != out.type.M) or (b.type.N != out.type.N):
|
|
5755
|
+
raise RuntimeError("tile_matmul(A, B, C) requires sizes of A, B and C to be consistent for a matmul")
|
|
5756
|
+
|
|
5757
|
+
# set the storage type to the inputs to shared
|
|
5758
|
+
a.type.storage = "shared"
|
|
5759
|
+
b.type.storage = "shared"
|
|
5760
|
+
out.type.storage = "shared"
|
|
5761
|
+
template_args = [accumulate]
|
|
5762
|
+
|
|
5763
|
+
# Maps Python/Warp types to C++ types and enums
|
|
5764
|
+
def cublasdx_type_map(dtype):
|
|
5765
|
+
if dtype == float16:
|
|
5766
|
+
return ("wp::float16", 3, 0)
|
|
5767
|
+
if dtype == float32:
|
|
5768
|
+
return ("wp::float32", 5, 0)
|
|
5769
|
+
if dtype == float64:
|
|
5770
|
+
return ("wp::float64", 6, 0)
|
|
5771
|
+
if dtype == vec2h:
|
|
5772
|
+
return ("wp::vec2h", 3, 1)
|
|
5773
|
+
if dtype == vec2f:
|
|
5774
|
+
return ("wp::vec2f", 5, 1)
|
|
5775
|
+
if dtype == vec2d:
|
|
5776
|
+
return ("wp::vec2d", 6, 1)
|
|
5777
|
+
raise RuntimeError("Unsupported input type in tile_matmul")
|
|
5778
|
+
|
|
5779
|
+
def cublasdx_arrangement_map(layout):
|
|
5780
|
+
if layout == "colmajor":
|
|
5781
|
+
return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
|
|
5782
|
+
if layout == "rowmajor":
|
|
5783
|
+
return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
|
|
5784
|
+
raise RuntimeError("Unsupported layout in tile_matmul")
|
|
5785
|
+
|
|
5786
|
+
# generate the LTO
|
|
5787
|
+
M, K = a.type.M, a.type.N
|
|
5788
|
+
_, N = b.type.M, b.type.N
|
|
5789
|
+
num_threads = options["block_dim"]
|
|
5790
|
+
arch = options["output_arch"]
|
|
5791
|
+
|
|
5792
|
+
def make_function(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout):
|
|
5793
|
+
(a_dtype, a_prec, a_type) = cublasdx_type_map(adtype)
|
|
5794
|
+
(b_dtype, b_prec, b_type) = cublasdx_type_map(bdtype)
|
|
5795
|
+
(c_dtype, c_prec, c_type) = cublasdx_type_map(cdtype)
|
|
5796
|
+
a_arrangement = cublasdx_arrangement_map(alayout)
|
|
5797
|
+
b_arrangement = cublasdx_arrangement_map(blayout)
|
|
5798
|
+
c_arrangement = cublasdx_arrangement_map(clayout)
|
|
5799
|
+
|
|
5800
|
+
if a_type != b_type or a_type != c_type:
|
|
5801
|
+
raise RuntimeError("time_matmul(A, B, C) requires all inputs to be real or complex")
|
|
5802
|
+
element_type = a_type
|
|
5803
|
+
|
|
5804
|
+
lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
|
|
5805
|
+
|
|
5806
|
+
# early out if LTO for this combination already exists for this module
|
|
5807
|
+
if lto_symbol in builder.ltoirs:
|
|
5808
|
+
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
5809
|
+
|
|
5810
|
+
# otherwise compile LTO
|
|
5811
|
+
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
5812
|
+
result = warp.context.runtime.core.cuda_compile_dot(
|
|
5813
|
+
lto_code.name.encode("utf-8"),
|
|
5814
|
+
lto_symbol.encode("utf-8"),
|
|
5815
|
+
0,
|
|
5816
|
+
None,
|
|
5817
|
+
None,
|
|
5818
|
+
arch,
|
|
5819
|
+
M,
|
|
5820
|
+
N,
|
|
5821
|
+
K,
|
|
5822
|
+
a_prec,
|
|
5823
|
+
b_prec,
|
|
5824
|
+
c_prec,
|
|
5825
|
+
element_type,
|
|
5826
|
+
a_arrangement,
|
|
5827
|
+
b_arrangement,
|
|
5828
|
+
c_arrangement,
|
|
5829
|
+
num_threads,
|
|
5830
|
+
)
|
|
5831
|
+
lto_code_path = Path(lto_code.name)
|
|
5832
|
+
if not result:
|
|
5833
|
+
lto_code.close()
|
|
5834
|
+
if lto_code_path.exists():
|
|
5835
|
+
lto_code_path.unlink()
|
|
5836
|
+
raise RuntimeError("Failed to compile tile_matmul")
|
|
5837
|
+
else:
|
|
5838
|
+
with open(lto_code.name, "rb") as f:
|
|
5839
|
+
lto_code_data = f.read()
|
|
5840
|
+
lto_code.close()
|
|
5841
|
+
lto_code_path.unlink()
|
|
5842
|
+
|
|
5843
|
+
builder.ltoirs[lto_symbol] = lto_code_data
|
|
5844
|
+
builder.ltoirs_decl[lto_symbol] = (
|
|
5845
|
+
f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
|
|
5846
|
+
)
|
|
5847
|
+
|
|
5848
|
+
return lto_symbol, lto_code_data
|
|
5849
|
+
|
|
5850
|
+
def tile_flip_layout(layout):
|
|
5851
|
+
if layout == "rowmajor":
|
|
5852
|
+
return "colmajor"
|
|
5853
|
+
elif layout == "colmajor":
|
|
5854
|
+
return "rowmajor"
|
|
5855
|
+
|
|
5856
|
+
# C += A * B
|
|
5857
|
+
(fun_forward, lto_forward) = make_function(
|
|
5858
|
+
M, N, K, a.type.dtype, b.type.dtype, out.type.dtype, a.type.layout, b.type.layout, out.type.layout
|
|
5859
|
+
)
|
|
5860
|
+
# adjA += adjC * B^T - Transpose ~= flipped layout
|
|
5861
|
+
(fun_backward_A, lto_backward_A) = make_function(
|
|
5862
|
+
M,
|
|
5863
|
+
K,
|
|
5864
|
+
N,
|
|
5865
|
+
out.type.dtype,
|
|
5866
|
+
b.type.dtype,
|
|
5867
|
+
a.type.dtype,
|
|
5868
|
+
out.type.layout,
|
|
5869
|
+
tile_flip_layout(b.type.layout),
|
|
5870
|
+
a.type.layout,
|
|
5871
|
+
)
|
|
5872
|
+
# adjB += A^T * adjC - Transpose ~= flipped layout
|
|
5873
|
+
(fun_backward_B, lto_backward_B) = make_function(
|
|
5874
|
+
K,
|
|
5875
|
+
N,
|
|
5876
|
+
M,
|
|
5877
|
+
a.type.dtype,
|
|
5878
|
+
out.type.dtype,
|
|
5879
|
+
b.type.dtype,
|
|
5880
|
+
tile_flip_layout(a.type.layout),
|
|
5881
|
+
out.type.layout,
|
|
5882
|
+
b.type.layout,
|
|
5883
|
+
)
|
|
5884
|
+
|
|
5885
|
+
return (
|
|
5886
|
+
(
|
|
5887
|
+
Var(fun_forward, str, False, True, False),
|
|
5888
|
+
Var(fun_backward_A, str, False, True, False),
|
|
5889
|
+
Var(fun_backward_B, str, False, True, False),
|
|
5890
|
+
a,
|
|
5891
|
+
b,
|
|
5892
|
+
out,
|
|
5893
|
+
),
|
|
5894
|
+
template_args,
|
|
5895
|
+
[lto_forward, lto_backward_A, lto_backward_B],
|
|
5896
|
+
)
|
|
5897
|
+
|
|
5898
|
+
|
|
5899
|
+
add_builtin(
|
|
5900
|
+
"tile_matmul",
|
|
5901
|
+
input_types={
|
|
5902
|
+
"a": Tile(dtype=Any, M=Any, N=Any),
|
|
5903
|
+
"b": Tile(dtype=Any, M=Any, N=Any),
|
|
5904
|
+
"out": Tile(dtype=Any, M=Any, N=Any),
|
|
5905
|
+
},
|
|
5906
|
+
value_func=tile_matmul_generic_value_func,
|
|
5907
|
+
lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
|
|
5908
|
+
variadic=False,
|
|
5909
|
+
doc="""Computes the matrix product and accumulates ``out += a*b``.
|
|
5910
|
+
|
|
5911
|
+
Supported datatypes are:
|
|
5912
|
+
* fp16, fp32, fp64 (real)
|
|
5913
|
+
* vec2h, vec2f, vec2d (complex)
|
|
5914
|
+
|
|
5915
|
+
All input and output tiles must have the same datatype. Tile data will be automatically be migrated
|
|
5916
|
+
to shared memory if necessary and will use TensorCore operations when available.
|
|
5917
|
+
|
|
5918
|
+
:param a: A tile with ``shape=(M, K)``
|
|
5919
|
+
:param b: A tile with ``shape=(K, N)``
|
|
5920
|
+
:param out: A tile with ``shape=(M, N)``
|
|
5921
|
+
""",
|
|
5922
|
+
group="Tile Primitives",
|
|
5923
|
+
export=False,
|
|
5924
|
+
)
|
|
5925
|
+
|
|
5926
|
+
add_builtin(
|
|
5927
|
+
"tile_matmul",
|
|
5928
|
+
input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
|
|
5929
|
+
value_func=tile_matmul_generic_value_func,
|
|
5930
|
+
lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
|
|
5931
|
+
variadic=False,
|
|
5932
|
+
doc="""Computes the matrix product ``out = a*b``.
|
|
5933
|
+
|
|
5934
|
+
Supported datatypes are:
|
|
5935
|
+
* fp16, fp32, fp64 (real)
|
|
5936
|
+
* vec2h, vec2f, vec2d (complex)
|
|
5937
|
+
|
|
5938
|
+
Both input tiles must have the same datatype. Tile data will be automatically be migrated
|
|
5939
|
+
to shared memory if necessary and will use TensorCore operations when available.
|
|
5940
|
+
|
|
5941
|
+
:param a: A tile with ``shape=(M, K)``
|
|
5942
|
+
:param b: A tile with ``shape=(K, N)``
|
|
5943
|
+
:returns: A tile with ``shape=(M, N)``
|
|
5944
|
+
""",
|
|
5945
|
+
group="Tile Primitives",
|
|
5946
|
+
export=False,
|
|
5947
|
+
)
|
|
5948
|
+
|
|
5949
|
+
|
|
5950
|
+
##
|
|
5951
|
+
## FFT
|
|
5952
|
+
##
|
|
5953
|
+
def tile_fft_generic_value_func(arg_types, arg_values):
|
|
5954
|
+
if arg_types is None:
|
|
5955
|
+
return Tile(dtype=Any, M=Any, N=Any)
|
|
5956
|
+
|
|
5957
|
+
if len(arg_types) != 1:
|
|
5958
|
+
raise RuntimeError("tile_fft() requires 1 positional args")
|
|
5959
|
+
|
|
5960
|
+
if not is_tile(arg_types["inout"]):
|
|
5961
|
+
raise RuntimeError("tile_fft() argument 0 must be a tile")
|
|
5962
|
+
|
|
5963
|
+
if arg_types["inout"].storage != "register":
|
|
5964
|
+
raise RuntimeError("tile_fft() input/output argument must have register memory storage")
|
|
5965
|
+
|
|
5966
|
+
return None
|
|
5967
|
+
|
|
5968
|
+
|
|
5969
|
+
def tile_fft_generic_lto_dispatch_func(
|
|
5970
|
+
arg_types: Mapping[str, type],
|
|
5971
|
+
return_type: Any,
|
|
5972
|
+
return_values: List[Var],
|
|
5973
|
+
arg_values: Mapping[str, Var],
|
|
5974
|
+
options: Mapping[str, Any],
|
|
5975
|
+
builder: warp.context.ModuleBuilder,
|
|
5976
|
+
direction: str = None,
|
|
5977
|
+
):
|
|
5978
|
+
inout = arg_values["inout"]
|
|
5979
|
+
inout.type.storage = "register"
|
|
5980
|
+
|
|
5981
|
+
if not is_tile(inout.type):
|
|
5982
|
+
raise RuntimeError("tile_fft() arguments must be a single tile with register storage")
|
|
5983
|
+
|
|
5984
|
+
if inout.type.dtype not in [vec2f, vec2d]:
|
|
5985
|
+
raise RuntimeError("tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries")
|
|
5986
|
+
|
|
5987
|
+
# see libcufftdx.hpp
|
|
5988
|
+
if direction == "forward":
|
|
5989
|
+
dir = 0 # CUFFTDX_DIRECTION_FORWARD
|
|
5990
|
+
elif direction == "inverse":
|
|
5991
|
+
dir = 1 # CUFFTDX_DIRECTION_INVERSE
|
|
5992
|
+
else:
|
|
5993
|
+
raise RuntimeError("Invalid direction")
|
|
5994
|
+
|
|
5995
|
+
if inout.type.dtype == vec2f:
|
|
5996
|
+
dtype = "wp::vec2f"
|
|
5997
|
+
precision = 5 # COMMONDX_PRECISION_F32
|
|
5998
|
+
elif inout.type.dtype == vec2d:
|
|
5999
|
+
dtype = "wp::vec2d"
|
|
6000
|
+
precision = 6 # COMMONDX_PRECISION_F64
|
|
6001
|
+
else:
|
|
6002
|
+
raise RuntimeError("Unsupported datatype")
|
|
6003
|
+
|
|
6004
|
+
# M FFTs of size N each
|
|
6005
|
+
batch, size = inout.type.M, inout.type.N
|
|
6006
|
+
num_threads = options["block_dim"]
|
|
6007
|
+
arch = options["output_arch"]
|
|
6008
|
+
ept = size // num_threads
|
|
6009
|
+
lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
|
|
6010
|
+
|
|
6011
|
+
# early out if LTO for this combination already exists for this module
|
|
6012
|
+
if lto_symbol in builder.ltoirs:
|
|
6013
|
+
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6014
|
+
|
|
6015
|
+
# otherwise compile LTO
|
|
6016
|
+
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6017
|
+
shared_memory_size = ctypes.c_int(0)
|
|
6018
|
+
|
|
6019
|
+
result = warp.context.runtime.core.cuda_compile_fft(
|
|
6020
|
+
lto_code.name.encode("utf-8"),
|
|
6021
|
+
lto_symbol.encode("utf-8"),
|
|
6022
|
+
0,
|
|
6023
|
+
None,
|
|
6024
|
+
None,
|
|
6025
|
+
arch,
|
|
6026
|
+
size,
|
|
6027
|
+
ept,
|
|
6028
|
+
dir,
|
|
6029
|
+
precision,
|
|
6030
|
+
ctypes.byref(shared_memory_size),
|
|
6031
|
+
)
|
|
6032
|
+
lto_code_path = Path(lto_code.name)
|
|
6033
|
+
if not result:
|
|
6034
|
+
lto_code.close()
|
|
6035
|
+
if lto_code_path.exists():
|
|
6036
|
+
lto_code_path.unlink()
|
|
6037
|
+
raise RuntimeError("Failed to compile tile_matmul")
|
|
6038
|
+
|
|
6039
|
+
with open(lto_code.name, "rb") as f:
|
|
6040
|
+
lto_code_data = f.read()
|
|
6041
|
+
|
|
6042
|
+
lto_code.close()
|
|
6043
|
+
lto_code_path.unlink()
|
|
6044
|
+
|
|
6045
|
+
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6046
|
+
|
|
6047
|
+
return (
|
|
6048
|
+
(
|
|
6049
|
+
Var(lto_symbol, str, False, True, False),
|
|
6050
|
+
Var(dtype, str, False, True, False),
|
|
6051
|
+
Var(str(shared_memory_size.value), str, False, True, False),
|
|
6052
|
+
Var(str(batch), str, False, True, False),
|
|
6053
|
+
Var(str(ept), str, False, True, False),
|
|
6054
|
+
inout,
|
|
6055
|
+
),
|
|
6056
|
+
[],
|
|
6057
|
+
[lto_code_data],
|
|
6058
|
+
)
|
|
6059
|
+
|
|
6060
|
+
|
|
6061
|
+
add_builtin(
|
|
6062
|
+
"tile_fft",
|
|
6063
|
+
input_types={"inout": Tile},
|
|
6064
|
+
value_func=tile_fft_generic_value_func,
|
|
6065
|
+
lto_dispatch_func=functools.partial(tile_fft_generic_lto_dispatch_func, direction="forward"),
|
|
6066
|
+
variadic=True,
|
|
6067
|
+
doc="""Compute the forward FFT along the second dimension of a 2D tile of data.
|
|
6068
|
+
|
|
6069
|
+
This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
|
|
6070
|
+
|
|
6071
|
+
Supported datatypes are:
|
|
6072
|
+
* vec2f, vec2d
|
|
6073
|
+
|
|
6074
|
+
:param inout: The input/output tile""",
|
|
6075
|
+
group="Tile Primitives",
|
|
6076
|
+
export=False,
|
|
6077
|
+
namespace="",
|
|
6078
|
+
)
|
|
6079
|
+
|
|
6080
|
+
add_builtin(
|
|
6081
|
+
"tile_ifft",
|
|
6082
|
+
input_types={"inout": Tile},
|
|
6083
|
+
value_func=tile_fft_generic_value_func,
|
|
6084
|
+
lto_dispatch_func=functools.partial(tile_fft_generic_lto_dispatch_func, direction="inverse"),
|
|
6085
|
+
variadic=True,
|
|
6086
|
+
doc="""Compute the inverse FFT along the second dimension of a 2D tile of data.
|
|
6087
|
+
|
|
6088
|
+
This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
|
|
6089
|
+
|
|
6090
|
+
Supported datatypes are:
|
|
6091
|
+
* vec2f, vec2d
|
|
6092
|
+
|
|
6093
|
+
:param inout: The input/output tile""",
|
|
6094
|
+
group="Tile Primitives",
|
|
6095
|
+
export=False,
|
|
6096
|
+
namespace="",
|
|
6097
|
+
)
|
|
6098
|
+
|
|
4289
6099
|
# ---------------------------------
|
|
4290
6100
|
# Code Generation
|
|
4291
6101
|
|
|
@@ -4295,13 +6105,12 @@ add_builtin(
|
|
|
4295
6105
|
value_type=Any,
|
|
4296
6106
|
doc="""Evaluates a static Python expression and replaces it with its result.
|
|
4297
6107
|
|
|
4298
|
-
See the
|
|
6108
|
+
See the :ref:`code generation guide <static_expressions>` for more details.
|
|
4299
6109
|
|
|
4300
|
-
|
|
4301
|
-
|
|
4302
|
-
|
|
4303
|
-
|
|
4304
|
-
(excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
|
|
6110
|
+
The inner expression must only reference variables that are available from the current scope where the Warp kernel or function containing the expression is defined,
|
|
6111
|
+
which includes constant variables and variables captured in the current closure in which the function or kernel is implemented.
|
|
6112
|
+
The return type of the expression must be either a Warp function, a string, or a type that is supported inside Warp kernels and functions
|
|
6113
|
+
(excluding Warp arrays since they cannot be created in a Warp kernel at the moment).""",
|
|
4305
6114
|
group="Code Generation",
|
|
4306
6115
|
)
|
|
4307
6116
|
|