warp-lang 1.5.1__py3-none-manylinux2014_x86_64.whl → 1.6.1__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +5 -0
- warp/autograd.py +414 -191
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +40 -12
- warp/build_dll.py +13 -6
- warp/builtins.py +1077 -481
- warp/codegen.py +250 -122
- warp/config.py +65 -21
- warp/context.py +500 -149
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_gemm.py +27 -18
- warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
- warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
- warp/examples/core/example_marching_cubes.py +1 -1
- warp/examples/core/example_mesh.py +1 -1
- warp/examples/core/example_torch.py +18 -34
- warp/examples/core/example_wave.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -0
- warp/examples/fem/example_mixed_elasticity.py +1 -1
- warp/examples/optim/example_bounce.py +1 -1
- warp/examples/optim/example_cloth_throw.py +1 -1
- warp/examples/optim/example_diffray.py +4 -15
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/optim/example_softbody_properties.py +392 -0
- warp/examples/optim/example_trajectory.py +1 -3
- warp/examples/optim/example_walker.py +5 -0
- warp/examples/sim/example_cartpole.py +0 -2
- warp/examples/sim/example_cloth_self_contact.py +314 -0
- warp/examples/sim/example_granular_collision_sdf.py +4 -5
- warp/examples/sim/example_jacobian_ik.py +0 -2
- warp/examples/sim/example_quadruped.py +5 -2
- warp/examples/tile/example_tile_cholesky.py +79 -0
- warp/examples/tile/example_tile_convolution.py +2 -2
- warp/examples/tile/example_tile_fft.py +2 -2
- warp/examples/tile/example_tile_filtering.py +3 -3
- warp/examples/tile/example_tile_matmul.py +4 -4
- warp/examples/tile/example_tile_mlp.py +12 -12
- warp/examples/tile/example_tile_nbody.py +191 -0
- warp/examples/tile/example_tile_walker.py +319 -0
- warp/math.py +147 -0
- warp/native/array.h +12 -0
- warp/native/builtin.h +0 -1
- warp/native/bvh.cpp +149 -70
- warp/native/bvh.cu +287 -68
- warp/native/bvh.h +195 -85
- warp/native/clang/clang.cpp +6 -2
- warp/native/crt.h +1 -0
- warp/native/cuda_util.cpp +35 -0
- warp/native/cuda_util.h +5 -0
- warp/native/exports.h +40 -40
- warp/native/intersect.h +17 -0
- warp/native/mat.h +57 -3
- warp/native/mathdx.cpp +19 -0
- warp/native/mesh.cpp +25 -8
- warp/native/mesh.cu +153 -101
- warp/native/mesh.h +482 -403
- warp/native/quat.h +40 -0
- warp/native/solid_angle.h +7 -0
- warp/native/sort.cpp +85 -0
- warp/native/sort.cu +34 -0
- warp/native/sort.h +3 -1
- warp/native/spatial.h +11 -0
- warp/native/tile.h +1189 -664
- warp/native/tile_reduce.h +8 -6
- warp/native/vec.h +41 -0
- warp/native/warp.cpp +8 -1
- warp/native/warp.cu +263 -40
- warp/native/warp.h +19 -5
- warp/optim/linear.py +22 -4
- warp/render/render_opengl.py +132 -59
- warp/render/render_usd.py +10 -2
- warp/sim/__init__.py +6 -1
- warp/sim/collide.py +289 -32
- warp/sim/import_urdf.py +20 -5
- warp/sim/integrator_euler.py +25 -7
- warp/sim/integrator_featherstone.py +147 -35
- warp/sim/integrator_vbd.py +842 -40
- warp/sim/model.py +173 -112
- warp/sim/render.py +2 -2
- warp/stubs.py +249 -116
- warp/tape.py +28 -30
- warp/tests/aux_test_module_unload.py +15 -0
- warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
- warp/tests/test_array.py +100 -0
- warp/tests/test_assert.py +242 -0
- warp/tests/test_codegen.py +14 -61
- warp/tests/test_collision.py +8 -8
- warp/tests/test_examples.py +16 -1
- warp/tests/test_grad_debug.py +87 -2
- warp/tests/test_hash_grid.py +1 -1
- warp/tests/test_ipc.py +116 -0
- warp/tests/test_launch.py +77 -26
- warp/tests/test_mat.py +213 -168
- warp/tests/test_math.py +47 -1
- warp/tests/test_matmul.py +11 -7
- warp/tests/test_matmul_lite.py +4 -4
- warp/tests/test_mesh.py +84 -60
- warp/tests/test_mesh_query_aabb.py +165 -0
- warp/tests/test_mesh_query_point.py +328 -286
- warp/tests/test_mesh_query_ray.py +134 -121
- warp/tests/test_mlp.py +2 -2
- warp/tests/test_operators.py +43 -0
- warp/tests/test_overwrite.py +6 -5
- warp/tests/test_quat.py +77 -0
- warp/tests/test_reload.py +29 -0
- warp/tests/test_sim_grad_bounce_linear.py +204 -0
- warp/tests/test_static.py +16 -0
- warp/tests/test_tape.py +25 -0
- warp/tests/test_tile.py +134 -191
- warp/tests/test_tile_load.py +399 -0
- warp/tests/test_tile_mathdx.py +61 -8
- warp/tests/test_tile_mlp.py +17 -17
- warp/tests/test_tile_reduce.py +24 -18
- warp/tests/test_tile_shared_memory.py +66 -17
- warp/tests/test_tile_view.py +165 -0
- warp/tests/test_torch.py +35 -0
- warp/tests/test_utils.py +36 -24
- warp/tests/test_vec.py +110 -0
- warp/tests/unittest_suites.py +29 -4
- warp/tests/unittest_utils.py +30 -11
- warp/thirdparty/unittest_parallel.py +5 -2
- warp/types.py +419 -111
- warp/utils.py +9 -5
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/METADATA +86 -45
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/RECORD +129 -118
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/WHEEL +1 -1
- warp/examples/benchmarks/benchmark_tile.py +0 -179
- warp/native/tile_gemm.h +0 -341
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -1707,64 +1707,98 @@ add_builtin(
|
|
|
1707
1707
|
|
|
1708
1708
|
# ------------------
|
|
1709
1709
|
# Tile-based primitives
|
|
1710
|
-
|
|
1710
|
+
|
|
1711
|
+
|
|
1712
|
+
def tile_unpack_shape(arg_values):
|
|
1713
|
+
shape = arg_values["shape"]
|
|
1714
|
+
|
|
1715
|
+
if not isinstance(shape, tuple):
|
|
1716
|
+
# promote to tuple
|
|
1717
|
+
shape = (shape,)
|
|
1718
|
+
|
|
1719
|
+
# check that components are constants
|
|
1720
|
+
for d in shape:
|
|
1721
|
+
if d is None:
|
|
1722
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
1723
|
+
|
|
1724
|
+
return shape
|
|
1725
|
+
|
|
1726
|
+
|
|
1727
|
+
def tile_unpack_offset(arg_values, ndim=0):
|
|
1728
|
+
if "offset" in arg_values:
|
|
1729
|
+
offset = arg_values["offset"]
|
|
1730
|
+
else:
|
|
1731
|
+
offset = (0,) * ndim
|
|
1732
|
+
|
|
1733
|
+
if isinstance(offset, tuple):
|
|
1734
|
+
return offset
|
|
1735
|
+
else:
|
|
1736
|
+
# promote to tuple
|
|
1737
|
+
return (offset,)
|
|
1711
1738
|
|
|
1712
1739
|
|
|
1713
1740
|
def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1714
1741
|
# return generic type (for doc builds)
|
|
1715
1742
|
if arg_types is None:
|
|
1716
|
-
return Tile(dtype=Any,
|
|
1717
|
-
|
|
1718
|
-
if "m" not in arg_values:
|
|
1719
|
-
raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
|
|
1743
|
+
return Tile(dtype=Any, shape=Any)
|
|
1720
1744
|
|
|
1721
|
-
|
|
1722
|
-
raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
|
|
1745
|
+
shape = tile_unpack_shape(arg_values)
|
|
1723
1746
|
|
|
1724
1747
|
if "dtype" not in arg_values:
|
|
1725
|
-
raise
|
|
1748
|
+
raise TypeError("tile_zeros() missing required keyword argument 'dtype'")
|
|
1726
1749
|
|
|
1727
1750
|
if "storage" not in arg_values:
|
|
1728
|
-
raise
|
|
1751
|
+
raise TypeError("tile_zeros() missing required keyword argument 'storage'")
|
|
1729
1752
|
|
|
1730
1753
|
if arg_values["storage"] not in {"shared", "register"}:
|
|
1731
|
-
raise ValueError(
|
|
1732
|
-
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1733
|
-
)
|
|
1754
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
1734
1755
|
|
|
1735
|
-
m, n = arg_values["m"], arg_values["n"]
|
|
1736
1756
|
dtype = arg_values["dtype"]
|
|
1737
1757
|
|
|
1738
|
-
return TileZeros(dtype=dtype,
|
|
1758
|
+
return TileZeros(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
1739
1759
|
|
|
1740
1760
|
|
|
1741
1761
|
def tile_zeros_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1742
|
-
|
|
1762
|
+
shape = tile_unpack_shape(arg_values)
|
|
1763
|
+
dtype = arg_values["dtype"]
|
|
1743
1764
|
|
|
1744
1765
|
template_args = []
|
|
1745
1766
|
template_args.append(dtype)
|
|
1746
|
-
|
|
1747
|
-
|
|
1767
|
+
for d in shape:
|
|
1768
|
+
template_args.append(d.constant)
|
|
1748
1769
|
|
|
1749
1770
|
return ([], template_args)
|
|
1750
1771
|
|
|
1751
1772
|
|
|
1752
1773
|
add_builtin(
|
|
1753
1774
|
"tile_zeros",
|
|
1754
|
-
input_types={"
|
|
1755
|
-
defaults={"storage": "register"},
|
|
1775
|
+
input_types={"shape": Tuple[int, ...], "dtype": Any, "storage": str},
|
|
1776
|
+
defaults={"storage": "register", "dtype": float},
|
|
1756
1777
|
value_func=tile_zeros_value_func,
|
|
1757
1778
|
dispatch_func=tile_zeros_dispatch_func,
|
|
1758
1779
|
variadic=False,
|
|
1759
1780
|
missing_grad=True,
|
|
1760
|
-
doc="""
|
|
1781
|
+
doc="""Allocate a tile of zero-initialized items.
|
|
1761
1782
|
|
|
1762
|
-
:param
|
|
1763
|
-
:param
|
|
1764
|
-
:param dtype: Datatype of output tile's elements
|
|
1783
|
+
:param shape: Shape of the output tile
|
|
1784
|
+
:param dtype: Data type of output tile's elements (default float)
|
|
1765
1785
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1766
1786
|
(default) or ``"shared"`` for shared memory.
|
|
1767
|
-
:returns: A zero-initialized tile with
|
|
1787
|
+
:returns: A zero-initialized tile with shape and data type as specified""",
|
|
1788
|
+
group="Tile Primitives",
|
|
1789
|
+
export=False,
|
|
1790
|
+
)
|
|
1791
|
+
|
|
1792
|
+
# overload for scalar shape
|
|
1793
|
+
add_builtin(
|
|
1794
|
+
"tile_zeros",
|
|
1795
|
+
input_types={"shape": int, "dtype": Any, "storage": str},
|
|
1796
|
+
defaults={"storage": "register", "dtype": float},
|
|
1797
|
+
value_func=tile_zeros_value_func,
|
|
1798
|
+
dispatch_func=tile_zeros_dispatch_func,
|
|
1799
|
+
variadic=False,
|
|
1800
|
+
missing_grad=True,
|
|
1801
|
+
hidden=True,
|
|
1768
1802
|
group="Tile Primitives",
|
|
1769
1803
|
export=False,
|
|
1770
1804
|
)
|
|
@@ -1773,54 +1807,63 @@ add_builtin(
|
|
|
1773
1807
|
def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1774
1808
|
# return generic type (for doc builds)
|
|
1775
1809
|
if arg_types is None:
|
|
1776
|
-
return Tile(dtype=Any,
|
|
1810
|
+
return Tile(dtype=Any, shape=Any)
|
|
1777
1811
|
|
|
1778
|
-
|
|
1779
|
-
raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
|
|
1780
|
-
|
|
1781
|
-
if "n" not in arg_values:
|
|
1782
|
-
raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
|
|
1812
|
+
shape = tile_unpack_shape(arg_values)
|
|
1783
1813
|
|
|
1784
1814
|
if "dtype" not in arg_values:
|
|
1785
|
-
raise
|
|
1815
|
+
raise TypeError("tile_ones() missing required keyword argument 'dtype'")
|
|
1816
|
+
|
|
1817
|
+
if "storage" not in arg_values:
|
|
1818
|
+
raise TypeError("tile_ones() missing required keyword argument 'storage'")
|
|
1786
1819
|
|
|
1787
1820
|
if arg_values["storage"] not in {"shared", "register"}:
|
|
1788
|
-
raise ValueError(
|
|
1789
|
-
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1790
|
-
)
|
|
1821
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
1791
1822
|
|
|
1792
|
-
m, n = arg_values["m"], arg_values["n"]
|
|
1793
1823
|
dtype = arg_values["dtype"]
|
|
1794
1824
|
|
|
1795
|
-
return
|
|
1825
|
+
return TileOnes(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
1796
1826
|
|
|
1797
1827
|
|
|
1798
1828
|
def tile_ones_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1799
|
-
|
|
1829
|
+
shape = tile_unpack_shape(arg_values)
|
|
1830
|
+
dtype = arg_values["dtype"]
|
|
1800
1831
|
|
|
1801
1832
|
template_args = []
|
|
1802
1833
|
template_args.append(dtype)
|
|
1803
|
-
|
|
1804
|
-
|
|
1834
|
+
for d in shape:
|
|
1835
|
+
template_args.append(d.constant)
|
|
1805
1836
|
|
|
1806
1837
|
return ([], template_args)
|
|
1807
1838
|
|
|
1808
1839
|
|
|
1809
1840
|
add_builtin(
|
|
1810
1841
|
"tile_ones",
|
|
1811
|
-
input_types={"
|
|
1842
|
+
input_types={"shape": Tuple[int, ...], "dtype": Any, "storage": str},
|
|
1812
1843
|
defaults={"storage": "register"},
|
|
1813
1844
|
value_func=tile_ones_value_func,
|
|
1814
1845
|
dispatch_func=tile_ones_dispatch_func,
|
|
1815
1846
|
missing_grad=True,
|
|
1816
|
-
doc="""
|
|
1847
|
+
doc="""Allocate a tile of one-initialized items.
|
|
1817
1848
|
|
|
1818
|
-
:param
|
|
1819
|
-
:param
|
|
1820
|
-
:param dtype: Datatype of output tile's elements
|
|
1849
|
+
:param shape: Shape of the output tile
|
|
1850
|
+
:param dtype: Data type of output tile's elements
|
|
1821
1851
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1822
1852
|
(default) or ``"shared"`` for shared memory.
|
|
1823
|
-
:returns: A one-initialized tile with
|
|
1853
|
+
:returns: A one-initialized tile with shape and data type as specified""",
|
|
1854
|
+
group="Tile Primitives",
|
|
1855
|
+
export=False,
|
|
1856
|
+
)
|
|
1857
|
+
|
|
1858
|
+
# overload for scalar shape
|
|
1859
|
+
add_builtin(
|
|
1860
|
+
"tile_ones",
|
|
1861
|
+
input_types={"shape": int, "dtype": Any, "storage": str},
|
|
1862
|
+
defaults={"storage": "register"},
|
|
1863
|
+
value_func=tile_ones_value_func,
|
|
1864
|
+
dispatch_func=tile_ones_dispatch_func,
|
|
1865
|
+
missing_grad=True,
|
|
1866
|
+
hidden=True,
|
|
1824
1867
|
group="Tile Primitives",
|
|
1825
1868
|
export=False,
|
|
1826
1869
|
)
|
|
@@ -1829,14 +1872,16 @@ add_builtin(
|
|
|
1829
1872
|
def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1830
1873
|
# return generic type (for doc builds)
|
|
1831
1874
|
if arg_types is None:
|
|
1832
|
-
return Tile(dtype=Any,
|
|
1875
|
+
return Tile(dtype=Any, shape=Any)
|
|
1876
|
+
|
|
1877
|
+
if "args" not in arg_values:
|
|
1878
|
+
raise TypeError("tile_arange() requires at least one positional argument specifying the range")
|
|
1879
|
+
|
|
1880
|
+
args = arg_values["args"]
|
|
1833
1881
|
|
|
1834
1882
|
start = 0
|
|
1835
1883
|
stop = 0
|
|
1836
1884
|
step = 1
|
|
1837
|
-
dtype = int
|
|
1838
|
-
|
|
1839
|
-
args = arg_values["args"]
|
|
1840
1885
|
|
|
1841
1886
|
if len(args) == 1:
|
|
1842
1887
|
start = 0
|
|
@@ -1852,8 +1897,7 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
|
|
|
1852
1897
|
step = args[2]
|
|
1853
1898
|
|
|
1854
1899
|
if start is None or stop is None or step is None:
|
|
1855
|
-
|
|
1856
|
-
raise RuntimeError("wp.tile_arange() arguments must be compile time constants")
|
|
1900
|
+
raise RuntimeError("tile_arange() arguments must be compile time constants")
|
|
1857
1901
|
|
|
1858
1902
|
if "dtype" in arg_values:
|
|
1859
1903
|
dtype = arg_values["dtype"]
|
|
@@ -1861,26 +1905,37 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
|
|
|
1861
1905
|
dtype = float
|
|
1862
1906
|
|
|
1863
1907
|
if arg_values["storage"] not in {"shared", "register"}:
|
|
1864
|
-
raise ValueError(
|
|
1865
|
-
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1866
|
-
)
|
|
1908
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
1867
1909
|
|
|
1868
1910
|
return TileRange(dtype=dtype, start=start, stop=stop, step=step, storage=arg_values["storage"])
|
|
1869
1911
|
|
|
1870
1912
|
|
|
1871
1913
|
def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1872
|
-
|
|
1914
|
+
size, dtype = return_type.size, return_type.dtype
|
|
1873
1915
|
|
|
1874
1916
|
template_args = []
|
|
1875
1917
|
template_args.append(dtype)
|
|
1876
|
-
template_args.append(
|
|
1877
|
-
|
|
1918
|
+
template_args.append(size)
|
|
1919
|
+
|
|
1920
|
+
if "args" not in arg_values:
|
|
1921
|
+
raise TypeError("tile_arange() requires at least one positional argument specifying the range")
|
|
1922
|
+
|
|
1923
|
+
args = arg_values["args"]
|
|
1878
1924
|
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
|
|
1925
|
+
if len(args) == 1:
|
|
1926
|
+
start = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.start)
|
|
1927
|
+
stop = args[0]
|
|
1928
|
+
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
|
|
1929
|
+
elif len(args) == 2:
|
|
1930
|
+
start = args[0]
|
|
1931
|
+
stop = args[1]
|
|
1932
|
+
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
|
|
1933
|
+
elif len(args) == 3:
|
|
1934
|
+
start = args[0]
|
|
1935
|
+
stop = args[1]
|
|
1936
|
+
step = args[2]
|
|
1937
|
+
else:
|
|
1938
|
+
raise TypeError(f"tile_arange() accepts at most 3 positional arguments, got {len(args)}")
|
|
1884
1939
|
|
|
1885
1940
|
function_args = []
|
|
1886
1941
|
function_args.append(start)
|
|
@@ -1898,7 +1953,7 @@ add_builtin(
|
|
|
1898
1953
|
dispatch_func=tile_arange_dispatch_func,
|
|
1899
1954
|
variadic=True,
|
|
1900
1955
|
missing_grad=True,
|
|
1901
|
-
doc="""
|
|
1956
|
+
doc="""Generate a tile of linearly spaced elements.
|
|
1902
1957
|
|
|
1903
1958
|
:param args: Variable-length positional arguments, interpreted as:
|
|
1904
1959
|
|
|
@@ -1906,246 +1961,157 @@ add_builtin(
|
|
|
1906
1961
|
- ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
|
|
1907
1962
|
- ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
|
|
1908
1963
|
|
|
1909
|
-
:param dtype:
|
|
1964
|
+
:param dtype: Data type of output tile's elements (optional, default: ``float``)
|
|
1910
1965
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1911
1966
|
(default) or ``"shared"`` for shared memory.
|
|
1912
|
-
:returns: A tile with ``shape=(
|
|
1967
|
+
:returns: A tile with ``shape=(n)`` with linearly spaced elements of specified data type""",
|
|
1913
1968
|
group="Tile Primitives",
|
|
1914
1969
|
export=False,
|
|
1915
1970
|
)
|
|
1916
1971
|
|
|
1917
1972
|
|
|
1918
|
-
def
|
|
1919
|
-
# return generic type (for doc builds)
|
|
1973
|
+
def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1920
1974
|
if arg_types is None:
|
|
1921
|
-
return
|
|
1922
|
-
|
|
1923
|
-
if not is_array(arg_types["a"]):
|
|
1924
|
-
raise RuntimeError("tile_load() argument 0 must be an array")
|
|
1975
|
+
return array(dtype=Scalar)
|
|
1925
1976
|
|
|
1926
|
-
|
|
1927
|
-
raise RuntimeError(
|
|
1928
|
-
"tile_load() argument 0 must be 1-dimensional if using the ``wp.tile_load(array, i, n)`` syntax."
|
|
1929
|
-
)
|
|
1977
|
+
a = arg_types["a"]
|
|
1930
1978
|
|
|
1931
|
-
|
|
1932
|
-
|
|
1979
|
+
shape = tile_unpack_shape(arg_values)
|
|
1980
|
+
offset = tile_unpack_offset(arg_values, a.ndim)
|
|
1933
1981
|
|
|
1934
|
-
if
|
|
1935
|
-
raise
|
|
1982
|
+
if a.ndim != len(shape):
|
|
1983
|
+
raise ValueError(
|
|
1984
|
+
f"tile_load() array argument must have same number of dimensions as the tile shape, trying to perform an {len(shape)} dimensional load from an array with {a.ndim} dimensions."
|
|
1985
|
+
)
|
|
1936
1986
|
|
|
1937
|
-
if
|
|
1987
|
+
if a.ndim != len(offset):
|
|
1938
1988
|
raise ValueError(
|
|
1939
|
-
f"
|
|
1989
|
+
f"tile_load() offset argument must have the same number of dimensions as the array to load from, got {len(offset)} indices for an array with {a.ndim} dimensions"
|
|
1940
1990
|
)
|
|
1941
1991
|
|
|
1942
|
-
|
|
1943
|
-
|
|
1992
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
1993
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
1944
1994
|
|
|
1945
|
-
return
|
|
1995
|
+
return Tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
|
|
1946
1996
|
|
|
1947
1997
|
|
|
1948
|
-
def
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
dtype = arg_values["a"].type.dtype
|
|
1998
|
+
def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1999
|
+
a = args["a"]
|
|
2000
|
+
shape = tile_unpack_shape(args)
|
|
2001
|
+
offset = tile_unpack_offset(args, a.type.ndim)
|
|
1953
2002
|
|
|
1954
|
-
|
|
1955
|
-
template_args.
|
|
1956
|
-
template_args.append(n)
|
|
2003
|
+
func_args = (a, *offset)
|
|
2004
|
+
template_args = (d.constant for d in shape)
|
|
1957
2005
|
|
|
1958
|
-
return (
|
|
2006
|
+
return (func_args, template_args)
|
|
1959
2007
|
|
|
1960
2008
|
|
|
1961
2009
|
add_builtin(
|
|
1962
2010
|
"tile_load",
|
|
1963
|
-
input_types={"a": array(dtype=Any), "
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
2011
|
+
input_types={"a": array(dtype=Any), "shape": Tuple[int, ...], "offset": Tuple[int, ...], "storage": str},
|
|
2012
|
+
value_func=tile_load_tuple_value_func,
|
|
2013
|
+
dispatch_func=tile_load_tuple_dispatch_func,
|
|
2014
|
+
defaults={"offset": None, "storage": "register"},
|
|
1967
2015
|
variadic=False,
|
|
1968
|
-
doc="""Loads a
|
|
2016
|
+
doc="""Loads a tile from a global memory array.
|
|
1969
2017
|
|
|
1970
2018
|
This method will cooperatively load a tile from global memory using all threads in the block.
|
|
1971
2019
|
|
|
1972
2020
|
:param a: The source array in global memory
|
|
1973
|
-
:param
|
|
1974
|
-
:param
|
|
2021
|
+
:param shape: Shape of the tile to load, must have the same number of dimensions as ``a``
|
|
2022
|
+
:param offset: Offset in the source array to begin reading from (optional)
|
|
1975
2023
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1976
2024
|
(default) or ``"shared"`` for shared memory.
|
|
1977
|
-
:returns: A tile with
|
|
2025
|
+
:returns: A tile with shape as specified and data type the same as the source array""",
|
|
1978
2026
|
group="Tile Primitives",
|
|
1979
2027
|
export=False,
|
|
1980
2028
|
)
|
|
1981
2029
|
|
|
1982
|
-
|
|
1983
|
-
def tile_load_2d_value_func(arg_types, arg_values):
|
|
1984
|
-
# return generic type (for doc builds)
|
|
1985
|
-
if arg_types is None:
|
|
1986
|
-
return Tile(dtype=Any, M=Any, N=Any)
|
|
1987
|
-
|
|
1988
|
-
if not is_array(arg_types["a"]):
|
|
1989
|
-
raise RuntimeError("tile_load() argument 0 must be an array")
|
|
1990
|
-
|
|
1991
|
-
if arg_types["a"].ndim != 2:
|
|
1992
|
-
raise RuntimeError(
|
|
1993
|
-
"tile_load() argument 0 must be 2-dimensional if using the ``wp.tile_load(array, i, j, m, n)`` syntax."
|
|
1994
|
-
)
|
|
1995
|
-
|
|
1996
|
-
if not type_is_int(arg_types["i"]):
|
|
1997
|
-
raise RuntimeError("tile_load() argument 1 must be an integer")
|
|
1998
|
-
|
|
1999
|
-
if not type_is_int(arg_types["j"]):
|
|
2000
|
-
raise RuntimeError("tile_load() argument 1 must be an integer")
|
|
2001
|
-
|
|
2002
|
-
if "m" not in arg_values:
|
|
2003
|
-
raise RuntimeError("'m' keyword argument must be specified when calling tile_load() function")
|
|
2004
|
-
|
|
2005
|
-
if "n" not in arg_values:
|
|
2006
|
-
raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
|
|
2007
|
-
|
|
2008
|
-
if arg_values["storage"] not in {"shared", "register"}:
|
|
2009
|
-
raise ValueError(
|
|
2010
|
-
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
2011
|
-
)
|
|
2012
|
-
|
|
2013
|
-
a = arg_types["a"]
|
|
2014
|
-
m, n = arg_values["m"], arg_values["n"]
|
|
2015
|
-
|
|
2016
|
-
return TileLoad(a, m, n, arg_values["storage"])
|
|
2017
|
-
|
|
2018
|
-
|
|
2019
|
-
def tile_load_2d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2020
|
-
array = arg_values["a"]
|
|
2021
|
-
i, j = arg_values["i"], arg_values["j"]
|
|
2022
|
-
m, n = arg_values["m"].constant, arg_values["n"].constant
|
|
2023
|
-
dtype = arg_values["a"].type.dtype
|
|
2024
|
-
|
|
2025
|
-
template_args = []
|
|
2026
|
-
template_args.append(dtype)
|
|
2027
|
-
template_args.append(m)
|
|
2028
|
-
template_args.append(n)
|
|
2029
|
-
|
|
2030
|
-
return ((array, i, j), template_args)
|
|
2031
|
-
|
|
2032
|
-
|
|
2030
|
+
# overload for scalar shape
|
|
2033
2031
|
add_builtin(
|
|
2034
2032
|
"tile_load",
|
|
2035
|
-
input_types={"a": array(dtype=Any), "
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
variadic=False,
|
|
2040
|
-
doc="""Loads a 2D tile from a global memory array.
|
|
2041
|
-
|
|
2042
|
-
This method will cooperatively load a tile from global memory using all threads in the block.
|
|
2043
|
-
|
|
2044
|
-
:param a: The source array in global memory
|
|
2045
|
-
:param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
|
|
2046
|
-
:param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
|
|
2047
|
-
:param m: The size of the tile's first dimension
|
|
2048
|
-
:param n: The size of the tile's second dimension
|
|
2049
|
-
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
2050
|
-
(default) or ``"shared"`` for shared memory.
|
|
2051
|
-
:returns: A tile with ``shape=(m,n)`` and dtype the same as the source array""",
|
|
2033
|
+
input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str},
|
|
2034
|
+
value_func=tile_load_tuple_value_func,
|
|
2035
|
+
dispatch_func=tile_load_tuple_dispatch_func,
|
|
2036
|
+
defaults={"offset": None, "storage": "register"},
|
|
2052
2037
|
group="Tile Primitives",
|
|
2038
|
+
hidden=True,
|
|
2053
2039
|
export=False,
|
|
2054
2040
|
)
|
|
2055
2041
|
|
|
2056
2042
|
|
|
2057
|
-
def
|
|
2043
|
+
def tile_store_value_func(arg_types, arg_values):
|
|
2058
2044
|
# return generic type (for doc builds)
|
|
2059
2045
|
if arg_types is None:
|
|
2060
2046
|
return None
|
|
2061
2047
|
|
|
2062
|
-
|
|
2063
|
-
|
|
2048
|
+
a = arg_types["a"]
|
|
2049
|
+
t = arg_types["t"]
|
|
2064
2050
|
|
|
2065
|
-
|
|
2066
|
-
raise RuntimeError("tile_store() argument 0 must be an array")
|
|
2051
|
+
c = tile_unpack_offset(arg_types, a.ndim)
|
|
2067
2052
|
|
|
2068
|
-
if
|
|
2069
|
-
raise
|
|
2070
|
-
"
|
|
2053
|
+
if len(c) != a.ndim:
|
|
2054
|
+
raise ValueError(
|
|
2055
|
+
f"tile_store() 'a' argument must have {len(c)} dimensions, "
|
|
2056
|
+
f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
|
|
2071
2057
|
)
|
|
2072
2058
|
|
|
2073
|
-
if
|
|
2074
|
-
raise
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2059
|
+
if len(t.shape) != a.ndim:
|
|
2060
|
+
raise ValueError(
|
|
2061
|
+
f"tile_store() 'a' argument must have the same number of dimensions as the 't' argument, "
|
|
2062
|
+
f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
|
|
2063
|
+
)
|
|
2078
2064
|
|
|
2079
2065
|
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2080
|
-
raise
|
|
2066
|
+
raise TypeError(
|
|
2067
|
+
f"tile_store() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2068
|
+
)
|
|
2081
2069
|
|
|
2082
2070
|
return None
|
|
2083
2071
|
|
|
2084
2072
|
|
|
2073
|
+
def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2074
|
+
a = args["a"]
|
|
2075
|
+
t = args["t"]
|
|
2076
|
+
|
|
2077
|
+
offset = tile_unpack_offset(args, a.type.ndim)
|
|
2078
|
+
|
|
2079
|
+
func_args = (a, *offset, t)
|
|
2080
|
+
template_args = []
|
|
2081
|
+
|
|
2082
|
+
return (func_args, template_args)
|
|
2083
|
+
|
|
2084
|
+
|
|
2085
2085
|
add_builtin(
|
|
2086
2086
|
"tile_store",
|
|
2087
|
-
input_types={"a": array(dtype=Any), "
|
|
2088
|
-
value_func=
|
|
2087
|
+
input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
|
|
2088
|
+
value_func=tile_store_value_func,
|
|
2089
|
+
dispatch_func=tile_store_dispatch_func,
|
|
2090
|
+
defaults={"offset": None},
|
|
2089
2091
|
variadic=False,
|
|
2090
2092
|
skip_replay=True,
|
|
2091
|
-
doc="""
|
|
2093
|
+
doc="""Store a tile to a global memory array.
|
|
2092
2094
|
|
|
2093
2095
|
This method will cooperatively store a tile to global memory using all threads in the block.
|
|
2094
2096
|
|
|
2095
2097
|
:param a: The destination array in global memory
|
|
2096
|
-
:param
|
|
2097
|
-
:param
|
|
2098
|
+
:param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
|
|
2099
|
+
:param offset: Offset in the destination array (optional)""",
|
|
2098
2100
|
group="Tile Primitives",
|
|
2099
2101
|
export=False,
|
|
2100
2102
|
)
|
|
2101
2103
|
|
|
2102
|
-
|
|
2103
|
-
def tile_store_2d_value_func(arg_types, arg_values):
|
|
2104
|
-
# return generic type (for doc builds)
|
|
2105
|
-
if arg_types is None:
|
|
2106
|
-
return None
|
|
2107
|
-
|
|
2108
|
-
if len(arg_types) != 4:
|
|
2109
|
-
raise RuntimeError("tile_store() requires 4 positional args")
|
|
2110
|
-
|
|
2111
|
-
if not is_array(arg_types["a"]):
|
|
2112
|
-
raise RuntimeError("tile_store() argument 0 must be an array")
|
|
2113
|
-
|
|
2114
|
-
if arg_types["a"].ndim != 2:
|
|
2115
|
-
raise RuntimeError(
|
|
2116
|
-
"tile_load() argument 0 must be a 2-dimensional array if using the ``wp.tile_store(array, i, j, t)`` syntax."
|
|
2117
|
-
)
|
|
2118
|
-
|
|
2119
|
-
if not type_is_int(arg_types["i"]):
|
|
2120
|
-
raise RuntimeError("tile_store() argument 1 must be an integer")
|
|
2121
|
-
|
|
2122
|
-
if not type_is_int(arg_types["j"]):
|
|
2123
|
-
raise RuntimeError("tile_store() argument 2 must be an integer")
|
|
2124
|
-
|
|
2125
|
-
if not is_tile(arg_types["t"]):
|
|
2126
|
-
raise RuntimeError("tile_store() argument 3 must be a tile")
|
|
2127
|
-
|
|
2128
|
-
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2129
|
-
raise RuntimeError("tile_store() destination array must have same type as source tile")
|
|
2130
|
-
|
|
2131
|
-
return None
|
|
2132
|
-
|
|
2133
|
-
|
|
2104
|
+
# overload for scalar offset
|
|
2134
2105
|
add_builtin(
|
|
2135
2106
|
"tile_store",
|
|
2136
|
-
input_types={"a": array(dtype=Any), "
|
|
2137
|
-
value_func=
|
|
2107
|
+
input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
|
|
2108
|
+
value_func=tile_store_value_func,
|
|
2109
|
+
dispatch_func=tile_store_dispatch_func,
|
|
2110
|
+
defaults={"offset": None},
|
|
2138
2111
|
variadic=False,
|
|
2139
2112
|
skip_replay=True,
|
|
2140
|
-
doc="""Stores a tile to a global memory array.
|
|
2141
|
-
|
|
2142
|
-
This method will cooperatively store a tile to global memory using all threads in the block.
|
|
2143
|
-
|
|
2144
|
-
:param a: The destination array in global memory
|
|
2145
|
-
:param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
|
|
2146
|
-
:param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
|
|
2147
|
-
:param t: The source tile to store data from, must have the same dtype as the destination array""",
|
|
2148
2113
|
group="Tile Primitives",
|
|
2114
|
+
hidden=True,
|
|
2149
2115
|
export=False,
|
|
2150
2116
|
)
|
|
2151
2117
|
|
|
@@ -2153,130 +2119,219 @@ add_builtin(
|
|
|
2153
2119
|
def tile_atomic_add_value_func(arg_types, arg_values):
|
|
2154
2120
|
# return generic type (for doc builds)
|
|
2155
2121
|
if arg_types is None:
|
|
2156
|
-
return Tile(dtype=Any,
|
|
2122
|
+
return Tile(dtype=Any, shape=Any)
|
|
2123
|
+
|
|
2124
|
+
a = arg_types["a"]
|
|
2125
|
+
t = arg_types["t"]
|
|
2126
|
+
|
|
2127
|
+
c = tile_unpack_offset(arg_types, a.ndim)
|
|
2128
|
+
if len(c) != a.ndim:
|
|
2129
|
+
raise ValueError(
|
|
2130
|
+
f"tile_atomic_add() 'a' argument must have {len(c)} dimensions, "
|
|
2131
|
+
f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
|
|
2132
|
+
)
|
|
2133
|
+
|
|
2134
|
+
if a.ndim != len(t.shape):
|
|
2135
|
+
raise ValueError(
|
|
2136
|
+
f"tile_atomic_add() 'a' argument must have the same number of dimensions as the 't' argument, "
|
|
2137
|
+
f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
|
|
2138
|
+
)
|
|
2157
2139
|
|
|
2158
|
-
if
|
|
2159
|
-
raise
|
|
2140
|
+
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2141
|
+
raise TypeError(
|
|
2142
|
+
f"tile_atomic_add() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2143
|
+
)
|
|
2160
2144
|
|
|
2161
|
-
|
|
2162
|
-
raise RuntimeError("tile_atomic_add() argument 0 must be an array")
|
|
2145
|
+
return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape)
|
|
2163
2146
|
|
|
2164
|
-
if not type_is_int(arg_types["x"]):
|
|
2165
|
-
raise RuntimeError("tile_atomic_add() argument 1 must be an integer")
|
|
2166
2147
|
|
|
2167
|
-
|
|
2168
|
-
|
|
2148
|
+
def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2149
|
+
a = args["a"]
|
|
2150
|
+
t = args["t"]
|
|
2169
2151
|
|
|
2170
|
-
|
|
2171
|
-
raise RuntimeError("tile_atomic_add() argument 3 must be a tile")
|
|
2152
|
+
offset = tile_unpack_offset(args, a.type.ndim)
|
|
2172
2153
|
|
|
2173
|
-
|
|
2174
|
-
|
|
2154
|
+
func_args = (a, *offset, t)
|
|
2155
|
+
template_args = []
|
|
2175
2156
|
|
|
2176
|
-
return
|
|
2157
|
+
return (func_args, template_args)
|
|
2177
2158
|
|
|
2178
2159
|
|
|
2179
2160
|
add_builtin(
|
|
2180
2161
|
"tile_atomic_add",
|
|
2181
|
-
input_types={"a": array(dtype=Any), "
|
|
2162
|
+
input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
|
|
2182
2163
|
value_func=tile_atomic_add_value_func,
|
|
2183
|
-
|
|
2164
|
+
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2165
|
+
defaults={"offset": None},
|
|
2166
|
+
variadic=False,
|
|
2184
2167
|
skip_replay=True,
|
|
2185
|
-
doc="""Atomically add a tile to the array `a`, each element will be updated atomically.
|
|
2168
|
+
doc="""Atomically add a 1D tile to the array `a`, each element will be updated atomically.
|
|
2186
2169
|
|
|
2187
2170
|
:param a: Array in global memory, should have the same ``dtype`` as the input tile
|
|
2188
|
-
:param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
|
|
2189
|
-
:param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
|
|
2190
2171
|
:param t: Source tile to add to the destination array
|
|
2191
|
-
:
|
|
2172
|
+
:param offset: Offset in the destination array (optional)
|
|
2173
|
+
:returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements""",
|
|
2192
2174
|
group="Tile Primitives",
|
|
2193
2175
|
export=False,
|
|
2194
2176
|
)
|
|
2195
2177
|
|
|
2178
|
+
# overload for scalar offset
|
|
2179
|
+
add_builtin(
|
|
2180
|
+
"tile_atomic_add",
|
|
2181
|
+
input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
|
|
2182
|
+
value_func=tile_atomic_add_value_func,
|
|
2183
|
+
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2184
|
+
defaults={"offset": None},
|
|
2185
|
+
variadic=False,
|
|
2186
|
+
skip_replay=True,
|
|
2187
|
+
group="Tile Primitives",
|
|
2188
|
+
hidden=True,
|
|
2189
|
+
export=False,
|
|
2190
|
+
)
|
|
2191
|
+
|
|
2196
2192
|
|
|
2197
2193
|
def tile_view_value_func(arg_types, arg_values):
|
|
2198
2194
|
# return generic type (for doc builds)
|
|
2199
2195
|
if arg_types is None:
|
|
2200
|
-
return Tile(dtype=Any,
|
|
2196
|
+
return Tile(dtype=Any, shape=Any)
|
|
2201
2197
|
|
|
2202
2198
|
tile = arg_types["t"]
|
|
2199
|
+
offset = arg_types["offset"]
|
|
2203
2200
|
|
|
2204
|
-
if
|
|
2205
|
-
|
|
2206
|
-
else:
|
|
2207
|
-
m = arg_values["m"]
|
|
2201
|
+
if len(offset) > len(tile.shape):
|
|
2202
|
+
raise ValueError(f"tile_view() specified too many offset coordinates {len(offset)} > {len(tile.shape)}")
|
|
2208
2203
|
|
|
2209
|
-
if "
|
|
2210
|
-
|
|
2204
|
+
if "shape" in arg_values:
|
|
2205
|
+
# if shape is specified take it directly, e.g.:
|
|
2206
|
+
# tile_view(t, offset=(i,j), shape=(m,n))
|
|
2207
|
+
shape = arg_values["shape"]
|
|
2208
|
+
strides = tile.strides
|
|
2209
|
+
|
|
2210
|
+
if len(shape) != len(tile.shape):
|
|
2211
|
+
raise ValueError(
|
|
2212
|
+
f"tile_view() if shape is specified it must have same number of dimensions as source tile, expected {len(tile.shape)}, got {len(shape)}"
|
|
2213
|
+
)
|
|
2211
2214
|
else:
|
|
2212
|
-
|
|
2215
|
+
# if not specified, then take output shape from unspecified src dimensions
|
|
2216
|
+
# e.g.: tile[i] will return a whole row of a 2D tile
|
|
2217
|
+
shape = tile.shape[len(offset) :]
|
|
2218
|
+
strides = tile.strides[len(offset) :]
|
|
2213
2219
|
|
|
2214
|
-
|
|
2215
|
-
raise RuntimeError(
|
|
2216
|
-
f"Trying to view a tile subrange with dimensions ({m}, {n}) which is larger than source tile with dimensions ({tile.M}, {tile.N})"
|
|
2217
|
-
)
|
|
2220
|
+
assert len(shape) == len(strides)
|
|
2218
2221
|
|
|
2219
2222
|
# force source tile to shared memory
|
|
2220
2223
|
tile.storage = "shared"
|
|
2221
2224
|
|
|
2222
|
-
output = Tile(dtype=tile.dtype,
|
|
2225
|
+
output = Tile(dtype=tile.dtype, shape=shape, strides=strides, layout=tile.layout, storage="shared", owner=False)
|
|
2223
2226
|
return output
|
|
2224
2227
|
|
|
2225
2228
|
|
|
2226
2229
|
def tile_view_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2227
2230
|
tile = arg_values["t"]
|
|
2228
|
-
|
|
2229
|
-
|
|
2230
|
-
if "j" not in arg_values:
|
|
2231
|
-
j = warp.codegen.Var(label=None, type=int, constant=0)
|
|
2232
|
-
else:
|
|
2233
|
-
j = arg_values["j"]
|
|
2231
|
+
coord = arg_values["offset"]
|
|
2234
2232
|
|
|
2235
|
-
|
|
2236
|
-
|
|
2237
|
-
|
|
2233
|
+
# zero-pad coord to match source array
|
|
2234
|
+
view_coord = [0] * len(tile.type.shape)
|
|
2235
|
+
for i in range(len(coord)):
|
|
2236
|
+
view_coord[i] = coord[i]
|
|
2238
2237
|
|
|
2239
|
-
return ((tile,
|
|
2238
|
+
return ((tile, *view_coord), (return_type,))
|
|
2240
2239
|
|
|
2241
2240
|
|
|
2242
2241
|
add_builtin(
|
|
2243
2242
|
"tile_view",
|
|
2244
|
-
input_types={"t": Tile(dtype=Any,
|
|
2243
|
+
input_types={"t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...], "shape": Tuple[int, ...]},
|
|
2245
2244
|
value_func=tile_view_value_func,
|
|
2246
2245
|
dispatch_func=tile_view_dispatch_func,
|
|
2247
|
-
defaults={"
|
|
2248
|
-
variadic=
|
|
2249
|
-
doc="""Return a
|
|
2246
|
+
defaults={"shape": None},
|
|
2247
|
+
variadic=False,
|
|
2248
|
+
doc="""Return a slice of a given tile [offset, offset+shape], if shape is not specified it will be inferred from the unspecified offset dimensions.
|
|
2250
2249
|
|
|
2251
2250
|
:param t: Input tile to extract a subrange from
|
|
2252
|
-
:param
|
|
2253
|
-
:param
|
|
2254
|
-
:
|
|
2255
|
-
:param n: Size of the subrange to return along the second dimension
|
|
2256
|
-
:returns: A tile with dimensions (m,n) and the same datatype as the input tile""",
|
|
2251
|
+
:param offset: Offset in the source tile
|
|
2252
|
+
:param shape: Shape of the returned slice
|
|
2253
|
+
:returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
|
|
2257
2254
|
group="Tile Primitives",
|
|
2255
|
+
missing_grad=True,
|
|
2258
2256
|
export=False,
|
|
2259
2257
|
)
|
|
2260
2258
|
|
|
2261
2259
|
|
|
2262
2260
|
def tile_assign_value_func(arg_types, arg_values):
|
|
2263
|
-
|
|
2261
|
+
if arg_types is None:
|
|
2262
|
+
return None
|
|
2263
|
+
|
|
2264
|
+
# force the destination tile to shared memory
|
|
2265
|
+
arg_types["dst"].storage = "shared"
|
|
2264
2266
|
return None
|
|
2265
2267
|
|
|
2266
2268
|
|
|
2269
|
+
def tile_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2270
|
+
dst = args["dst"]
|
|
2271
|
+
src = args["src"]
|
|
2272
|
+
|
|
2273
|
+
offset = tile_unpack_offset(args, len(dst.type.shape))
|
|
2274
|
+
|
|
2275
|
+
func_args = (dst, src, *offset)
|
|
2276
|
+
template_args = []
|
|
2277
|
+
|
|
2278
|
+
return (func_args, template_args)
|
|
2279
|
+
|
|
2280
|
+
|
|
2267
2281
|
add_builtin(
|
|
2268
2282
|
"tile_assign",
|
|
2269
|
-
input_types={"dst": Tile(dtype=Any,
|
|
2283
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "src": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
|
|
2284
|
+
value_func=tile_assign_value_func,
|
|
2285
|
+
dispatch_func=tile_assign_dispatch_func,
|
|
2286
|
+
defaults={"offset": None},
|
|
2287
|
+
doc="""Assign a tile to a subrange of a destination tile.
|
|
2288
|
+
|
|
2289
|
+
:param dst: The destination tile to assign to
|
|
2290
|
+
:param src: The source tile to read values from
|
|
2291
|
+
:param offset: Offset in the destination tile to write to""",
|
|
2292
|
+
group="Tile Primitives",
|
|
2293
|
+
export=False,
|
|
2294
|
+
)
|
|
2295
|
+
|
|
2296
|
+
# handles expressions like tile[i,j] = 1.0
|
|
2297
|
+
add_builtin(
|
|
2298
|
+
"assign",
|
|
2299
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "src": Scalar},
|
|
2300
|
+
value_func=tile_assign_value_func,
|
|
2301
|
+
group="Tile Primitives",
|
|
2302
|
+
export=False,
|
|
2303
|
+
hidden=True,
|
|
2304
|
+
missing_grad=True,
|
|
2305
|
+
)
|
|
2306
|
+
|
|
2307
|
+
add_builtin(
|
|
2308
|
+
"assign",
|
|
2309
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "src": Scalar},
|
|
2270
2310
|
value_func=tile_assign_value_func,
|
|
2271
|
-
|
|
2272
|
-
|
|
2311
|
+
group="Tile Primitives",
|
|
2312
|
+
export=False,
|
|
2313
|
+
hidden=True,
|
|
2314
|
+
missing_grad=True,
|
|
2315
|
+
)
|
|
2316
|
+
|
|
2317
|
+
add_builtin(
|
|
2318
|
+
"assign",
|
|
2319
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "src": Scalar},
|
|
2320
|
+
value_func=tile_assign_value_func,
|
|
2321
|
+
group="Tile Primitives",
|
|
2322
|
+
export=False,
|
|
2323
|
+
hidden=True,
|
|
2324
|
+
missing_grad=True,
|
|
2325
|
+
)
|
|
2273
2326
|
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
:
|
|
2277
|
-
|
|
2327
|
+
add_builtin(
|
|
2328
|
+
"assign",
|
|
2329
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int, "src": Scalar},
|
|
2330
|
+
value_func=tile_assign_value_func,
|
|
2278
2331
|
group="Tile Primitives",
|
|
2279
2332
|
export=False,
|
|
2333
|
+
hidden=True,
|
|
2334
|
+
missing_grad=True,
|
|
2280
2335
|
)
|
|
2281
2336
|
|
|
2282
2337
|
|
|
@@ -2286,7 +2341,7 @@ def tile_value_func(arg_types, arg_values):
|
|
|
2286
2341
|
return Tile
|
|
2287
2342
|
|
|
2288
2343
|
if len(arg_types) != 1:
|
|
2289
|
-
raise
|
|
2344
|
+
raise TypeError(f"tile() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2290
2345
|
|
|
2291
2346
|
dtype = None
|
|
2292
2347
|
length = None
|
|
@@ -2294,11 +2349,12 @@ def tile_value_func(arg_types, arg_values):
|
|
|
2294
2349
|
if type_is_vector(arg_types["x"]):
|
|
2295
2350
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
2296
2351
|
length = arg_types["x"]._shape_[0]
|
|
2352
|
+
shape = (length, warp.codegen.options["block_dim"])
|
|
2297
2353
|
else:
|
|
2298
2354
|
dtype = arg_types["x"]
|
|
2299
|
-
|
|
2355
|
+
shape = (warp.codegen.options["block_dim"],)
|
|
2300
2356
|
|
|
2301
|
-
return Tile(dtype=dtype,
|
|
2357
|
+
return Tile(dtype=dtype, shape=shape, op="tile")
|
|
2302
2358
|
|
|
2303
2359
|
|
|
2304
2360
|
add_builtin(
|
|
@@ -2306,14 +2362,14 @@ add_builtin(
|
|
|
2306
2362
|
input_types={"x": Any},
|
|
2307
2363
|
value_func=tile_value_func,
|
|
2308
2364
|
variadic=True,
|
|
2309
|
-
doc="""
|
|
2365
|
+
doc="""Construct a new tile from per-thread kernel values.
|
|
2310
2366
|
|
|
2311
2367
|
This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
|
|
2312
2368
|
|
|
2313
2369
|
* If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
|
|
2314
2370
|
* If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
|
|
2315
2371
|
|
|
2316
|
-
:param x: A per-thread local value, e.g
|
|
2372
|
+
:param x: A per-thread local value, e.g. scalar, vector, or matrix.
|
|
2317
2373
|
:returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
|
|
2318
2374
|
|
|
2319
2375
|
This example shows how to create a linear sequence from thread variables:
|
|
@@ -2332,7 +2388,7 @@ add_builtin(
|
|
|
2332
2388
|
|
|
2333
2389
|
.. code-block:: text
|
|
2334
2390
|
|
|
2335
|
-
|
|
2391
|
+
[0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30] = tile(shape=(16), storage=register)
|
|
2336
2392
|
|
|
2337
2393
|
""",
|
|
2338
2394
|
group="Tile Primitives",
|
|
@@ -2346,38 +2402,40 @@ def untile_value_func(arg_types, arg_values):
|
|
|
2346
2402
|
return Scalar
|
|
2347
2403
|
|
|
2348
2404
|
if len(arg_types) != 1:
|
|
2349
|
-
raise
|
|
2405
|
+
raise TypeError(f"untile() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2350
2406
|
|
|
2351
2407
|
t = arg_types["a"]
|
|
2352
2408
|
|
|
2353
2409
|
if not is_tile(t):
|
|
2354
|
-
raise
|
|
2410
|
+
raise TypeError(f"untile() argument must be a tile, got {t!r}")
|
|
2355
2411
|
|
|
2356
|
-
if t.
|
|
2357
|
-
raise
|
|
2358
|
-
f"untile() argument
|
|
2412
|
+
if t.shape[-1] != warp.codegen.options["block_dim"]:
|
|
2413
|
+
raise ValueError(
|
|
2414
|
+
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
|
|
2359
2415
|
)
|
|
2360
2416
|
|
|
2361
|
-
if t.
|
|
2417
|
+
if len(t.shape) == 1:
|
|
2362
2418
|
return t.dtype
|
|
2363
|
-
elif t.
|
|
2364
|
-
return warp.types.vector(t.
|
|
2419
|
+
elif len(t.shape) == 2:
|
|
2420
|
+
return warp.types.vector(t.shape[0], t.dtype)
|
|
2421
|
+
else:
|
|
2422
|
+
raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
|
|
2365
2423
|
|
|
2366
2424
|
|
|
2367
2425
|
add_builtin(
|
|
2368
2426
|
"untile",
|
|
2369
|
-
input_types={"a": Tile(dtype=Any,
|
|
2427
|
+
input_types={"a": Tile(dtype=Any, shape=Any)},
|
|
2370
2428
|
value_func=untile_value_func,
|
|
2371
2429
|
variadic=True,
|
|
2372
|
-
doc="""Convert a
|
|
2430
|
+
doc="""Convert a tile back to per-thread values.
|
|
2373
2431
|
|
|
2374
2432
|
This function converts a block-wide tile back to per-thread values.
|
|
2375
2433
|
|
|
2376
|
-
* If the input tile is
|
|
2377
|
-
* If the input tile is
|
|
2434
|
+
* If the input tile is 1D, then the resulting value will be a per-thread scalar
|
|
2435
|
+
* If the input tile is 2D, then the resulting value will be a per-thread vector of length M
|
|
2378
2436
|
|
|
2379
2437
|
:param a: A tile with dimensions ``shape=(M, block_dim)``
|
|
2380
|
-
:returns: A single value per-thread with the same
|
|
2438
|
+
:returns: A single value per-thread with the same data type as the tile
|
|
2381
2439
|
|
|
2382
2440
|
This example shows how to create a linear sequence from thread variables:
|
|
2383
2441
|
|
|
@@ -2418,21 +2476,58 @@ def tile_extract_value_func(arg_types, arg_values):
|
|
|
2418
2476
|
if arg_types is None:
|
|
2419
2477
|
return Scalar
|
|
2420
2478
|
|
|
2421
|
-
|
|
2422
|
-
|
|
2423
|
-
|
|
2424
|
-
if not is_tile(arg_types["a"]):
|
|
2425
|
-
raise RuntimeError("tile_extract() argument 0 must be a tile")
|
|
2479
|
+
# force the input tile to shared memory
|
|
2480
|
+
arg_types["a"].storage = "shared"
|
|
2426
2481
|
|
|
2427
2482
|
return arg_types["a"].dtype
|
|
2428
2483
|
|
|
2429
2484
|
|
|
2430
2485
|
add_builtin(
|
|
2431
2486
|
"tile_extract",
|
|
2432
|
-
input_types={"a": Tile(dtype=Any,
|
|
2487
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "i": int},
|
|
2433
2488
|
value_func=tile_extract_value_func,
|
|
2434
|
-
variadic=
|
|
2435
|
-
doc="""
|
|
2489
|
+
variadic=False,
|
|
2490
|
+
doc="""Extract a single element from the tile and return it as a scalar type.
|
|
2491
|
+
|
|
2492
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2493
|
+
|
|
2494
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
2495
|
+
|
|
2496
|
+
:param a: Tile to extract the element from
|
|
2497
|
+
:param i: Coordinate of element on first dimension
|
|
2498
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
2499
|
+
group="Tile Primitives",
|
|
2500
|
+
hidden=True,
|
|
2501
|
+
export=False,
|
|
2502
|
+
)
|
|
2503
|
+
|
|
2504
|
+
|
|
2505
|
+
add_builtin(
|
|
2506
|
+
"tile_extract",
|
|
2507
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int},
|
|
2508
|
+
value_func=tile_extract_value_func,
|
|
2509
|
+
variadic=False,
|
|
2510
|
+
doc="""Extract a single element from the tile and return it as a scalar type.
|
|
2511
|
+
|
|
2512
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2513
|
+
|
|
2514
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
2515
|
+
|
|
2516
|
+
:param a: Tile to extract the element from
|
|
2517
|
+
:param i: Coordinate of element on first dimension
|
|
2518
|
+
:param j: Coordinate of element on the second dimension
|
|
2519
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
2520
|
+
group="Tile Primitives",
|
|
2521
|
+
hidden=True,
|
|
2522
|
+
export=False,
|
|
2523
|
+
)
|
|
2524
|
+
|
|
2525
|
+
add_builtin(
|
|
2526
|
+
"tile_extract",
|
|
2527
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int},
|
|
2528
|
+
value_func=tile_extract_value_func,
|
|
2529
|
+
variadic=False,
|
|
2530
|
+
doc="""Extract a single element from the tile and return it as a scalar type.
|
|
2436
2531
|
|
|
2437
2532
|
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2438
2533
|
|
|
@@ -2441,8 +2536,32 @@ add_builtin(
|
|
|
2441
2536
|
:param a: Tile to extract the element from
|
|
2442
2537
|
:param i: Coordinate of element on first dimension
|
|
2443
2538
|
:param j: Coordinate of element on the second dimension
|
|
2444
|
-
:
|
|
2539
|
+
:param k: Coordinate of element on the third dimension
|
|
2540
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
2445
2541
|
group="Tile Primitives",
|
|
2542
|
+
hidden=True,
|
|
2543
|
+
export=False,
|
|
2544
|
+
)
|
|
2545
|
+
|
|
2546
|
+
add_builtin(
|
|
2547
|
+
"tile_extract",
|
|
2548
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int},
|
|
2549
|
+
value_func=tile_extract_value_func,
|
|
2550
|
+
variadic=False,
|
|
2551
|
+
doc="""Extract a single element from the tile and return it as a scalar type.
|
|
2552
|
+
|
|
2553
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2554
|
+
|
|
2555
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
2556
|
+
|
|
2557
|
+
:param a: Tile to extract the element from
|
|
2558
|
+
:param i: Coordinate of element on first dimension
|
|
2559
|
+
:param j: Coordinate of element on the second dimension
|
|
2560
|
+
:param k: Coordinate of element on the third dimension
|
|
2561
|
+
:param l: Coordinate of element on the fourth dimension
|
|
2562
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
2563
|
+
group="Tile Primitives",
|
|
2564
|
+
hidden=True,
|
|
2446
2565
|
export=False,
|
|
2447
2566
|
)
|
|
2448
2567
|
|
|
@@ -2453,12 +2572,12 @@ def tile_transpose_value_func(arg_types, arg_values):
|
|
|
2453
2572
|
return Tile
|
|
2454
2573
|
|
|
2455
2574
|
if len(arg_types) != 1:
|
|
2456
|
-
raise
|
|
2575
|
+
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2457
2576
|
|
|
2458
2577
|
t = arg_types["a"]
|
|
2459
2578
|
|
|
2460
2579
|
if not is_tile(t):
|
|
2461
|
-
raise
|
|
2580
|
+
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
2462
2581
|
|
|
2463
2582
|
layout = None
|
|
2464
2583
|
|
|
@@ -2473,8 +2592,7 @@ def tile_transpose_value_func(arg_types, arg_values):
|
|
|
2473
2592
|
|
|
2474
2593
|
return Tile(
|
|
2475
2594
|
dtype=t.dtype,
|
|
2476
|
-
|
|
2477
|
-
N=t.M,
|
|
2595
|
+
shape=t.shape[::-1],
|
|
2478
2596
|
op="transpose",
|
|
2479
2597
|
storage=t.storage,
|
|
2480
2598
|
strides=t.strides[::-1],
|
|
@@ -2485,12 +2603,13 @@ def tile_transpose_value_func(arg_types, arg_values):
|
|
|
2485
2603
|
|
|
2486
2604
|
add_builtin(
|
|
2487
2605
|
"tile_transpose",
|
|
2488
|
-
input_types={"a": Tile(dtype=Any,
|
|
2606
|
+
input_types={"a": Tile(dtype=Any, shape=Any)},
|
|
2489
2607
|
value_func=tile_transpose_value_func,
|
|
2490
2608
|
variadic=True,
|
|
2491
2609
|
doc="""Transpose a tile.
|
|
2492
2610
|
|
|
2493
|
-
For shared memory tiles this operation will alias the input tile
|
|
2611
|
+
For shared memory tiles, this operation will alias the input tile.
|
|
2612
|
+
Register tiles will first be transferred to shared memory before transposition.
|
|
2494
2613
|
|
|
2495
2614
|
:param a: Tile to transpose with ``shape=(M,N)``
|
|
2496
2615
|
:returns: Tile with ``shape=(N,M)``""",
|
|
@@ -2504,41 +2623,36 @@ def tile_broadcast_value_func(arg_types, arg_values):
|
|
|
2504
2623
|
if arg_types is None:
|
|
2505
2624
|
return Tile
|
|
2506
2625
|
|
|
2507
|
-
if len(arg_types) != 3:
|
|
2508
|
-
raise RuntimeError("tile_broadcast() requires 1 positional args")
|
|
2509
|
-
|
|
2510
2626
|
t = arg_types["a"]
|
|
2511
|
-
m = arg_values["m"]
|
|
2512
|
-
n = arg_values["n"]
|
|
2513
2627
|
|
|
2514
|
-
|
|
2515
|
-
|
|
2628
|
+
# target shape and strides
|
|
2629
|
+
target_shape = tile_unpack_shape(arg_values)
|
|
2630
|
+
target_strides = [0] * len(target_shape)
|
|
2516
2631
|
|
|
2517
|
-
|
|
2518
|
-
if t.N == 1:
|
|
2519
|
-
stride_n = 0
|
|
2520
|
-
elif t.N == n:
|
|
2521
|
-
stride_n = t.strides[1]
|
|
2522
|
-
else:
|
|
2523
|
-
raise RuntimeError(
|
|
2524
|
-
f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
|
|
2525
|
-
)
|
|
2632
|
+
offset = len(target_shape) - len(t.shape)
|
|
2526
2633
|
|
|
2527
|
-
#
|
|
2528
|
-
|
|
2529
|
-
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
|
-
|
|
2634
|
+
# compute target strides
|
|
2635
|
+
for i in reversed(range(len(target_shape))):
|
|
2636
|
+
j = i - offset
|
|
2637
|
+
|
|
2638
|
+
if j < 0:
|
|
2639
|
+
target_strides[i] = 0
|
|
2640
|
+
else:
|
|
2641
|
+
# try to broadcast each dimension
|
|
2642
|
+
if t.shape[j] == 1:
|
|
2643
|
+
target_strides[i] = 0
|
|
2644
|
+
elif t.shape[j] == target_shape[i]:
|
|
2645
|
+
target_strides[i] = t.strides[j]
|
|
2646
|
+
else:
|
|
2647
|
+
raise ValueError(
|
|
2648
|
+
f"tile_broadcast() cannot broadcast dimension {t.shape[j]} into {target_shape[i]} at index {i}"
|
|
2649
|
+
)
|
|
2536
2650
|
|
|
2537
2651
|
# force the input tile to shared memory
|
|
2538
2652
|
t.storage = "shared"
|
|
2539
2653
|
|
|
2540
2654
|
tile_type = Tile(
|
|
2541
|
-
dtype=t.dtype,
|
|
2655
|
+
dtype=t.dtype, shape=target_shape, op="broadcast", storage=t.storage, strides=target_strides, owner=False
|
|
2542
2656
|
)
|
|
2543
2657
|
return tile_type
|
|
2544
2658
|
|
|
@@ -2547,8 +2661,8 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
|
|
|
2547
2661
|
tile = arg_values["a"]
|
|
2548
2662
|
|
|
2549
2663
|
template_args = []
|
|
2550
|
-
template_args.append(return_type.
|
|
2551
|
-
template_args.append(return_type.
|
|
2664
|
+
template_args.append(return_type.shape[0])
|
|
2665
|
+
template_args.append(return_type.shape[1])
|
|
2552
2666
|
template_args.append(return_type.strides[0])
|
|
2553
2667
|
template_args.append(return_type.strides[1])
|
|
2554
2668
|
|
|
@@ -2557,15 +2671,18 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
|
|
|
2557
2671
|
|
|
2558
2672
|
add_builtin(
|
|
2559
2673
|
"tile_broadcast",
|
|
2560
|
-
input_types={"a": Tile(dtype=Any,
|
|
2674
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "shape": Tuple[int, ...]},
|
|
2561
2675
|
value_func=tile_broadcast_value_func,
|
|
2562
2676
|
dispatch_func=tile_broadcast_dispatch_func,
|
|
2563
|
-
variadic=
|
|
2677
|
+
variadic=False,
|
|
2564
2678
|
doc="""Broadcast a tile.
|
|
2565
2679
|
|
|
2566
|
-
This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n)
|
|
2680
|
+
This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n).
|
|
2681
|
+
|
|
2682
|
+
Broadcasting follows NumPy broadcast rules.
|
|
2567
2683
|
|
|
2568
2684
|
:param a: Tile to broadcast
|
|
2685
|
+
:param shape: The shape to broadcast to
|
|
2569
2686
|
:returns: Tile with broadcast ``shape=(m, n)``""",
|
|
2570
2687
|
group="Tile Primitives",
|
|
2571
2688
|
export=False,
|
|
@@ -2575,19 +2692,10 @@ add_builtin(
|
|
|
2575
2692
|
def tile_matmul_value_func(arg_types, arg_values):
|
|
2576
2693
|
# return generic type (for doc builds)
|
|
2577
2694
|
if arg_types is None:
|
|
2578
|
-
return Tile(dtype=Any,
|
|
2695
|
+
return Tile(dtype=Any, shape=Any)
|
|
2579
2696
|
|
|
2580
2697
|
if len(arg_types) != 3:
|
|
2581
|
-
raise
|
|
2582
|
-
|
|
2583
|
-
if not is_tile(arg_types["a"]):
|
|
2584
|
-
raise RuntimeError("tile_matmul() argument 0 must be a tile")
|
|
2585
|
-
|
|
2586
|
-
if not is_tile(arg_types["b"]):
|
|
2587
|
-
raise RuntimeError("tile_matmul() argument 1 must be an tile")
|
|
2588
|
-
|
|
2589
|
-
if not isinstance(arg_types["out"], Tile):
|
|
2590
|
-
raise RuntimeError("tile_matmul() output argument must be a tile")
|
|
2698
|
+
raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
|
|
2591
2699
|
|
|
2592
2700
|
return None
|
|
2593
2701
|
|
|
@@ -2622,17 +2730,17 @@ add_builtin(
|
|
|
2622
2730
|
def tile_sum_value_func(arg_types, arg_values):
|
|
2623
2731
|
# return generic type (for doc builds)
|
|
2624
2732
|
if arg_types is None:
|
|
2625
|
-
return Tile(dtype=Any,
|
|
2733
|
+
return Tile(dtype=Any, shape=(1,))
|
|
2626
2734
|
|
|
2627
2735
|
if len(arg_types) != 1:
|
|
2628
|
-
raise
|
|
2736
|
+
raise TypeError(f"tile_sum() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2629
2737
|
|
|
2630
2738
|
a = arg_types["a"]
|
|
2631
2739
|
|
|
2632
2740
|
if not is_tile(a):
|
|
2633
|
-
raise
|
|
2741
|
+
raise TypeError(f"tile_sum() argument must be a tile, got {a!r}")
|
|
2634
2742
|
|
|
2635
|
-
return Tile(dtype=a.dtype,
|
|
2743
|
+
return Tile(dtype=a.dtype, shape=(1,), op="sum")
|
|
2636
2744
|
|
|
2637
2745
|
|
|
2638
2746
|
add_builtin(
|
|
@@ -2643,7 +2751,7 @@ add_builtin(
|
|
|
2643
2751
|
doc="""Cooperatively compute the sum of the tile elements using all threads in the block.
|
|
2644
2752
|
|
|
2645
2753
|
:param a: The tile to compute the sum of
|
|
2646
|
-
:returns: A single-element tile
|
|
2754
|
+
:returns: A single-element tile holding the sum
|
|
2647
2755
|
|
|
2648
2756
|
Example:
|
|
2649
2757
|
|
|
@@ -2652,7 +2760,7 @@ add_builtin(
|
|
|
2652
2760
|
@wp.kernel
|
|
2653
2761
|
def compute():
|
|
2654
2762
|
|
|
2655
|
-
t = wp.tile_ones(dtype=float,
|
|
2763
|
+
t = wp.tile_ones(dtype=float, shape=(16, 16))
|
|
2656
2764
|
s = wp.tile_sum(t)
|
|
2657
2765
|
|
|
2658
2766
|
print(s)
|
|
@@ -2663,7 +2771,7 @@ add_builtin(
|
|
|
2663
2771
|
|
|
2664
2772
|
.. code-block:: text
|
|
2665
2773
|
|
|
2666
|
-
tile(
|
|
2774
|
+
[256] = tile(shape=(1), storage=register)
|
|
2667
2775
|
|
|
2668
2776
|
""",
|
|
2669
2777
|
group="Tile Primitives",
|
|
@@ -2674,17 +2782,17 @@ add_builtin(
|
|
|
2674
2782
|
def tile_min_value_func(arg_types, arg_values):
|
|
2675
2783
|
# return generic type (for doc builds)
|
|
2676
2784
|
if arg_types is None:
|
|
2677
|
-
return Tile(dtype=Any,
|
|
2785
|
+
return Tile(dtype=Any, shape=(1,))
|
|
2678
2786
|
|
|
2679
2787
|
if len(arg_types) != 1:
|
|
2680
|
-
raise
|
|
2788
|
+
raise TypeError(f"tile_min() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2681
2789
|
|
|
2682
2790
|
a = arg_types["a"]
|
|
2683
2791
|
|
|
2684
2792
|
if not is_tile(a):
|
|
2685
|
-
raise
|
|
2793
|
+
raise TypeError(f"tile_min() argument must be a tile, got {a!r}")
|
|
2686
2794
|
|
|
2687
|
-
return Tile(dtype=a.dtype,
|
|
2795
|
+
return Tile(dtype=a.dtype, shape=(1,), op="min")
|
|
2688
2796
|
|
|
2689
2797
|
|
|
2690
2798
|
add_builtin(
|
|
@@ -2695,7 +2803,7 @@ add_builtin(
|
|
|
2695
2803
|
doc="""Cooperatively compute the minimum of the tile elements using all threads in the block.
|
|
2696
2804
|
|
|
2697
2805
|
:param a: The tile to compute the minimum of
|
|
2698
|
-
:returns: A single-element tile
|
|
2806
|
+
:returns: A single-element tile holding the minimum value
|
|
2699
2807
|
|
|
2700
2808
|
Example:
|
|
2701
2809
|
|
|
@@ -2716,7 +2824,7 @@ add_builtin(
|
|
|
2716
2824
|
|
|
2717
2825
|
.. code-block:: text
|
|
2718
2826
|
|
|
2719
|
-
tile(
|
|
2827
|
+
[64] = tile(shape=(1), storage=register)
|
|
2720
2828
|
|
|
2721
2829
|
""",
|
|
2722
2830
|
group="Tile Primitives",
|
|
@@ -2727,28 +2835,28 @@ add_builtin(
|
|
|
2727
2835
|
def tile_max_value_func(arg_types, arg_values):
|
|
2728
2836
|
# return generic type (for doc builds)
|
|
2729
2837
|
if arg_types is None:
|
|
2730
|
-
return Tile(dtype=Any,
|
|
2838
|
+
return Tile(dtype=Any, shape=(1,))
|
|
2731
2839
|
|
|
2732
2840
|
if len(arg_types) != 1:
|
|
2733
|
-
raise
|
|
2841
|
+
raise TypeError(f"tile_max() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2734
2842
|
|
|
2735
2843
|
a = arg_types["a"]
|
|
2736
2844
|
|
|
2737
2845
|
if not is_tile(a):
|
|
2738
|
-
raise
|
|
2846
|
+
raise TypeError(f"tile_max() argument must be a tile, got {a!r}")
|
|
2739
2847
|
|
|
2740
|
-
return Tile(dtype=a.dtype,
|
|
2848
|
+
return Tile(dtype=a.dtype, shape=(1,), op="min")
|
|
2741
2849
|
|
|
2742
2850
|
|
|
2743
2851
|
add_builtin(
|
|
2744
2852
|
"tile_max",
|
|
2745
|
-
input_types={"a": Tile},
|
|
2853
|
+
input_types={"a": Tile(dtype=Any, shape=Any)},
|
|
2746
2854
|
value_func=tile_max_value_func,
|
|
2747
|
-
variadic=
|
|
2855
|
+
variadic=False,
|
|
2748
2856
|
doc="""Cooperatively compute the maximum of the tile elements using all threads in the block.
|
|
2749
2857
|
|
|
2750
2858
|
:param a: The tile to compute the maximum from
|
|
2751
|
-
:returns: A single-element tile
|
|
2859
|
+
:returns: A single-element tile holding the maximum value
|
|
2752
2860
|
|
|
2753
2861
|
Example:
|
|
2754
2862
|
|
|
@@ -2768,7 +2876,7 @@ add_builtin(
|
|
|
2768
2876
|
|
|
2769
2877
|
.. code-block:: text
|
|
2770
2878
|
|
|
2771
|
-
tile(
|
|
2879
|
+
[127] = tile(shape=(1), storage=register)
|
|
2772
2880
|
|
|
2773
2881
|
""",
|
|
2774
2882
|
group="Tile Primitives",
|
|
@@ -2779,15 +2887,14 @@ add_builtin(
|
|
|
2779
2887
|
# does type propagation for load()
|
|
2780
2888
|
def tile_reduce_value_func(arg_types, arg_values):
|
|
2781
2889
|
if arg_types is None:
|
|
2782
|
-
return Tile(dtype=Any,
|
|
2890
|
+
return Tile(dtype=Any, shape=(1,))
|
|
2783
2891
|
|
|
2784
2892
|
a = arg_types["a"]
|
|
2785
2893
|
|
|
2786
|
-
# check all args are tiles
|
|
2787
2894
|
if not is_tile(a):
|
|
2788
|
-
raise
|
|
2895
|
+
raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
|
|
2789
2896
|
|
|
2790
|
-
return Tile(dtype=a.dtype,
|
|
2897
|
+
return Tile(dtype=a.dtype, shape=(1,), op="reduce")
|
|
2791
2898
|
|
|
2792
2899
|
|
|
2793
2900
|
def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
@@ -2798,7 +2905,7 @@ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any,
|
|
|
2798
2905
|
|
|
2799
2906
|
add_builtin(
|
|
2800
2907
|
"tile_reduce",
|
|
2801
|
-
input_types={"op": Callable, "a": Tile(dtype=Any,
|
|
2908
|
+
input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
|
|
2802
2909
|
value_func=tile_reduce_value_func,
|
|
2803
2910
|
native_func="tile_reduce",
|
|
2804
2911
|
doc="""Apply a custom reduction operator across the tile.
|
|
@@ -2806,8 +2913,8 @@ add_builtin(
|
|
|
2806
2913
|
This function cooperatively performs a reduction using the provided operator across the tile.
|
|
2807
2914
|
|
|
2808
2915
|
:param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
|
|
2809
|
-
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's
|
|
2810
|
-
:returns: A single-element tile with
|
|
2916
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
|
|
2917
|
+
:returns: A single-element tile with the same data type as the input tile.
|
|
2811
2918
|
|
|
2812
2919
|
Example:
|
|
2813
2920
|
|
|
@@ -2827,7 +2934,7 @@ add_builtin(
|
|
|
2827
2934
|
|
|
2828
2935
|
.. code-block:: text
|
|
2829
2936
|
|
|
2830
|
-
tile(
|
|
2937
|
+
[362880] = tile(shape=(1), storage=register)
|
|
2831
2938
|
""",
|
|
2832
2939
|
group="Tile Primitives",
|
|
2833
2940
|
export=False,
|
|
@@ -2839,26 +2946,19 @@ add_builtin(
|
|
|
2839
2946
|
# does type propagation for load()
|
|
2840
2947
|
def tile_unary_map_value_func(arg_types, arg_values):
|
|
2841
2948
|
if arg_types is None:
|
|
2842
|
-
return Tile(dtype=Any,
|
|
2949
|
+
return Tile(dtype=Any, shape=Any)
|
|
2843
2950
|
|
|
2844
2951
|
a = arg_types["a"]
|
|
2845
2952
|
|
|
2846
|
-
# check all args are tiles
|
|
2847
2953
|
if not is_tile(a):
|
|
2848
|
-
raise
|
|
2954
|
+
raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
|
|
2849
2955
|
|
|
2850
2956
|
return TileUnaryMap(a)
|
|
2851
2957
|
|
|
2852
2958
|
|
|
2853
|
-
# def tile_map_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2854
|
-
# func_args = (args["op"], *args["args"])
|
|
2855
|
-
# template_args = ()
|
|
2856
|
-
# return (func_args, template_args)
|
|
2857
|
-
|
|
2858
|
-
|
|
2859
2959
|
add_builtin(
|
|
2860
2960
|
"tile_map",
|
|
2861
|
-
input_types={"op": Callable, "a": Tile(dtype=Any,
|
|
2961
|
+
input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
|
|
2862
2962
|
value_func=tile_unary_map_value_func,
|
|
2863
2963
|
# dispatch_func=tile_map_dispatch_func,
|
|
2864
2964
|
# variadic=True,
|
|
@@ -2868,8 +2968,8 @@ add_builtin(
|
|
|
2868
2968
|
This function cooperatively applies a unary function to each element of the tile using all threads in the block.
|
|
2869
2969
|
|
|
2870
2970
|
:param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
|
|
2871
|
-
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's
|
|
2872
|
-
:returns: A tile with the same dimensions and
|
|
2971
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
|
|
2972
|
+
:returns: A tile with the same dimensions and data type as the input tile.
|
|
2873
2973
|
|
|
2874
2974
|
Example:
|
|
2875
2975
|
|
|
@@ -2889,7 +2989,7 @@ add_builtin(
|
|
|
2889
2989
|
|
|
2890
2990
|
.. code-block:: text
|
|
2891
2991
|
|
|
2892
|
-
|
|
2992
|
+
[0 0.0998334 0.198669 0.29552 0.389418 0.479426 0.564642 0.644218 0.717356 0.783327] = tile(shape=(10), storage=register)
|
|
2893
2993
|
""",
|
|
2894
2994
|
group="Tile Primitives",
|
|
2895
2995
|
export=False,
|
|
@@ -2898,34 +2998,37 @@ add_builtin(
|
|
|
2898
2998
|
|
|
2899
2999
|
def tile_binary_map_value_func(arg_types, arg_values):
|
|
2900
3000
|
if arg_types is None:
|
|
2901
|
-
return Tile(dtype=Any,
|
|
3001
|
+
return Tile(dtype=Any, shape=Any)
|
|
2902
3002
|
|
|
2903
3003
|
a = arg_types["a"]
|
|
2904
3004
|
b = arg_types["b"]
|
|
2905
3005
|
|
|
2906
3006
|
# check all args are tiles
|
|
2907
3007
|
if not is_tile(a):
|
|
2908
|
-
raise
|
|
3008
|
+
raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
|
|
2909
3009
|
|
|
2910
3010
|
if not is_tile(b):
|
|
2911
|
-
raise
|
|
3011
|
+
raise TypeError(f"tile_map() 'b' argument must be a tile, got {b!r}")
|
|
2912
3012
|
|
|
2913
|
-
#
|
|
3013
|
+
# ensure types equal
|
|
2914
3014
|
if not types_equal(a.dtype, b.dtype):
|
|
2915
|
-
raise
|
|
3015
|
+
raise TypeError(f"tile_map() arguments must have the same dtype, got {a.dtype} and {b.dtype}")
|
|
2916
3016
|
|
|
2917
|
-
if a.
|
|
2918
|
-
raise
|
|
3017
|
+
if len(a.shape) != len(b.shape):
|
|
3018
|
+
raise ValueError(
|
|
3019
|
+
f"tile_map() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
|
|
3020
|
+
)
|
|
2919
3021
|
|
|
2920
|
-
|
|
2921
|
-
|
|
3022
|
+
for i in range(len(a.shape)):
|
|
3023
|
+
if a.shape[i] != b.shape[i]:
|
|
3024
|
+
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape[i]} and {b.shape[i]}")
|
|
2922
3025
|
|
|
2923
3026
|
return TileBinaryMap(a, b)
|
|
2924
3027
|
|
|
2925
3028
|
|
|
2926
3029
|
add_builtin(
|
|
2927
3030
|
"tile_map",
|
|
2928
|
-
input_types={"op": Callable, "a": Tile(dtype=Any,
|
|
3031
|
+
input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
2929
3032
|
value_func=tile_binary_map_value_func,
|
|
2930
3033
|
# dispatch_func=tile_map_dispatch_func,
|
|
2931
3034
|
# variadic=True,
|
|
@@ -2948,7 +3051,7 @@ add_builtin(
|
|
|
2948
3051
|
def compute():
|
|
2949
3052
|
|
|
2950
3053
|
a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
|
|
2951
|
-
b = wp.tile_ones(
|
|
3054
|
+
b = wp.tile_ones(shape=10, dtype=float)
|
|
2952
3055
|
|
|
2953
3056
|
s = wp.tile_map(wp.add, a, b)
|
|
2954
3057
|
|
|
@@ -2960,7 +3063,7 @@ add_builtin(
|
|
|
2960
3063
|
|
|
2961
3064
|
.. code-block:: text
|
|
2962
3065
|
|
|
2963
|
-
|
|
3066
|
+
[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9] = tile(shape=(10), storage=register)""",
|
|
2964
3067
|
group="Tile Primitives",
|
|
2965
3068
|
export=False,
|
|
2966
3069
|
)
|
|
@@ -3075,6 +3178,18 @@ add_builtin(
|
|
|
3075
3178
|
)
|
|
3076
3179
|
|
|
3077
3180
|
|
|
3181
|
+
def mlp_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
3182
|
+
warp.utils.warn(
|
|
3183
|
+
"wp.mlp() is deprecated and will be removed in a future\nversion. Use tile primitives instead.",
|
|
3184
|
+
category=DeprecationWarning,
|
|
3185
|
+
)
|
|
3186
|
+
|
|
3187
|
+
func_args = tuple(args.values())
|
|
3188
|
+
template_args = ()
|
|
3189
|
+
|
|
3190
|
+
return (func_args, template_args)
|
|
3191
|
+
|
|
3192
|
+
|
|
3078
3193
|
add_builtin(
|
|
3079
3194
|
"mlp",
|
|
3080
3195
|
input_types={
|
|
@@ -3086,9 +3201,13 @@ add_builtin(
|
|
|
3086
3201
|
"out": array(dtype=float, ndim=2),
|
|
3087
3202
|
},
|
|
3088
3203
|
value_type=None,
|
|
3204
|
+
dispatch_func=mlp_dispatch_func,
|
|
3089
3205
|
skip_replay=True,
|
|
3090
3206
|
doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
|
|
3091
3207
|
|
|
3208
|
+
.. deprecated:: 1.6
|
|
3209
|
+
Use :doc:`tile primitives </modules/tiles>` instead.
|
|
3210
|
+
|
|
3092
3211
|
:param weights: A layer's network weights with dimensions ``(m, n)``.
|
|
3093
3212
|
:param bias: An array with dimensions ``(n)``.
|
|
3094
3213
|
:param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
|
|
@@ -4054,7 +4173,7 @@ add_builtin(
|
|
|
4054
4173
|
input_types={"state": uint32},
|
|
4055
4174
|
value_type=int,
|
|
4056
4175
|
group="Random",
|
|
4057
|
-
doc="Return a random integer in the range [
|
|
4176
|
+
doc="Return a random integer in the range [-2^31, 2^31).",
|
|
4058
4177
|
)
|
|
4059
4178
|
add_builtin(
|
|
4060
4179
|
"randi",
|
|
@@ -5008,6 +5127,43 @@ add_builtin(
|
|
|
5008
5127
|
)
|
|
5009
5128
|
|
|
5010
5129
|
|
|
5130
|
+
# implements vector[idx] += scalar
|
|
5131
|
+
add_builtin(
|
|
5132
|
+
"augassign_add",
|
|
5133
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5134
|
+
value_type=None,
|
|
5135
|
+
hidden=True,
|
|
5136
|
+
group="Utility",
|
|
5137
|
+
)
|
|
5138
|
+
|
|
5139
|
+
# implements quaternion[idx] += scalar
|
|
5140
|
+
add_builtin(
|
|
5141
|
+
"augassign_add",
|
|
5142
|
+
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5143
|
+
value_type=None,
|
|
5144
|
+
hidden=True,
|
|
5145
|
+
group="Utility",
|
|
5146
|
+
)
|
|
5147
|
+
|
|
5148
|
+
# implements vector[idx] -= scalar
|
|
5149
|
+
add_builtin(
|
|
5150
|
+
"augassign_sub",
|
|
5151
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5152
|
+
value_type=None,
|
|
5153
|
+
hidden=True,
|
|
5154
|
+
group="Utility",
|
|
5155
|
+
)
|
|
5156
|
+
|
|
5157
|
+
# implements quaternion[idx] -= scalar
|
|
5158
|
+
add_builtin(
|
|
5159
|
+
"augassign_sub",
|
|
5160
|
+
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5161
|
+
value_type=None,
|
|
5162
|
+
hidden=True,
|
|
5163
|
+
group="Utility",
|
|
5164
|
+
)
|
|
5165
|
+
|
|
5166
|
+
|
|
5011
5167
|
def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
5012
5168
|
mat_type = arg_types["a"]
|
|
5013
5169
|
row_type = mat_type._wp_row_type_
|
|
@@ -5057,22 +5213,42 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
|
5057
5213
|
return mat_size == vec_size and mat_type == vec_type
|
|
5058
5214
|
|
|
5059
5215
|
|
|
5060
|
-
# implements matrix[i,j] = scalar
|
|
5216
|
+
# implements matrix[i,j] = scalar
|
|
5217
|
+
add_builtin(
|
|
5218
|
+
"assign",
|
|
5219
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5220
|
+
value_func=matrix_assign_value_func,
|
|
5221
|
+
hidden=True,
|
|
5222
|
+
group="Utility",
|
|
5223
|
+
)
|
|
5224
|
+
|
|
5225
|
+
|
|
5226
|
+
# implements matrix[i] = vector
|
|
5227
|
+
add_builtin(
|
|
5228
|
+
"assign",
|
|
5229
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5230
|
+
constraint=matrix_vector_sametype,
|
|
5231
|
+
value_func=matrix_assign_value_func,
|
|
5232
|
+
hidden=True,
|
|
5233
|
+
group="Utility",
|
|
5234
|
+
)
|
|
5235
|
+
|
|
5236
|
+
|
|
5237
|
+
# implements matrix[i,j] += scalar
|
|
5061
5238
|
add_builtin(
|
|
5062
|
-
"
|
|
5239
|
+
"augassign_add",
|
|
5063
5240
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5064
|
-
|
|
5241
|
+
value_type=None,
|
|
5065
5242
|
hidden=True,
|
|
5066
5243
|
group="Utility",
|
|
5067
5244
|
)
|
|
5068
5245
|
|
|
5069
5246
|
|
|
5070
|
-
# implements matrix[i]
|
|
5247
|
+
# implements matrix[i,j] -= scalar
|
|
5071
5248
|
add_builtin(
|
|
5072
|
-
"
|
|
5073
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "
|
|
5074
|
-
|
|
5075
|
-
value_func=matrix_assign_value_func,
|
|
5249
|
+
"augassign_sub",
|
|
5250
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5251
|
+
value_type=None,
|
|
5076
5252
|
hidden=True,
|
|
5077
5253
|
group="Utility",
|
|
5078
5254
|
)
|
|
@@ -5644,19 +5820,19 @@ add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.boo
|
|
|
5644
5820
|
# Tile operators
|
|
5645
5821
|
def tile_unary_value_func(arg_types, arg_values):
|
|
5646
5822
|
if arg_types is None:
|
|
5647
|
-
return Tile(dtype=Any,
|
|
5823
|
+
return Tile(dtype=Any, shape=Any)
|
|
5648
5824
|
|
|
5649
5825
|
t = arg_types["x"]
|
|
5650
5826
|
|
|
5651
5827
|
if not is_tile(t):
|
|
5652
|
-
raise
|
|
5828
|
+
raise TypeError(f"Expected tile for unary expression, got {t}")
|
|
5653
5829
|
|
|
5654
5830
|
return TileUnaryMap(t)
|
|
5655
5831
|
|
|
5656
5832
|
|
|
5657
5833
|
def tile_scalar_mul_value_func(arg_types, arg_values):
|
|
5658
5834
|
if arg_types is None:
|
|
5659
|
-
return Tile(dtype=Any,
|
|
5835
|
+
return Tile(dtype=Any, shape=Any)
|
|
5660
5836
|
|
|
5661
5837
|
x = arg_types["x"]
|
|
5662
5838
|
y = arg_types["y"]
|
|
@@ -5664,25 +5840,21 @@ def tile_scalar_mul_value_func(arg_types, arg_values):
|
|
|
5664
5840
|
# tile*scalar
|
|
5665
5841
|
if is_tile(x):
|
|
5666
5842
|
if x.dtype != y:
|
|
5667
|
-
raise
|
|
5668
|
-
"Scalar factor should have the same type as tile for tile*scalar, tile type: {x} scalar type: {y}"
|
|
5669
|
-
)
|
|
5843
|
+
raise TypeError(f"Scalar factor type {y} does not match tile type {x.dtype} for tile*scalar")
|
|
5670
5844
|
|
|
5671
|
-
return TileBinaryMap(x, TileConstant(y, x.
|
|
5845
|
+
return TileBinaryMap(x, TileConstant(y, x.shape))
|
|
5672
5846
|
|
|
5673
5847
|
# scalar*tile
|
|
5674
5848
|
if is_tile(y):
|
|
5675
5849
|
if y.dtype != x:
|
|
5676
|
-
raise
|
|
5677
|
-
"Scalar factor should have the same type as tile for scalar*tile, tile type: {x} scalar type: {y}"
|
|
5678
|
-
)
|
|
5850
|
+
raise TypeError(f"Scalar factor type {x} does not match tile type {y.dtype} for scalar*tile")
|
|
5679
5851
|
|
|
5680
|
-
return TileBinaryMap(TileConstant(x, y.
|
|
5852
|
+
return TileBinaryMap(TileConstant(x, y.shape), y)
|
|
5681
5853
|
|
|
5682
5854
|
|
|
5683
5855
|
add_builtin(
|
|
5684
5856
|
"neg",
|
|
5685
|
-
input_types={"x": Tile(dtype=Any,
|
|
5857
|
+
input_types={"x": Tile(dtype=Any, shape=Any)},
|
|
5686
5858
|
value_func=tile_unary_value_func,
|
|
5687
5859
|
doc="Negate each element of a tile",
|
|
5688
5860
|
export=False,
|
|
@@ -5692,7 +5864,7 @@ add_builtin(
|
|
|
5692
5864
|
|
|
5693
5865
|
add_builtin(
|
|
5694
5866
|
"add",
|
|
5695
|
-
input_types={"a": Tile(dtype=Any,
|
|
5867
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
5696
5868
|
value_func=tile_binary_map_value_func,
|
|
5697
5869
|
# dispatch_func=tile_map_dispatch_func,
|
|
5698
5870
|
# variadic=True,
|
|
@@ -5702,9 +5874,22 @@ add_builtin(
|
|
|
5702
5874
|
export=False,
|
|
5703
5875
|
)
|
|
5704
5876
|
|
|
5877
|
+
add_builtin(
|
|
5878
|
+
"sub",
|
|
5879
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
5880
|
+
value_func=tile_binary_map_value_func,
|
|
5881
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
5882
|
+
# variadic=True,
|
|
5883
|
+
native_func="tile_sub",
|
|
5884
|
+
doc="Subtract each element b from a",
|
|
5885
|
+
group="Tile Primitives",
|
|
5886
|
+
export=False,
|
|
5887
|
+
)
|
|
5888
|
+
|
|
5889
|
+
|
|
5705
5890
|
add_builtin(
|
|
5706
5891
|
"mul",
|
|
5707
|
-
input_types={"x": Tile(dtype=Any,
|
|
5892
|
+
input_types={"x": Tile(dtype=Any, shape=Any), "y": Scalar},
|
|
5708
5893
|
value_func=tile_scalar_mul_value_func,
|
|
5709
5894
|
doc="Multiply each element of a tile by a scalar",
|
|
5710
5895
|
export=False,
|
|
@@ -5714,7 +5899,7 @@ add_builtin(
|
|
|
5714
5899
|
|
|
5715
5900
|
add_builtin(
|
|
5716
5901
|
"mul",
|
|
5717
|
-
input_types={"x": Scalar, "y": Tile(dtype=Any,
|
|
5902
|
+
input_types={"x": Scalar, "y": Tile(dtype=Any, shape=Any)},
|
|
5718
5903
|
value_func=tile_scalar_mul_value_func,
|
|
5719
5904
|
doc="Multiply each element of a tile by a scalar",
|
|
5720
5905
|
export=False,
|
|
@@ -5723,6 +5908,70 @@ add_builtin(
|
|
|
5723
5908
|
)
|
|
5724
5909
|
|
|
5725
5910
|
|
|
5911
|
+
def tile_diag_add_value_func(arg_types, arg_values):
|
|
5912
|
+
if arg_types is None:
|
|
5913
|
+
return Tile(dtype=Any, shape=Any)
|
|
5914
|
+
|
|
5915
|
+
a = arg_types["a"]
|
|
5916
|
+
d = arg_types["d"]
|
|
5917
|
+
|
|
5918
|
+
if not is_tile(a):
|
|
5919
|
+
raise TypeError(f"tile_diag_add() 'a' argument must be a tile, got {a!r}")
|
|
5920
|
+
|
|
5921
|
+
if not is_tile(d):
|
|
5922
|
+
raise TypeError(f"tile_diag_add() 'd' argument must be a tile, got {d!r}")
|
|
5923
|
+
|
|
5924
|
+
if not types_equal(a.dtype, d.dtype):
|
|
5925
|
+
raise TypeError(f"tile_diag_add() arguments must have the same dtype, got {a.dtype} and {d.dtype}")
|
|
5926
|
+
|
|
5927
|
+
if len(a.shape) != 2:
|
|
5928
|
+
raise TypeError("tile_diag_add() argument 'a' must be a 2D tile")
|
|
5929
|
+
|
|
5930
|
+
if len(d.shape) != 1:
|
|
5931
|
+
raise TypeError("tile_diag_add() argument 'd' must be a 1D tile")
|
|
5932
|
+
|
|
5933
|
+
if a.shape[0] != a.shape[1]:
|
|
5934
|
+
raise ValueError("tile_diag_add() 'a' argument must be square")
|
|
5935
|
+
|
|
5936
|
+
if a.shape[0] != d.shape[0]:
|
|
5937
|
+
raise ValueError(
|
|
5938
|
+
f"tile_diag_add() 'd' argument must have the same number of elements as the number of rows in 'a', "
|
|
5939
|
+
f"got {d.shape[0]} elements in 'd' and {a.shape[0]} rows in 'a'"
|
|
5940
|
+
)
|
|
5941
|
+
|
|
5942
|
+
# use first argument to define output type
|
|
5943
|
+
return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
|
|
5944
|
+
|
|
5945
|
+
|
|
5946
|
+
def tile_diag_add_lto_dispatch_func(
|
|
5947
|
+
arg_types: Mapping[str, type],
|
|
5948
|
+
return_type: Any,
|
|
5949
|
+
return_values: List[Var],
|
|
5950
|
+
arg_values: Mapping[str, Var],
|
|
5951
|
+
options: Mapping[str, Any],
|
|
5952
|
+
builder: warp.context.ModuleBuilder,
|
|
5953
|
+
):
|
|
5954
|
+
a = arg_values["a"]
|
|
5955
|
+
d = arg_values["d"]
|
|
5956
|
+
# force the storage type of the input variables to shared memory
|
|
5957
|
+
a.type.storage = "shared"
|
|
5958
|
+
d.type.storage = "shared"
|
|
5959
|
+
out = return_values[0]
|
|
5960
|
+
return ((a, d, out), [], [], 0)
|
|
5961
|
+
|
|
5962
|
+
|
|
5963
|
+
add_builtin(
|
|
5964
|
+
"tile_diag_add",
|
|
5965
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "d": Tile(dtype=Any, shape=Any)},
|
|
5966
|
+
value_func=tile_diag_add_value_func,
|
|
5967
|
+
lto_dispatch_func=tile_diag_add_lto_dispatch_func,
|
|
5968
|
+
native_func="tile_diag_add",
|
|
5969
|
+
doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
|
|
5970
|
+
group="Tile Primitives",
|
|
5971
|
+
export=False,
|
|
5972
|
+
)
|
|
5973
|
+
|
|
5974
|
+
|
|
5726
5975
|
##
|
|
5727
5976
|
## MathDx, LTOIR-based, Tile functions
|
|
5728
5977
|
##
|
|
@@ -5734,24 +5983,25 @@ add_builtin(
|
|
|
5734
5983
|
def tile_matmul_generic_value_func(arg_types, arg_values):
|
|
5735
5984
|
# return generic type (for doc builds)
|
|
5736
5985
|
if arg_types is None:
|
|
5737
|
-
return Tile(dtype=Any,
|
|
5986
|
+
return Tile(dtype=Any, shape=Any)
|
|
5738
5987
|
|
|
5739
5988
|
a = arg_types["a"]
|
|
5740
5989
|
b = arg_types["b"]
|
|
5741
5990
|
|
|
5742
5991
|
if not is_tile(a):
|
|
5743
|
-
raise
|
|
5992
|
+
raise TypeError(f"tile_matmul() 'a' argument must be a tile, got {a!r}")
|
|
5993
|
+
|
|
5744
5994
|
if not is_tile(b):
|
|
5745
|
-
raise
|
|
5995
|
+
raise TypeError(f"tile_matmul() 'b' argument must be a tile, got {b!r}")
|
|
5746
5996
|
|
|
5747
5997
|
# out = wp.tile_matmul(a, b)
|
|
5748
5998
|
if len(arg_types) == 2:
|
|
5749
|
-
return Tile(dtype=a.dtype,
|
|
5999
|
+
return Tile(dtype=a.dtype, shape=(a.shape[0], b.shape[1]), storage="shared")
|
|
5750
6000
|
|
|
5751
6001
|
# wp.tile_matmul(a, b, out)
|
|
5752
6002
|
elif len(arg_types) == 3:
|
|
5753
6003
|
if not is_tile(arg_types["out"]):
|
|
5754
|
-
raise
|
|
6004
|
+
raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {arg_types['out']!r}")
|
|
5755
6005
|
|
|
5756
6006
|
return None
|
|
5757
6007
|
|
|
@@ -5774,16 +6024,20 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
5774
6024
|
accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value
|
|
5775
6025
|
out = arg_values["out"]
|
|
5776
6026
|
|
|
5777
|
-
if
|
|
5778
|
-
raise
|
|
6027
|
+
if not is_tile(out.type):
|
|
6028
|
+
raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {out!r}")
|
|
5779
6029
|
|
|
5780
6030
|
if any(arg.type.dtype not in [float16, float32, float64, vec2h, vec2f, vec2d] for arg in [a, b, out]):
|
|
5781
|
-
raise
|
|
6031
|
+
raise TypeError(
|
|
5782
6032
|
"tile_matmul() arguments must be tiles of float16, float32 or float64, vec2h, vec2f, vec2d entries"
|
|
5783
6033
|
)
|
|
5784
6034
|
|
|
5785
|
-
if (
|
|
5786
|
-
|
|
6035
|
+
if (
|
|
6036
|
+
(a.type.shape[1] != b.type.shape[0])
|
|
6037
|
+
or (a.type.shape[0] != out.type.shape[0])
|
|
6038
|
+
or (b.type.shape[1] != out.type.shape[1])
|
|
6039
|
+
):
|
|
6040
|
+
raise ValueError("tile_matmul(A, B, C) requires sizes of A, B and C to be consistent for a matmul")
|
|
5787
6041
|
|
|
5788
6042
|
# set the storage type to the inputs to shared
|
|
5789
6043
|
a.type.storage = "shared"
|
|
@@ -5805,18 +6059,18 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
5805
6059
|
return ("wp::vec2f", 5, 1)
|
|
5806
6060
|
if dtype == vec2d:
|
|
5807
6061
|
return ("wp::vec2d", 6, 1)
|
|
5808
|
-
raise
|
|
6062
|
+
raise TypeError("Unsupported input type in tile_matmul")
|
|
5809
6063
|
|
|
5810
6064
|
def cublasdx_arrangement_map(layout):
|
|
5811
6065
|
if layout == "colmajor":
|
|
5812
6066
|
return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
|
|
5813
6067
|
if layout == "rowmajor":
|
|
5814
6068
|
return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
|
|
5815
|
-
raise
|
|
6069
|
+
raise ValueError("Unsupported layout in tile_matmul")
|
|
5816
6070
|
|
|
5817
6071
|
# generate the LTO
|
|
5818
|
-
M, K = a.type.
|
|
5819
|
-
_, N = b.type.
|
|
6072
|
+
M, K = a.type.shape[0], a.type.shape[1]
|
|
6073
|
+
_, N = b.type.shape[0], b.type.shape[1]
|
|
5820
6074
|
num_threads = options["block_dim"]
|
|
5821
6075
|
arch = options["output_arch"]
|
|
5822
6076
|
|
|
@@ -5829,7 +6083,8 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
5829
6083
|
c_arrangement = cublasdx_arrangement_map(clayout)
|
|
5830
6084
|
|
|
5831
6085
|
if a_type != b_type or a_type != c_type:
|
|
5832
|
-
raise
|
|
6086
|
+
raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
|
|
6087
|
+
|
|
5833
6088
|
element_type = a_type
|
|
5834
6089
|
|
|
5835
6090
|
lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
|
|
@@ -5924,15 +6179,16 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
5924
6179
|
),
|
|
5925
6180
|
template_args,
|
|
5926
6181
|
[lto_forward, lto_backward_A, lto_backward_B],
|
|
6182
|
+
0,
|
|
5927
6183
|
)
|
|
5928
6184
|
|
|
5929
6185
|
|
|
5930
6186
|
add_builtin(
|
|
5931
6187
|
"tile_matmul",
|
|
5932
6188
|
input_types={
|
|
5933
|
-
"a": Tile(dtype=Any,
|
|
5934
|
-
"b": Tile(dtype=Any,
|
|
5935
|
-
"out": Tile(dtype=Any,
|
|
6189
|
+
"a": Tile(dtype=Any, shape=Any),
|
|
6190
|
+
"b": Tile(dtype=Any, shape=Any),
|
|
6191
|
+
"out": Tile(dtype=Any, shape=Any),
|
|
5936
6192
|
},
|
|
5937
6193
|
value_func=tile_matmul_generic_value_func,
|
|
5938
6194
|
lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
|
|
@@ -5956,7 +6212,7 @@ add_builtin(
|
|
|
5956
6212
|
|
|
5957
6213
|
add_builtin(
|
|
5958
6214
|
"tile_matmul",
|
|
5959
|
-
input_types={"a": Tile(dtype=Any,
|
|
6215
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
5960
6216
|
value_func=tile_matmul_generic_value_func,
|
|
5961
6217
|
lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
|
|
5962
6218
|
variadic=False,
|
|
@@ -5983,16 +6239,23 @@ add_builtin(
|
|
|
5983
6239
|
##
|
|
5984
6240
|
def tile_fft_generic_value_func(arg_types, arg_values):
|
|
5985
6241
|
if arg_types is None:
|
|
5986
|
-
return Tile(dtype=Any,
|
|
6242
|
+
return Tile(dtype=Any, shape=Any)
|
|
5987
6243
|
|
|
5988
6244
|
if len(arg_types) != 1:
|
|
5989
|
-
raise
|
|
6245
|
+
raise TypeError(f"tile_fft() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
5990
6246
|
|
|
5991
|
-
|
|
5992
|
-
raise RuntimeError("tile_fft() argument 0 must be a tile")
|
|
6247
|
+
inout = arg_types["inout"]
|
|
5993
6248
|
|
|
5994
|
-
if
|
|
5995
|
-
raise
|
|
6249
|
+
if not is_tile(inout):
|
|
6250
|
+
raise TypeError(f"tile_fft() argument must be a tile, got {inout!r}")
|
|
6251
|
+
|
|
6252
|
+
if inout.storage != "register":
|
|
6253
|
+
raise ValueError(f"tile_fft() argument must have 'register' storage, got {inout.storage}")
|
|
6254
|
+
|
|
6255
|
+
if inout.dtype not in [vec2f, vec2d]:
|
|
6256
|
+
raise TypeError(
|
|
6257
|
+
f"tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries, got {inout.dtype!r}"
|
|
6258
|
+
)
|
|
5996
6259
|
|
|
5997
6260
|
return None
|
|
5998
6261
|
|
|
@@ -6009,19 +6272,13 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6009
6272
|
inout = arg_values["inout"]
|
|
6010
6273
|
inout.type.storage = "register"
|
|
6011
6274
|
|
|
6012
|
-
if not is_tile(inout.type):
|
|
6013
|
-
raise RuntimeError("tile_fft() arguments must be a single tile with register storage")
|
|
6014
|
-
|
|
6015
|
-
if inout.type.dtype not in [vec2f, vec2d]:
|
|
6016
|
-
raise RuntimeError("tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries")
|
|
6017
|
-
|
|
6018
6275
|
# see libcufftdx.hpp
|
|
6019
6276
|
if direction == "forward":
|
|
6020
6277
|
dir = 0 # CUFFTDX_DIRECTION_FORWARD
|
|
6021
6278
|
elif direction == "inverse":
|
|
6022
6279
|
dir = 1 # CUFFTDX_DIRECTION_INVERSE
|
|
6023
6280
|
else:
|
|
6024
|
-
raise
|
|
6281
|
+
raise ValueError(f"Invalid direction: {direction!r}. Expected 'forward' or 'inverse'.")
|
|
6025
6282
|
|
|
6026
6283
|
if inout.type.dtype == vec2f:
|
|
6027
6284
|
dtype = "wp::vec2f"
|
|
@@ -6030,10 +6287,10 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6030
6287
|
dtype = "wp::vec2d"
|
|
6031
6288
|
precision = 6 # COMMONDX_PRECISION_F64
|
|
6032
6289
|
else:
|
|
6033
|
-
raise
|
|
6290
|
+
raise TypeError(f"Unsupported data type, got {dtype!r}")
|
|
6034
6291
|
|
|
6035
6292
|
# M FFTs of size N each
|
|
6036
|
-
batch, size = inout.type.
|
|
6293
|
+
batch, size = inout.type.shape[0], inout.type.shape[1]
|
|
6037
6294
|
num_threads = options["block_dim"]
|
|
6038
6295
|
arch = options["output_arch"]
|
|
6039
6296
|
ept = size // num_threads
|
|
@@ -6065,7 +6322,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6065
6322
|
lto_code.close()
|
|
6066
6323
|
if lto_code_path.exists():
|
|
6067
6324
|
lto_code_path.unlink()
|
|
6068
|
-
raise RuntimeError("Failed to compile
|
|
6325
|
+
raise RuntimeError("Failed to compile tile_fft")
|
|
6069
6326
|
|
|
6070
6327
|
with open(lto_code.name, "rb") as f:
|
|
6071
6328
|
lto_code_data = f.read()
|
|
@@ -6075,17 +6332,20 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6075
6332
|
|
|
6076
6333
|
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6077
6334
|
|
|
6335
|
+
shared_memory_bytes = Tile.round_up(shared_memory_size.value)
|
|
6336
|
+
|
|
6078
6337
|
return (
|
|
6079
6338
|
(
|
|
6080
6339
|
Var(lto_symbol, str, False, True, False),
|
|
6081
6340
|
Var(dtype, str, False, True, False),
|
|
6082
|
-
Var(str(
|
|
6341
|
+
Var(str(shared_memory_bytes), str, False, True, False),
|
|
6083
6342
|
Var(str(batch), str, False, True, False),
|
|
6084
6343
|
Var(str(ept), str, False, True, False),
|
|
6085
6344
|
inout,
|
|
6086
6345
|
),
|
|
6087
6346
|
[],
|
|
6088
6347
|
[lto_code_data],
|
|
6348
|
+
shared_memory_bytes,
|
|
6089
6349
|
)
|
|
6090
6350
|
|
|
6091
6351
|
|
|
@@ -6099,6 +6359,8 @@ add_builtin(
|
|
|
6099
6359
|
|
|
6100
6360
|
This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
|
|
6101
6361
|
|
|
6362
|
+
Note that computing the adjoint is not yet supported.
|
|
6363
|
+
|
|
6102
6364
|
Supported datatypes are:
|
|
6103
6365
|
* vec2f, vec2d
|
|
6104
6366
|
|
|
@@ -6118,6 +6380,8 @@ add_builtin(
|
|
|
6118
6380
|
|
|
6119
6381
|
This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
|
|
6120
6382
|
|
|
6383
|
+
Note that computing the adjoint is not yet supported.
|
|
6384
|
+
|
|
6121
6385
|
Supported datatypes are:
|
|
6122
6386
|
* vec2f, vec2d
|
|
6123
6387
|
|
|
@@ -6127,6 +6391,283 @@ add_builtin(
|
|
|
6127
6391
|
namespace="",
|
|
6128
6392
|
)
|
|
6129
6393
|
|
|
6394
|
+
|
|
6395
|
+
##
|
|
6396
|
+
## Cholesky
|
|
6397
|
+
##
|
|
6398
|
+
def tile_cholesky_generic_value_func(arg_types, arg_values):
|
|
6399
|
+
if arg_types is None:
|
|
6400
|
+
return Tile(dtype=Any, shape=Any)
|
|
6401
|
+
|
|
6402
|
+
if len(arg_types) != 1:
|
|
6403
|
+
raise TypeError("tile_cholesky() requires 1 positional args")
|
|
6404
|
+
|
|
6405
|
+
a = arg_types["A"]
|
|
6406
|
+
|
|
6407
|
+
if not is_tile(a):
|
|
6408
|
+
raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
|
|
6409
|
+
|
|
6410
|
+
if len(a.shape) != 2:
|
|
6411
|
+
raise ValueError("tile_cholesky() argumust must be a 2D tile")
|
|
6412
|
+
|
|
6413
|
+
if a.shape[0] != a.shape[1]:
|
|
6414
|
+
raise ValueError("tile_cholesky() argument must be square")
|
|
6415
|
+
|
|
6416
|
+
return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
|
|
6417
|
+
|
|
6418
|
+
|
|
6419
|
+
cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3}
|
|
6420
|
+
|
|
6421
|
+
cusolver_type_map = {float32: ("wp::float32", 5), float64: ("wp::float64", 6)}
|
|
6422
|
+
|
|
6423
|
+
cusolver_fill_mode_map = {"upper": 0, "lower": 1}
|
|
6424
|
+
|
|
6425
|
+
|
|
6426
|
+
def tile_cholesky_generic_lto_dispatch_func(
|
|
6427
|
+
arg_types: Mapping[str, type],
|
|
6428
|
+
return_type: Any,
|
|
6429
|
+
return_values: List[Var],
|
|
6430
|
+
arg_values: Mapping[str, Var],
|
|
6431
|
+
options: Mapping[str, Any],
|
|
6432
|
+
builder: warp.context.ModuleBuilder,
|
|
6433
|
+
):
|
|
6434
|
+
a = arg_values["A"]
|
|
6435
|
+
# force source tile to shared memory
|
|
6436
|
+
a.type.storage = "shared"
|
|
6437
|
+
|
|
6438
|
+
if a.type.dtype not in cusolver_type_map.keys():
|
|
6439
|
+
raise TypeError("tile_cholesky() argument must be a tile of float32 or float64 entries")
|
|
6440
|
+
|
|
6441
|
+
if len(return_values) != 1:
|
|
6442
|
+
raise TypeError("tile_cholesky() returns one output")
|
|
6443
|
+
out = return_values[0]
|
|
6444
|
+
|
|
6445
|
+
dtype, precision_enum = cusolver_type_map[a.type.dtype]
|
|
6446
|
+
|
|
6447
|
+
# We already ensured a is square in tile_cholesky_generic_value_func()
|
|
6448
|
+
M, N = a.type.shape[0], a.type.shape[1]
|
|
6449
|
+
if out.type.shape[0] != M or out.type.shape[1] != M:
|
|
6450
|
+
raise ValueError("tile_cholesky() output tile must be square")
|
|
6451
|
+
|
|
6452
|
+
num_threads = options["block_dim"]
|
|
6453
|
+
arch = options["output_arch"]
|
|
6454
|
+
lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
|
|
6455
|
+
|
|
6456
|
+
# early out if LTO for this combination already exists for this module
|
|
6457
|
+
if lto_symbol in builder.ltoirs:
|
|
6458
|
+
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6459
|
+
|
|
6460
|
+
# otherwise compile LTO
|
|
6461
|
+
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6462
|
+
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6463
|
+
|
|
6464
|
+
# cuSOLVERDx only support col-major input/outputs,
|
|
6465
|
+
# so we use upper to mimic a row-major input
|
|
6466
|
+
result = warp.context.runtime.core.cuda_compile_solver(
|
|
6467
|
+
universal_fatbin_code.name.encode("utf-8"),
|
|
6468
|
+
lto_code.name.encode("utf-8"),
|
|
6469
|
+
lto_symbol.encode("utf-8"),
|
|
6470
|
+
0,
|
|
6471
|
+
None,
|
|
6472
|
+
None,
|
|
6473
|
+
arch,
|
|
6474
|
+
M,
|
|
6475
|
+
N,
|
|
6476
|
+
cusolver_function_map["potrf"],
|
|
6477
|
+
precision_enum,
|
|
6478
|
+
cusolver_fill_mode_map["upper"],
|
|
6479
|
+
num_threads,
|
|
6480
|
+
)
|
|
6481
|
+
|
|
6482
|
+
if not result:
|
|
6483
|
+
for f in [lto_code, universal_fatbin_code]:
|
|
6484
|
+
f.close()
|
|
6485
|
+
if Path(f.name).exists():
|
|
6486
|
+
Path(f.name).unlink()
|
|
6487
|
+
raise RuntimeError("Failed to compile tile_cholesky")
|
|
6488
|
+
|
|
6489
|
+
else:
|
|
6490
|
+
with open(lto_code.name, "rb") as f:
|
|
6491
|
+
lto_code_data = f.read()
|
|
6492
|
+
with open(universal_fatbin_code.name, "rb") as f:
|
|
6493
|
+
universal_fatbin_code_data = f.read()
|
|
6494
|
+
for f in [lto_code, universal_fatbin_code]:
|
|
6495
|
+
f.close()
|
|
6496
|
+
Path(f.name).unlink()
|
|
6497
|
+
|
|
6498
|
+
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6499
|
+
builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, unsigned);"
|
|
6500
|
+
builder.fatbins["cholesky"] = universal_fatbin_code_data
|
|
6501
|
+
|
|
6502
|
+
return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
|
|
6503
|
+
|
|
6504
|
+
|
|
6505
|
+
add_builtin(
|
|
6506
|
+
"tile_cholesky",
|
|
6507
|
+
input_types={"A": Tile},
|
|
6508
|
+
value_func=tile_cholesky_generic_value_func,
|
|
6509
|
+
lto_dispatch_func=tile_cholesky_generic_lto_dispatch_func,
|
|
6510
|
+
variadic=True,
|
|
6511
|
+
doc="""Compute the Cholesky factorization L of a matrix A.
|
|
6512
|
+
L is lower triangular and satisfies LL^T = A.
|
|
6513
|
+
|
|
6514
|
+
Note that computing the adjoint is not yet supported.
|
|
6515
|
+
|
|
6516
|
+
Supported datatypes are:
|
|
6517
|
+
* float32
|
|
6518
|
+
* float64
|
|
6519
|
+
|
|
6520
|
+
:param A: A square, symmetric positive-definite, matrix.
|
|
6521
|
+
:returns L: A square, lower triangular, matrix, such that LL^T = A""",
|
|
6522
|
+
group="Tile Primitives",
|
|
6523
|
+
export=False,
|
|
6524
|
+
namespace="",
|
|
6525
|
+
)
|
|
6526
|
+
|
|
6527
|
+
|
|
6528
|
+
def tile_cholesky_solve_generic_value_func(arg_types, arg_values):
|
|
6529
|
+
if arg_types is None:
|
|
6530
|
+
return None
|
|
6531
|
+
|
|
6532
|
+
if len(arg_types) != 2:
|
|
6533
|
+
raise TypeError("tile_cholesky_solve() requires exactly 2 positional args")
|
|
6534
|
+
|
|
6535
|
+
l = arg_types["L"]
|
|
6536
|
+
x = arg_types["x"]
|
|
6537
|
+
|
|
6538
|
+
if not is_tile(l):
|
|
6539
|
+
raise TypeError(f"tile_cholesky_solve() 'L' argument must be a tile, got {l!r}")
|
|
6540
|
+
|
|
6541
|
+
if not is_tile(x):
|
|
6542
|
+
raise TypeError(f"tile_cholesky_solve() 'x' argument must be a tile, got {l!r}")
|
|
6543
|
+
|
|
6544
|
+
if not types_equal(l.dtype, x.dtype):
|
|
6545
|
+
raise TypeError(f"tile_cholesky_solve() arguments must have the same dtype, got {l.dtype} and {x.dtype}")
|
|
6546
|
+
|
|
6547
|
+
if l.shape[0] != l.shape[1]:
|
|
6548
|
+
raise ValueError("tile_cholesky_solve() 'L' argument must be square")
|
|
6549
|
+
|
|
6550
|
+
if len(x.shape) != 1:
|
|
6551
|
+
raise TypeError("tile_cholesky_solve() 'x' argument must be a 1D tile")
|
|
6552
|
+
|
|
6553
|
+
if x.shape[0] != l.shape[0]:
|
|
6554
|
+
raise ValueError(
|
|
6555
|
+
f"tile_cholesky_solve() 'x' argument must have the same number of elements as the number of rows in 'L', "
|
|
6556
|
+
f"got {x.shape[0]} elements in 'x' and {l.shape[0]} rows in 'L'"
|
|
6557
|
+
)
|
|
6558
|
+
|
|
6559
|
+
return Tile(dtype=l.dtype, shape=x.shape, storage="shared")
|
|
6560
|
+
|
|
6561
|
+
|
|
6562
|
+
def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
6563
|
+
arg_types: Mapping[str, type],
|
|
6564
|
+
return_type: Any,
|
|
6565
|
+
return_values: List[Var],
|
|
6566
|
+
arg_values: Mapping[str, Var],
|
|
6567
|
+
options: Mapping[str, Any],
|
|
6568
|
+
builder: warp.context.ModuleBuilder,
|
|
6569
|
+
):
|
|
6570
|
+
L = arg_values["L"]
|
|
6571
|
+
x = arg_values["x"]
|
|
6572
|
+
# force the storage type of the input variables to shared memory
|
|
6573
|
+
L.type.storage = "shared"
|
|
6574
|
+
x.type.storage = "shared"
|
|
6575
|
+
|
|
6576
|
+
if len(return_values) != 1:
|
|
6577
|
+
raise TypeError(f"tile_cholesky_solve() must return exactly one value, got {len(return_values)}")
|
|
6578
|
+
|
|
6579
|
+
y = return_values[0]
|
|
6580
|
+
|
|
6581
|
+
if any(T not in cusolver_type_map.keys() for T in [x.type.dtype, L.type.dtype]):
|
|
6582
|
+
raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
|
|
6583
|
+
|
|
6584
|
+
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
6585
|
+
M, N = L.type.shape[0], L.type.shape[1]
|
|
6586
|
+
|
|
6587
|
+
if len(y.type.shape) != 1:
|
|
6588
|
+
raise TypeError("tile_cholesky_solve() output vector must be 1D")
|
|
6589
|
+
|
|
6590
|
+
if y.type.shape[0] != M:
|
|
6591
|
+
raise ValueError(
|
|
6592
|
+
"tile_cholesky_solve() output vector must have same number of elements as the number of rows in 'L' "
|
|
6593
|
+
f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
6594
|
+
)
|
|
6595
|
+
|
|
6596
|
+
num_threads = options["block_dim"]
|
|
6597
|
+
arch = options["output_arch"]
|
|
6598
|
+
lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
|
|
6599
|
+
|
|
6600
|
+
# early out if LTO for this combination already exists for this module
|
|
6601
|
+
if lto_symbol in builder.ltoirs:
|
|
6602
|
+
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6603
|
+
|
|
6604
|
+
# otherwise compile LTO
|
|
6605
|
+
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6606
|
+
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6607
|
+
|
|
6608
|
+
# cuSOLVERDx only support col-major input/outputs,
|
|
6609
|
+
# so we use upper to mimic a row-major input
|
|
6610
|
+
result = warp.context.runtime.core.cuda_compile_solver(
|
|
6611
|
+
universal_fatbin_code.name.encode("utf-8"),
|
|
6612
|
+
lto_code.name.encode("utf-8"),
|
|
6613
|
+
lto_symbol.encode("utf-8"),
|
|
6614
|
+
0,
|
|
6615
|
+
None,
|
|
6616
|
+
None,
|
|
6617
|
+
arch,
|
|
6618
|
+
M,
|
|
6619
|
+
N,
|
|
6620
|
+
cusolver_function_map["potrs"],
|
|
6621
|
+
precision_enum,
|
|
6622
|
+
cusolver_fill_mode_map["upper"],
|
|
6623
|
+
num_threads,
|
|
6624
|
+
)
|
|
6625
|
+
|
|
6626
|
+
if not result:
|
|
6627
|
+
for f in [lto_code, universal_fatbin_code]:
|
|
6628
|
+
f.close()
|
|
6629
|
+
if Path(f.name).exists():
|
|
6630
|
+
Path(f.name).unlink()
|
|
6631
|
+
raise RuntimeError("Failed to compile tile_cholesky_solve")
|
|
6632
|
+
|
|
6633
|
+
else:
|
|
6634
|
+
with open(lto_code.name, "rb") as f:
|
|
6635
|
+
lto_code_data = f.read()
|
|
6636
|
+
with open(universal_fatbin_code.name, "rb") as f:
|
|
6637
|
+
universal_fatbin_code_data = f.read()
|
|
6638
|
+
for f in [lto_code, universal_fatbin_code]:
|
|
6639
|
+
f.close()
|
|
6640
|
+
Path(f.name).unlink()
|
|
6641
|
+
|
|
6642
|
+
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6643
|
+
builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, {dtype}*);"
|
|
6644
|
+
builder.fatbins["cholesky"] = universal_fatbin_code_data
|
|
6645
|
+
|
|
6646
|
+
return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
|
|
6647
|
+
|
|
6648
|
+
|
|
6649
|
+
add_builtin(
|
|
6650
|
+
"tile_cholesky_solve",
|
|
6651
|
+
input_types={"L": Tile, "x": Tile},
|
|
6652
|
+
value_func=tile_cholesky_solve_generic_value_func,
|
|
6653
|
+
lto_dispatch_func=tile_cholesky_solve_generic_lto_dispatch_func,
|
|
6654
|
+
variadic=True,
|
|
6655
|
+
doc="""With L such that LL^T = A, solve for x in Ax = y
|
|
6656
|
+
|
|
6657
|
+
Note that computing the adjoint is not yet supported.
|
|
6658
|
+
|
|
6659
|
+
Supported datatypes are:
|
|
6660
|
+
* float32
|
|
6661
|
+
* float64
|
|
6662
|
+
|
|
6663
|
+
:param L: A square, lower triangular, matrix, such that LL^T = A
|
|
6664
|
+
:param x: An 1D tile of length M
|
|
6665
|
+
:returns y: An 1D tile of length M such that LL^T y = x""",
|
|
6666
|
+
group="Tile Primitives",
|
|
6667
|
+
export=False,
|
|
6668
|
+
namespace="",
|
|
6669
|
+
)
|
|
6670
|
+
|
|
6130
6671
|
# ---------------------------------
|
|
6131
6672
|
# Code Generation
|
|
6132
6673
|
|
|
@@ -6134,7 +6675,7 @@ add_builtin(
|
|
|
6134
6675
|
"static",
|
|
6135
6676
|
input_types={"expr": Any},
|
|
6136
6677
|
value_type=Any,
|
|
6137
|
-
doc="""
|
|
6678
|
+
doc="""Evaluate a static Python expression and replaces it with its result.
|
|
6138
6679
|
|
|
6139
6680
|
See the :ref:`code generation guide <static_expressions>` for more details.
|
|
6140
6681
|
|
|
@@ -6158,3 +6699,58 @@ def static(expr):
|
|
|
6158
6699
|
which includes constant variables and variables captured in the current closure in which the function or kernel is implemented.
|
|
6159
6700
|
"""
|
|
6160
6701
|
return expr
|
|
6702
|
+
|
|
6703
|
+
|
|
6704
|
+
add_builtin(
|
|
6705
|
+
"len",
|
|
6706
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
6707
|
+
value_type=int,
|
|
6708
|
+
doc="Return the number of elements in a vector.",
|
|
6709
|
+
group="Utility",
|
|
6710
|
+
export=False,
|
|
6711
|
+
)
|
|
6712
|
+
|
|
6713
|
+
add_builtin(
|
|
6714
|
+
"len",
|
|
6715
|
+
input_types={"a": quaternion(dtype=Scalar)},
|
|
6716
|
+
value_type=int,
|
|
6717
|
+
doc="Return the number of elements in a quaternion.",
|
|
6718
|
+
group="Utility",
|
|
6719
|
+
export=False,
|
|
6720
|
+
)
|
|
6721
|
+
|
|
6722
|
+
add_builtin(
|
|
6723
|
+
"len",
|
|
6724
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
6725
|
+
value_type=int,
|
|
6726
|
+
doc="Return the number of rows in a matrix.",
|
|
6727
|
+
group="Utility",
|
|
6728
|
+
export=False,
|
|
6729
|
+
)
|
|
6730
|
+
|
|
6731
|
+
add_builtin(
|
|
6732
|
+
"len",
|
|
6733
|
+
input_types={"a": transformation(dtype=Float)},
|
|
6734
|
+
value_type=int,
|
|
6735
|
+
doc="Return the number of elements in a transformation.",
|
|
6736
|
+
group="Utility",
|
|
6737
|
+
export=False,
|
|
6738
|
+
)
|
|
6739
|
+
|
|
6740
|
+
add_builtin(
|
|
6741
|
+
"len",
|
|
6742
|
+
input_types={"a": array(dtype=Any)},
|
|
6743
|
+
value_type=int,
|
|
6744
|
+
doc="Return the size of the first dimension in an array.",
|
|
6745
|
+
group="Utility",
|
|
6746
|
+
export=False,
|
|
6747
|
+
)
|
|
6748
|
+
|
|
6749
|
+
add_builtin(
|
|
6750
|
+
"len",
|
|
6751
|
+
input_types={"a": Tile(dtype=Any, shape=Any)},
|
|
6752
|
+
value_type=int,
|
|
6753
|
+
doc="Return the number of rows in a tile.",
|
|
6754
|
+
group="Utility",
|
|
6755
|
+
export=False,
|
|
6756
|
+
)
|