warp-lang 1.5.0__py3-none-macosx_10_13_universal2.whl → 1.6.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +5 -0
- warp/autograd.py +414 -191
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +40 -12
- warp/build_dll.py +13 -6
- warp/builtins.py +1124 -497
- warp/codegen.py +261 -136
- warp/config.py +1 -1
- warp/context.py +357 -119
- warp/examples/assets/square_cloth.usd +0 -0
- warp/examples/benchmarks/benchmark_gemm.py +27 -18
- warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
- warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
- warp/examples/core/example_torch.py +18 -34
- warp/examples/fem/example_apic_fluid.py +1 -0
- warp/examples/fem/example_mixed_elasticity.py +1 -1
- warp/examples/optim/example_bounce.py +1 -1
- warp/examples/optim/example_cloth_throw.py +1 -1
- warp/examples/optim/example_diffray.py +4 -15
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/optim/example_softbody_properties.py +392 -0
- warp/examples/optim/example_trajectory.py +1 -3
- warp/examples/optim/example_walker.py +5 -0
- warp/examples/sim/example_cartpole.py +0 -2
- warp/examples/sim/example_cloth.py +3 -1
- warp/examples/sim/example_cloth_self_contact.py +260 -0
- warp/examples/sim/example_granular_collision_sdf.py +4 -5
- warp/examples/sim/example_jacobian_ik.py +0 -2
- warp/examples/sim/example_quadruped.py +5 -2
- warp/examples/tile/example_tile_cholesky.py +79 -0
- warp/examples/tile/example_tile_convolution.py +2 -2
- warp/examples/tile/example_tile_fft.py +2 -2
- warp/examples/tile/example_tile_filtering.py +3 -3
- warp/examples/tile/example_tile_matmul.py +4 -4
- warp/examples/tile/example_tile_mlp.py +12 -12
- warp/examples/tile/example_tile_nbody.py +180 -0
- warp/examples/tile/example_tile_walker.py +319 -0
- warp/fem/geometry/geometry.py +0 -2
- warp/math.py +147 -0
- warp/native/array.h +12 -0
- warp/native/builtin.h +0 -1
- warp/native/bvh.cpp +149 -70
- warp/native/bvh.cu +287 -68
- warp/native/bvh.h +195 -85
- warp/native/clang/clang.cpp +5 -1
- warp/native/coloring.cpp +5 -1
- warp/native/cuda_util.cpp +91 -53
- warp/native/cuda_util.h +5 -0
- warp/native/exports.h +40 -40
- warp/native/intersect.h +17 -0
- warp/native/mat.h +41 -0
- warp/native/mathdx.cpp +19 -0
- warp/native/mesh.cpp +25 -8
- warp/native/mesh.cu +153 -101
- warp/native/mesh.h +482 -403
- warp/native/quat.h +40 -0
- warp/native/solid_angle.h +7 -0
- warp/native/sort.cpp +85 -0
- warp/native/sort.cu +34 -0
- warp/native/sort.h +3 -1
- warp/native/spatial.h +11 -0
- warp/native/tile.h +1187 -669
- warp/native/tile_reduce.h +8 -6
- warp/native/vec.h +41 -0
- warp/native/warp.cpp +8 -1
- warp/native/warp.cu +263 -40
- warp/native/warp.h +19 -5
- warp/optim/linear.py +22 -4
- warp/render/render_opengl.py +130 -64
- warp/sim/__init__.py +6 -1
- warp/sim/collide.py +270 -26
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +25 -7
- warp/sim/integrator_featherstone.py +154 -35
- warp/sim/integrator_vbd.py +842 -40
- warp/sim/model.py +134 -72
- warp/sparse.py +1 -1
- warp/stubs.py +265 -132
- warp/tape.py +28 -30
- warp/tests/aux_test_module_unload.py +15 -0
- warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
- warp/tests/test_array.py +74 -0
- warp/tests/test_assert.py +242 -0
- warp/tests/test_codegen.py +14 -61
- warp/tests/test_collision.py +2 -2
- warp/tests/test_coloring.py +12 -2
- warp/tests/test_examples.py +12 -1
- warp/tests/test_func.py +21 -4
- warp/tests/test_grad_debug.py +87 -2
- warp/tests/test_hash_grid.py +1 -1
- warp/tests/test_ipc.py +116 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat.py +138 -167
- warp/tests/test_math.py +47 -1
- warp/tests/test_matmul.py +17 -16
- warp/tests/test_matmul_lite.py +10 -15
- warp/tests/test_mesh.py +84 -60
- warp/tests/test_mesh_query_aabb.py +165 -0
- warp/tests/test_mesh_query_point.py +328 -286
- warp/tests/test_mesh_query_ray.py +134 -121
- warp/tests/test_mlp.py +2 -2
- warp/tests/test_operators.py +43 -0
- warp/tests/test_overwrite.py +47 -2
- warp/tests/test_quat.py +77 -0
- warp/tests/test_reload.py +29 -0
- warp/tests/test_sim_grad_bounce_linear.py +204 -0
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_static.py +19 -3
- warp/tests/test_tape.py +25 -0
- warp/tests/test_tile.py +178 -191
- warp/tests/test_tile_load.py +356 -0
- warp/tests/test_tile_mathdx.py +61 -8
- warp/tests/test_tile_mlp.py +17 -17
- warp/tests/test_tile_reduce.py +24 -18
- warp/tests/test_tile_shared_memory.py +66 -17
- warp/tests/test_tile_view.py +165 -0
- warp/tests/test_torch.py +35 -0
- warp/tests/test_utils.py +36 -24
- warp/tests/test_vec.py +110 -0
- warp/tests/unittest_suites.py +29 -4
- warp/tests/unittest_utils.py +30 -13
- warp/thirdparty/unittest_parallel.py +2 -2
- warp/types.py +411 -101
- warp/utils.py +10 -7
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
- warp/examples/benchmarks/benchmark_tile.py +0 -179
- warp/native/tile_gemm.h +0 -341
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/builtins.py
CHANGED
|
@@ -399,11 +399,11 @@ def scalar_infer_type(arg_types: Mapping[str, type]):
|
|
|
399
399
|
|
|
400
400
|
scalar_types = set()
|
|
401
401
|
for t in arg_types:
|
|
402
|
-
|
|
403
|
-
if hasattr(
|
|
404
|
-
scalar_types.add(
|
|
405
|
-
elif
|
|
406
|
-
scalar_types.add(
|
|
402
|
+
t_val = strip_reference(t)
|
|
403
|
+
if hasattr(t_val, "_wp_scalar_type_"):
|
|
404
|
+
scalar_types.add(t_val._wp_scalar_type_)
|
|
405
|
+
elif t_val in scalar_and_bool_types:
|
|
406
|
+
scalar_types.add(t_val)
|
|
407
407
|
|
|
408
408
|
if len(scalar_types) > 1:
|
|
409
409
|
raise RuntimeError(
|
|
@@ -1707,64 +1707,98 @@ add_builtin(
|
|
|
1707
1707
|
|
|
1708
1708
|
# ------------------
|
|
1709
1709
|
# Tile-based primitives
|
|
1710
|
-
|
|
1710
|
+
|
|
1711
|
+
|
|
1712
|
+
def tile_unpack_shape(arg_values):
|
|
1713
|
+
shape = arg_values["shape"]
|
|
1714
|
+
|
|
1715
|
+
if not isinstance(shape, tuple):
|
|
1716
|
+
# promote to tuple
|
|
1717
|
+
shape = (shape,)
|
|
1718
|
+
|
|
1719
|
+
# check that components are constants
|
|
1720
|
+
for d in shape:
|
|
1721
|
+
if d is None:
|
|
1722
|
+
raise ValueError("Tile functions require shape to be a compile time constant.")
|
|
1723
|
+
|
|
1724
|
+
return shape
|
|
1725
|
+
|
|
1726
|
+
|
|
1727
|
+
def tile_unpack_offset(arg_values, ndim=0):
|
|
1728
|
+
if "offset" in arg_values:
|
|
1729
|
+
offset = arg_values["offset"]
|
|
1730
|
+
else:
|
|
1731
|
+
offset = (0,) * ndim
|
|
1732
|
+
|
|
1733
|
+
if isinstance(offset, tuple):
|
|
1734
|
+
return offset
|
|
1735
|
+
else:
|
|
1736
|
+
# promote to tuple
|
|
1737
|
+
return (offset,)
|
|
1711
1738
|
|
|
1712
1739
|
|
|
1713
1740
|
def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1714
1741
|
# return generic type (for doc builds)
|
|
1715
1742
|
if arg_types is None:
|
|
1716
|
-
return Tile(dtype=Any,
|
|
1717
|
-
|
|
1718
|
-
if "m" not in arg_values:
|
|
1719
|
-
raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
|
|
1743
|
+
return Tile(dtype=Any, shape=Any)
|
|
1720
1744
|
|
|
1721
|
-
|
|
1722
|
-
raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
|
|
1745
|
+
shape = tile_unpack_shape(arg_values)
|
|
1723
1746
|
|
|
1724
1747
|
if "dtype" not in arg_values:
|
|
1725
|
-
raise
|
|
1748
|
+
raise TypeError("tile_zeros() missing required keyword argument 'dtype'")
|
|
1726
1749
|
|
|
1727
1750
|
if "storage" not in arg_values:
|
|
1728
|
-
raise
|
|
1751
|
+
raise TypeError("tile_zeros() missing required keyword argument 'storage'")
|
|
1729
1752
|
|
|
1730
1753
|
if arg_values["storage"] not in {"shared", "register"}:
|
|
1731
|
-
raise ValueError(
|
|
1732
|
-
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1733
|
-
)
|
|
1754
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
1734
1755
|
|
|
1735
|
-
m, n = arg_values["m"], arg_values["n"]
|
|
1736
1756
|
dtype = arg_values["dtype"]
|
|
1737
1757
|
|
|
1738
|
-
return TileZeros(dtype=dtype,
|
|
1758
|
+
return TileZeros(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
1739
1759
|
|
|
1740
1760
|
|
|
1741
1761
|
def tile_zeros_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1742
|
-
|
|
1762
|
+
shape = tile_unpack_shape(arg_values)
|
|
1763
|
+
dtype = arg_values["dtype"]
|
|
1743
1764
|
|
|
1744
1765
|
template_args = []
|
|
1745
1766
|
template_args.append(dtype)
|
|
1746
|
-
|
|
1747
|
-
|
|
1767
|
+
for d in shape:
|
|
1768
|
+
template_args.append(d.constant)
|
|
1748
1769
|
|
|
1749
1770
|
return ([], template_args)
|
|
1750
1771
|
|
|
1751
1772
|
|
|
1752
1773
|
add_builtin(
|
|
1753
1774
|
"tile_zeros",
|
|
1754
|
-
input_types={"
|
|
1755
|
-
defaults={"storage": "register"},
|
|
1775
|
+
input_types={"shape": Tuple[int, ...], "dtype": Any, "storage": str},
|
|
1776
|
+
defaults={"storage": "register", "dtype": float},
|
|
1756
1777
|
value_func=tile_zeros_value_func,
|
|
1757
1778
|
dispatch_func=tile_zeros_dispatch_func,
|
|
1758
1779
|
variadic=False,
|
|
1759
1780
|
missing_grad=True,
|
|
1760
|
-
doc="""
|
|
1781
|
+
doc="""Allocate a tile of zero-initialized items.
|
|
1761
1782
|
|
|
1762
|
-
:param
|
|
1763
|
-
:param
|
|
1764
|
-
:param dtype: Datatype of output tile's elements
|
|
1783
|
+
:param shape: Shape of the output tile
|
|
1784
|
+
:param dtype: Data type of output tile's elements (default float)
|
|
1765
1785
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1766
1786
|
(default) or ``"shared"`` for shared memory.
|
|
1767
|
-
:returns: A zero-initialized tile with
|
|
1787
|
+
:returns: A zero-initialized tile with shape and data type as specified""",
|
|
1788
|
+
group="Tile Primitives",
|
|
1789
|
+
export=False,
|
|
1790
|
+
)
|
|
1791
|
+
|
|
1792
|
+
# overload for scalar shape
|
|
1793
|
+
add_builtin(
|
|
1794
|
+
"tile_zeros",
|
|
1795
|
+
input_types={"shape": int, "dtype": Any, "storage": str},
|
|
1796
|
+
defaults={"storage": "register", "dtype": float},
|
|
1797
|
+
value_func=tile_zeros_value_func,
|
|
1798
|
+
dispatch_func=tile_zeros_dispatch_func,
|
|
1799
|
+
variadic=False,
|
|
1800
|
+
missing_grad=True,
|
|
1801
|
+
hidden=True,
|
|
1768
1802
|
group="Tile Primitives",
|
|
1769
1803
|
export=False,
|
|
1770
1804
|
)
|
|
@@ -1773,54 +1807,63 @@ add_builtin(
|
|
|
1773
1807
|
def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1774
1808
|
# return generic type (for doc builds)
|
|
1775
1809
|
if arg_types is None:
|
|
1776
|
-
return Tile(dtype=Any,
|
|
1810
|
+
return Tile(dtype=Any, shape=Any)
|
|
1777
1811
|
|
|
1778
|
-
|
|
1779
|
-
raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
|
|
1780
|
-
|
|
1781
|
-
if "n" not in arg_values:
|
|
1782
|
-
raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
|
|
1812
|
+
shape = tile_unpack_shape(arg_values)
|
|
1783
1813
|
|
|
1784
1814
|
if "dtype" not in arg_values:
|
|
1785
|
-
raise
|
|
1815
|
+
raise TypeError("tile_ones() missing required keyword argument 'dtype'")
|
|
1816
|
+
|
|
1817
|
+
if "storage" not in arg_values:
|
|
1818
|
+
raise TypeError("tile_ones() missing required keyword argument 'storage'")
|
|
1786
1819
|
|
|
1787
1820
|
if arg_values["storage"] not in {"shared", "register"}:
|
|
1788
|
-
raise ValueError(
|
|
1789
|
-
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1790
|
-
)
|
|
1821
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
1791
1822
|
|
|
1792
|
-
m, n = arg_values["m"], arg_values["n"]
|
|
1793
1823
|
dtype = arg_values["dtype"]
|
|
1794
1824
|
|
|
1795
|
-
return
|
|
1825
|
+
return TileOnes(dtype=dtype, shape=shape, storage=arg_values["storage"])
|
|
1796
1826
|
|
|
1797
1827
|
|
|
1798
1828
|
def tile_ones_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1799
|
-
|
|
1829
|
+
shape = tile_unpack_shape(arg_values)
|
|
1830
|
+
dtype = arg_values["dtype"]
|
|
1800
1831
|
|
|
1801
1832
|
template_args = []
|
|
1802
1833
|
template_args.append(dtype)
|
|
1803
|
-
|
|
1804
|
-
|
|
1834
|
+
for d in shape:
|
|
1835
|
+
template_args.append(d.constant)
|
|
1805
1836
|
|
|
1806
1837
|
return ([], template_args)
|
|
1807
1838
|
|
|
1808
1839
|
|
|
1809
1840
|
add_builtin(
|
|
1810
1841
|
"tile_ones",
|
|
1811
|
-
input_types={"
|
|
1842
|
+
input_types={"shape": Tuple[int, ...], "dtype": Any, "storage": str},
|
|
1812
1843
|
defaults={"storage": "register"},
|
|
1813
1844
|
value_func=tile_ones_value_func,
|
|
1814
1845
|
dispatch_func=tile_ones_dispatch_func,
|
|
1815
1846
|
missing_grad=True,
|
|
1816
|
-
doc="""
|
|
1847
|
+
doc="""Allocate a tile of one-initialized items.
|
|
1817
1848
|
|
|
1818
|
-
:param
|
|
1819
|
-
:param
|
|
1820
|
-
:param dtype: Datatype of output tile's elements
|
|
1849
|
+
:param shape: Shape of the output tile
|
|
1850
|
+
:param dtype: Data type of output tile's elements
|
|
1821
1851
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1822
1852
|
(default) or ``"shared"`` for shared memory.
|
|
1823
|
-
:returns: A one-initialized tile with
|
|
1853
|
+
:returns: A one-initialized tile with shape and data type as specified""",
|
|
1854
|
+
group="Tile Primitives",
|
|
1855
|
+
export=False,
|
|
1856
|
+
)
|
|
1857
|
+
|
|
1858
|
+
# overload for scalar shape
|
|
1859
|
+
add_builtin(
|
|
1860
|
+
"tile_ones",
|
|
1861
|
+
input_types={"shape": int, "dtype": Any, "storage": str},
|
|
1862
|
+
defaults={"storage": "register"},
|
|
1863
|
+
value_func=tile_ones_value_func,
|
|
1864
|
+
dispatch_func=tile_ones_dispatch_func,
|
|
1865
|
+
missing_grad=True,
|
|
1866
|
+
hidden=True,
|
|
1824
1867
|
group="Tile Primitives",
|
|
1825
1868
|
export=False,
|
|
1826
1869
|
)
|
|
@@ -1829,14 +1872,16 @@ add_builtin(
|
|
|
1829
1872
|
def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1830
1873
|
# return generic type (for doc builds)
|
|
1831
1874
|
if arg_types is None:
|
|
1832
|
-
return Tile(dtype=Any,
|
|
1875
|
+
return Tile(dtype=Any, shape=Any)
|
|
1876
|
+
|
|
1877
|
+
if "args" not in arg_values:
|
|
1878
|
+
raise TypeError("tile_arange() requires at least one positional argument specifying the range")
|
|
1879
|
+
|
|
1880
|
+
args = arg_values["args"]
|
|
1833
1881
|
|
|
1834
1882
|
start = 0
|
|
1835
1883
|
stop = 0
|
|
1836
1884
|
step = 1
|
|
1837
|
-
dtype = int
|
|
1838
|
-
|
|
1839
|
-
args = arg_values["args"]
|
|
1840
1885
|
|
|
1841
1886
|
if len(args) == 1:
|
|
1842
1887
|
start = 0
|
|
@@ -1852,7 +1897,7 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
|
|
|
1852
1897
|
step = args[2]
|
|
1853
1898
|
|
|
1854
1899
|
if start is None or stop is None or step is None:
|
|
1855
|
-
raise RuntimeError("
|
|
1900
|
+
raise RuntimeError("tile_arange() arguments must be compile time constants")
|
|
1856
1901
|
|
|
1857
1902
|
if "dtype" in arg_values:
|
|
1858
1903
|
dtype = arg_values["dtype"]
|
|
@@ -1860,26 +1905,37 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
|
|
|
1860
1905
|
dtype = float
|
|
1861
1906
|
|
|
1862
1907
|
if arg_values["storage"] not in {"shared", "register"}:
|
|
1863
|
-
raise ValueError(
|
|
1864
|
-
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
1865
|
-
)
|
|
1908
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
1866
1909
|
|
|
1867
1910
|
return TileRange(dtype=dtype, start=start, stop=stop, step=step, storage=arg_values["storage"])
|
|
1868
1911
|
|
|
1869
1912
|
|
|
1870
1913
|
def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
1871
|
-
|
|
1914
|
+
size, dtype = return_type.size, return_type.dtype
|
|
1872
1915
|
|
|
1873
1916
|
template_args = []
|
|
1874
1917
|
template_args.append(dtype)
|
|
1875
|
-
template_args.append(
|
|
1876
|
-
|
|
1918
|
+
template_args.append(size)
|
|
1919
|
+
|
|
1920
|
+
if "args" not in arg_values:
|
|
1921
|
+
raise TypeError("tile_arange() requires at least one positional argument specifying the range")
|
|
1877
1922
|
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
1923
|
+
args = arg_values["args"]
|
|
1924
|
+
|
|
1925
|
+
if len(args) == 1:
|
|
1926
|
+
start = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.start)
|
|
1927
|
+
stop = args[0]
|
|
1928
|
+
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
|
|
1929
|
+
elif len(args) == 2:
|
|
1930
|
+
start = args[0]
|
|
1931
|
+
stop = args[1]
|
|
1932
|
+
step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
|
|
1933
|
+
elif len(args) == 3:
|
|
1934
|
+
start = args[0]
|
|
1935
|
+
stop = args[1]
|
|
1936
|
+
step = args[2]
|
|
1937
|
+
else:
|
|
1938
|
+
raise TypeError(f"tile_arange() accepts at most 3 positional arguments, got {len(args)}")
|
|
1883
1939
|
|
|
1884
1940
|
function_args = []
|
|
1885
1941
|
function_args.append(start)
|
|
@@ -1897,7 +1953,7 @@ add_builtin(
|
|
|
1897
1953
|
dispatch_func=tile_arange_dispatch_func,
|
|
1898
1954
|
variadic=True,
|
|
1899
1955
|
missing_grad=True,
|
|
1900
|
-
doc="""
|
|
1956
|
+
doc="""Generate a tile of linearly spaced elements.
|
|
1901
1957
|
|
|
1902
1958
|
:param args: Variable-length positional arguments, interpreted as:
|
|
1903
1959
|
|
|
@@ -1905,246 +1961,157 @@ add_builtin(
|
|
|
1905
1961
|
- ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
|
|
1906
1962
|
- ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
|
|
1907
1963
|
|
|
1908
|
-
:param dtype:
|
|
1964
|
+
:param dtype: Data type of output tile's elements (optional, default: ``float``)
|
|
1909
1965
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1910
1966
|
(default) or ``"shared"`` for shared memory.
|
|
1911
|
-
:returns: A tile with ``shape=(
|
|
1967
|
+
:returns: A tile with ``shape=(n)`` with linearly spaced elements of specified data type""",
|
|
1912
1968
|
group="Tile Primitives",
|
|
1913
1969
|
export=False,
|
|
1914
1970
|
)
|
|
1915
1971
|
|
|
1916
1972
|
|
|
1917
|
-
def
|
|
1918
|
-
# return generic type (for doc builds)
|
|
1973
|
+
def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
1919
1974
|
if arg_types is None:
|
|
1920
|
-
return
|
|
1921
|
-
|
|
1922
|
-
if not is_array(arg_types["a"]):
|
|
1923
|
-
raise RuntimeError("tile_load() argument 0 must be an array")
|
|
1975
|
+
return array(dtype=Scalar)
|
|
1924
1976
|
|
|
1925
|
-
|
|
1926
|
-
raise RuntimeError(
|
|
1927
|
-
"tile_load() argument 0 must be 1-dimensional if using the ``wp.tile_load(array, i, n)`` syntax."
|
|
1928
|
-
)
|
|
1977
|
+
a = arg_types["a"]
|
|
1929
1978
|
|
|
1930
|
-
|
|
1931
|
-
|
|
1979
|
+
shape = tile_unpack_shape(arg_values)
|
|
1980
|
+
offset = tile_unpack_offset(arg_values, a.ndim)
|
|
1932
1981
|
|
|
1933
|
-
if
|
|
1934
|
-
raise
|
|
1982
|
+
if a.ndim != len(shape):
|
|
1983
|
+
raise ValueError(
|
|
1984
|
+
f"tile_load() array argument must have same number of dimensions as the tile shape, trying to perform an {len(shape)} dimensional load from an array with {a.ndim} dimensions."
|
|
1985
|
+
)
|
|
1935
1986
|
|
|
1936
|
-
if
|
|
1987
|
+
if a.ndim != len(offset):
|
|
1937
1988
|
raise ValueError(
|
|
1938
|
-
f"
|
|
1989
|
+
f"tile_load() offset argument must have the same number of dimensions as the array to load from, got {len(offset)} indices for an array with {a.ndim} dimensions"
|
|
1939
1990
|
)
|
|
1940
1991
|
|
|
1941
|
-
|
|
1942
|
-
|
|
1992
|
+
if arg_values["storage"] not in {"shared", "register"}:
|
|
1993
|
+
raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
|
|
1943
1994
|
|
|
1944
|
-
return
|
|
1995
|
+
return Tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
|
|
1945
1996
|
|
|
1946
1997
|
|
|
1947
|
-
def
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
dtype = arg_values["a"].type.dtype
|
|
1998
|
+
def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
1999
|
+
a = args["a"]
|
|
2000
|
+
shape = tile_unpack_shape(args)
|
|
2001
|
+
offset = tile_unpack_offset(args, a.type.ndim)
|
|
1952
2002
|
|
|
1953
|
-
|
|
1954
|
-
template_args.
|
|
1955
|
-
template_args.append(n)
|
|
2003
|
+
func_args = (a, *offset)
|
|
2004
|
+
template_args = (d.constant for d in shape)
|
|
1956
2005
|
|
|
1957
|
-
return (
|
|
2006
|
+
return (func_args, template_args)
|
|
1958
2007
|
|
|
1959
2008
|
|
|
1960
2009
|
add_builtin(
|
|
1961
2010
|
"tile_load",
|
|
1962
|
-
input_types={"a": array(dtype=Any), "
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
2011
|
+
input_types={"a": array(dtype=Any), "shape": Tuple[int, ...], "offset": Tuple[int, ...], "storage": str},
|
|
2012
|
+
value_func=tile_load_tuple_value_func,
|
|
2013
|
+
dispatch_func=tile_load_tuple_dispatch_func,
|
|
2014
|
+
defaults={"offset": None, "storage": "register"},
|
|
1966
2015
|
variadic=False,
|
|
1967
|
-
doc="""Loads a
|
|
2016
|
+
doc="""Loads a tile from a global memory array.
|
|
1968
2017
|
|
|
1969
2018
|
This method will cooperatively load a tile from global memory using all threads in the block.
|
|
1970
2019
|
|
|
1971
2020
|
:param a: The source array in global memory
|
|
1972
|
-
:param
|
|
1973
|
-
:param
|
|
2021
|
+
:param shape: Shape of the tile to load, must have the same number of dimensions as ``a``
|
|
2022
|
+
:param offset: Offset in the source array to begin reading from (optional)
|
|
1974
2023
|
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
1975
2024
|
(default) or ``"shared"`` for shared memory.
|
|
1976
|
-
:returns: A tile with
|
|
2025
|
+
:returns: A tile with shape as specified and data type the same as the source array""",
|
|
1977
2026
|
group="Tile Primitives",
|
|
1978
2027
|
export=False,
|
|
1979
2028
|
)
|
|
1980
2029
|
|
|
1981
|
-
|
|
1982
|
-
def tile_load_2d_value_func(arg_types, arg_values):
|
|
1983
|
-
# return generic type (for doc builds)
|
|
1984
|
-
if arg_types is None:
|
|
1985
|
-
return Tile(dtype=Any, M=Any, N=Any)
|
|
1986
|
-
|
|
1987
|
-
if not is_array(arg_types["a"]):
|
|
1988
|
-
raise RuntimeError("tile_load() argument 0 must be an array")
|
|
1989
|
-
|
|
1990
|
-
if arg_types["a"].ndim != 2:
|
|
1991
|
-
raise RuntimeError(
|
|
1992
|
-
"tile_load() argument 0 must be 2-dimensional if using the ``wp.tile_load(array, i, j, m, n)`` syntax."
|
|
1993
|
-
)
|
|
1994
|
-
|
|
1995
|
-
if not type_is_int(arg_types["i"]):
|
|
1996
|
-
raise RuntimeError("tile_load() argument 1 must be an integer")
|
|
1997
|
-
|
|
1998
|
-
if not type_is_int(arg_types["j"]):
|
|
1999
|
-
raise RuntimeError("tile_load() argument 1 must be an integer")
|
|
2000
|
-
|
|
2001
|
-
if "m" not in arg_values:
|
|
2002
|
-
raise RuntimeError("'m' keyword argument must be specified when calling tile_load() function")
|
|
2003
|
-
|
|
2004
|
-
if "n" not in arg_values:
|
|
2005
|
-
raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
|
|
2006
|
-
|
|
2007
|
-
if arg_values["storage"] not in {"shared", "register"}:
|
|
2008
|
-
raise ValueError(
|
|
2009
|
-
f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
|
|
2010
|
-
)
|
|
2011
|
-
|
|
2012
|
-
a = arg_types["a"]
|
|
2013
|
-
m, n = arg_values["m"], arg_values["n"]
|
|
2014
|
-
|
|
2015
|
-
return TileLoad(a, m, n, arg_values["storage"])
|
|
2016
|
-
|
|
2017
|
-
|
|
2018
|
-
def tile_load_2d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2019
|
-
array = arg_values["a"]
|
|
2020
|
-
i, j = arg_values["i"], arg_values["j"]
|
|
2021
|
-
m, n = arg_values["m"].constant, arg_values["n"].constant
|
|
2022
|
-
dtype = arg_values["a"].type.dtype
|
|
2023
|
-
|
|
2024
|
-
template_args = []
|
|
2025
|
-
template_args.append(dtype)
|
|
2026
|
-
template_args.append(m)
|
|
2027
|
-
template_args.append(n)
|
|
2028
|
-
|
|
2029
|
-
return ((array, i, j), template_args)
|
|
2030
|
-
|
|
2031
|
-
|
|
2030
|
+
# overload for scalar shape
|
|
2032
2031
|
add_builtin(
|
|
2033
2032
|
"tile_load",
|
|
2034
|
-
input_types={"a": array(dtype=Any), "
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
variadic=False,
|
|
2039
|
-
doc="""Loads a 2D tile from a global memory array.
|
|
2040
|
-
|
|
2041
|
-
This method will cooperatively load a tile from global memory using all threads in the block.
|
|
2042
|
-
|
|
2043
|
-
:param a: The source array in global memory
|
|
2044
|
-
:param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
|
|
2045
|
-
:param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
|
|
2046
|
-
:param m: The size of the tile's first dimension
|
|
2047
|
-
:param n: The size of the tile's second dimension
|
|
2048
|
-
:param storage: The storage location for the tile: ``"register"`` for registers
|
|
2049
|
-
(default) or ``"shared"`` for shared memory.
|
|
2050
|
-
:returns: A tile with ``shape=(m,n)`` and dtype the same as the source array""",
|
|
2033
|
+
input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str},
|
|
2034
|
+
value_func=tile_load_tuple_value_func,
|
|
2035
|
+
dispatch_func=tile_load_tuple_dispatch_func,
|
|
2036
|
+
defaults={"offset": None, "storage": "register"},
|
|
2051
2037
|
group="Tile Primitives",
|
|
2038
|
+
hidden=True,
|
|
2052
2039
|
export=False,
|
|
2053
2040
|
)
|
|
2054
2041
|
|
|
2055
2042
|
|
|
2056
|
-
def
|
|
2043
|
+
def tile_store_value_func(arg_types, arg_values):
|
|
2057
2044
|
# return generic type (for doc builds)
|
|
2058
2045
|
if arg_types is None:
|
|
2059
2046
|
return None
|
|
2060
2047
|
|
|
2061
|
-
|
|
2062
|
-
|
|
2048
|
+
a = arg_types["a"]
|
|
2049
|
+
t = arg_types["t"]
|
|
2063
2050
|
|
|
2064
|
-
|
|
2065
|
-
raise RuntimeError("tile_store() argument 0 must be an array")
|
|
2051
|
+
c = tile_unpack_offset(arg_types, a.ndim)
|
|
2066
2052
|
|
|
2067
|
-
if
|
|
2068
|
-
raise
|
|
2069
|
-
"
|
|
2053
|
+
if len(c) != a.ndim:
|
|
2054
|
+
raise ValueError(
|
|
2055
|
+
f"tile_store() 'a' argument must have {len(c)} dimensions, "
|
|
2056
|
+
f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
|
|
2070
2057
|
)
|
|
2071
2058
|
|
|
2072
|
-
if
|
|
2073
|
-
raise
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2059
|
+
if len(t.shape) != a.ndim:
|
|
2060
|
+
raise ValueError(
|
|
2061
|
+
f"tile_store() 'a' argument must have the same number of dimensions as the 't' argument, "
|
|
2062
|
+
f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
|
|
2063
|
+
)
|
|
2077
2064
|
|
|
2078
2065
|
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2079
|
-
raise
|
|
2066
|
+
raise TypeError(
|
|
2067
|
+
f"tile_store() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2068
|
+
)
|
|
2080
2069
|
|
|
2081
2070
|
return None
|
|
2082
2071
|
|
|
2083
2072
|
|
|
2073
|
+
def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2074
|
+
a = args["a"]
|
|
2075
|
+
t = args["t"]
|
|
2076
|
+
|
|
2077
|
+
offset = tile_unpack_offset(args, a.type.ndim)
|
|
2078
|
+
|
|
2079
|
+
func_args = (a, *offset, t)
|
|
2080
|
+
template_args = []
|
|
2081
|
+
|
|
2082
|
+
return (func_args, template_args)
|
|
2083
|
+
|
|
2084
|
+
|
|
2084
2085
|
add_builtin(
|
|
2085
2086
|
"tile_store",
|
|
2086
|
-
input_types={"a": array(dtype=Any), "
|
|
2087
|
-
value_func=
|
|
2087
|
+
input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
|
|
2088
|
+
value_func=tile_store_value_func,
|
|
2089
|
+
dispatch_func=tile_store_dispatch_func,
|
|
2090
|
+
defaults={"offset": None},
|
|
2088
2091
|
variadic=False,
|
|
2089
2092
|
skip_replay=True,
|
|
2090
|
-
doc="""
|
|
2093
|
+
doc="""Store a tile to a global memory array.
|
|
2091
2094
|
|
|
2092
2095
|
This method will cooperatively store a tile to global memory using all threads in the block.
|
|
2093
2096
|
|
|
2094
2097
|
:param a: The destination array in global memory
|
|
2095
|
-
:param
|
|
2096
|
-
:param
|
|
2098
|
+
:param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
|
|
2099
|
+
:param offset: Offset in the destination array (optional)""",
|
|
2097
2100
|
group="Tile Primitives",
|
|
2098
2101
|
export=False,
|
|
2099
2102
|
)
|
|
2100
2103
|
|
|
2101
|
-
|
|
2102
|
-
def tile_store_2d_value_func(arg_types, arg_values):
|
|
2103
|
-
# return generic type (for doc builds)
|
|
2104
|
-
if arg_types is None:
|
|
2105
|
-
return None
|
|
2106
|
-
|
|
2107
|
-
if len(arg_types) != 4:
|
|
2108
|
-
raise RuntimeError("tile_store() requires 4 positional args")
|
|
2109
|
-
|
|
2110
|
-
if not is_array(arg_types["a"]):
|
|
2111
|
-
raise RuntimeError("tile_store() argument 0 must be an array")
|
|
2112
|
-
|
|
2113
|
-
if arg_types["a"].ndim != 2:
|
|
2114
|
-
raise RuntimeError(
|
|
2115
|
-
"tile_load() argument 0 must be a 2-dimensional array if using the ``wp.tile_store(array, i, j, t)`` syntax."
|
|
2116
|
-
)
|
|
2117
|
-
|
|
2118
|
-
if not type_is_int(arg_types["i"]):
|
|
2119
|
-
raise RuntimeError("tile_store() argument 1 must be an integer")
|
|
2120
|
-
|
|
2121
|
-
if not type_is_int(arg_types["j"]):
|
|
2122
|
-
raise RuntimeError("tile_store() argument 2 must be an integer")
|
|
2123
|
-
|
|
2124
|
-
if not is_tile(arg_types["t"]):
|
|
2125
|
-
raise RuntimeError("tile_store() argument 3 must be a tile")
|
|
2126
|
-
|
|
2127
|
-
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2128
|
-
raise RuntimeError("tile_store() destination array must have same type as source tile")
|
|
2129
|
-
|
|
2130
|
-
return None
|
|
2131
|
-
|
|
2132
|
-
|
|
2104
|
+
# overload for scalar offset
|
|
2133
2105
|
add_builtin(
|
|
2134
2106
|
"tile_store",
|
|
2135
|
-
input_types={"a": array(dtype=Any), "
|
|
2136
|
-
value_func=
|
|
2107
|
+
input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
|
|
2108
|
+
value_func=tile_store_value_func,
|
|
2109
|
+
dispatch_func=tile_store_dispatch_func,
|
|
2110
|
+
defaults={"offset": None},
|
|
2137
2111
|
variadic=False,
|
|
2138
2112
|
skip_replay=True,
|
|
2139
|
-
doc="""Stores a tile to a global memory array.
|
|
2140
|
-
|
|
2141
|
-
This method will cooperatively store a tile to global memory using all threads in the block.
|
|
2142
|
-
|
|
2143
|
-
:param a: The destination array in global memory
|
|
2144
|
-
:param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
|
|
2145
|
-
:param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
|
|
2146
|
-
:param t: The source tile to store data from, must have the same dtype as the destination array""",
|
|
2147
2113
|
group="Tile Primitives",
|
|
2114
|
+
hidden=True,
|
|
2148
2115
|
export=False,
|
|
2149
2116
|
)
|
|
2150
2117
|
|
|
@@ -2152,132 +2119,221 @@ add_builtin(
|
|
|
2152
2119
|
def tile_atomic_add_value_func(arg_types, arg_values):
|
|
2153
2120
|
# return generic type (for doc builds)
|
|
2154
2121
|
if arg_types is None:
|
|
2155
|
-
return Tile(dtype=Any,
|
|
2122
|
+
return Tile(dtype=Any, shape=Any)
|
|
2156
2123
|
|
|
2157
|
-
|
|
2158
|
-
|
|
2124
|
+
a = arg_types["a"]
|
|
2125
|
+
t = arg_types["t"]
|
|
2159
2126
|
|
|
2160
|
-
|
|
2161
|
-
|
|
2127
|
+
c = tile_unpack_offset(arg_types, a.ndim)
|
|
2128
|
+
if len(c) != a.ndim:
|
|
2129
|
+
raise ValueError(
|
|
2130
|
+
f"tile_atomic_add() 'a' argument must have {len(c)} dimensions, "
|
|
2131
|
+
f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
|
|
2132
|
+
)
|
|
2162
2133
|
|
|
2163
|
-
if
|
|
2164
|
-
raise
|
|
2134
|
+
if a.ndim != len(t.shape):
|
|
2135
|
+
raise ValueError(
|
|
2136
|
+
f"tile_atomic_add() 'a' argument must have the same number of dimensions as the 't' argument, "
|
|
2137
|
+
f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
|
|
2138
|
+
)
|
|
2139
|
+
|
|
2140
|
+
if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
|
|
2141
|
+
raise TypeError(
|
|
2142
|
+
f"tile_atomic_add() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2143
|
+
)
|
|
2165
2144
|
|
|
2166
|
-
|
|
2167
|
-
raise RuntimeError("tile_atomic_add() argument 2 must be an integer")
|
|
2145
|
+
return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape)
|
|
2168
2146
|
|
|
2169
|
-
if not is_tile(arg_types["t"]):
|
|
2170
|
-
raise RuntimeError("tile_atomic_add() argument 3 must be a tile")
|
|
2171
2147
|
|
|
2172
|
-
|
|
2173
|
-
|
|
2148
|
+
def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2149
|
+
a = args["a"]
|
|
2150
|
+
t = args["t"]
|
|
2174
2151
|
|
|
2175
|
-
|
|
2152
|
+
offset = tile_unpack_offset(args, a.type.ndim)
|
|
2153
|
+
|
|
2154
|
+
func_args = (a, *offset, t)
|
|
2155
|
+
template_args = []
|
|
2156
|
+
|
|
2157
|
+
return (func_args, template_args)
|
|
2176
2158
|
|
|
2177
2159
|
|
|
2178
2160
|
add_builtin(
|
|
2179
2161
|
"tile_atomic_add",
|
|
2180
|
-
input_types={"a": array(dtype=Any), "
|
|
2162
|
+
input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
|
|
2181
2163
|
value_func=tile_atomic_add_value_func,
|
|
2182
|
-
|
|
2164
|
+
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2165
|
+
defaults={"offset": None},
|
|
2166
|
+
variadic=False,
|
|
2183
2167
|
skip_replay=True,
|
|
2184
|
-
doc="""Atomically add a tile to the array `a`, each element will be updated atomically.
|
|
2168
|
+
doc="""Atomically add a 1D tile to the array `a`, each element will be updated atomically.
|
|
2185
2169
|
|
|
2186
2170
|
:param a: Array in global memory, should have the same ``dtype`` as the input tile
|
|
2187
|
-
:param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
|
|
2188
|
-
:param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
|
|
2189
2171
|
:param t: Source tile to add to the destination array
|
|
2190
|
-
:
|
|
2172
|
+
:param offset: Offset in the destination array (optional)
|
|
2173
|
+
:returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements""",
|
|
2191
2174
|
group="Tile Primitives",
|
|
2192
2175
|
export=False,
|
|
2193
2176
|
)
|
|
2194
2177
|
|
|
2178
|
+
# overload for scalar offset
|
|
2179
|
+
add_builtin(
|
|
2180
|
+
"tile_atomic_add",
|
|
2181
|
+
input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
|
|
2182
|
+
value_func=tile_atomic_add_value_func,
|
|
2183
|
+
dispatch_func=tile_atomic_add_dispatch_func,
|
|
2184
|
+
defaults={"offset": None},
|
|
2185
|
+
variadic=False,
|
|
2186
|
+
skip_replay=True,
|
|
2187
|
+
group="Tile Primitives",
|
|
2188
|
+
hidden=True,
|
|
2189
|
+
export=False,
|
|
2190
|
+
)
|
|
2191
|
+
|
|
2195
2192
|
|
|
2196
2193
|
def tile_view_value_func(arg_types, arg_values):
|
|
2197
2194
|
# return generic type (for doc builds)
|
|
2198
2195
|
if arg_types is None:
|
|
2199
|
-
return Tile(dtype=Any,
|
|
2196
|
+
return Tile(dtype=Any, shape=Any)
|
|
2200
2197
|
|
|
2201
2198
|
tile = arg_types["t"]
|
|
2199
|
+
offset = arg_types["offset"]
|
|
2202
2200
|
|
|
2203
|
-
if
|
|
2204
|
-
|
|
2205
|
-
else:
|
|
2206
|
-
m = arg_values["m"]
|
|
2201
|
+
if len(offset) > len(tile.shape):
|
|
2202
|
+
raise ValueError(f"tile_view() specified too many offset coordinates {len(offset)} > {len(tile.shape)}")
|
|
2207
2203
|
|
|
2208
|
-
if "
|
|
2209
|
-
|
|
2204
|
+
if "shape" in arg_values:
|
|
2205
|
+
# if shape is specified take it directly, e.g.:
|
|
2206
|
+
# tile_view(t, offset=(i,j), shape=(m,n))
|
|
2207
|
+
shape = arg_values["shape"]
|
|
2208
|
+
strides = tile.strides
|
|
2209
|
+
|
|
2210
|
+
if len(shape) != len(tile.shape):
|
|
2211
|
+
raise ValueError(
|
|
2212
|
+
f"tile_view() if shape is specified it must have same number of dimensions as source tile, expected {len(tile.shape)}, got {len(shape)}"
|
|
2213
|
+
)
|
|
2210
2214
|
else:
|
|
2211
|
-
|
|
2215
|
+
# if not specified, then take output shape from unspecified src dimensions
|
|
2216
|
+
# e.g.: tile[i] will return a whole row of a 2D tile
|
|
2217
|
+
shape = tile.shape[len(offset) :]
|
|
2218
|
+
strides = tile.strides[len(offset) :]
|
|
2212
2219
|
|
|
2213
|
-
|
|
2214
|
-
raise RuntimeError(
|
|
2215
|
-
f"Trying to view a tile subrange with dimensions ({m}, {n}) which is larger than source tile with dimensions ({tile.M}, {tile.N})"
|
|
2216
|
-
)
|
|
2220
|
+
assert len(shape) == len(strides)
|
|
2217
2221
|
|
|
2218
2222
|
# force source tile to shared memory
|
|
2219
2223
|
tile.storage = "shared"
|
|
2220
2224
|
|
|
2221
|
-
output = Tile(dtype=tile.dtype,
|
|
2225
|
+
output = Tile(dtype=tile.dtype, shape=shape, strides=strides, layout=tile.layout, storage="shared", owner=False)
|
|
2222
2226
|
return output
|
|
2223
2227
|
|
|
2224
2228
|
|
|
2225
2229
|
def tile_view_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
|
|
2226
2230
|
tile = arg_values["t"]
|
|
2227
|
-
|
|
2228
|
-
|
|
2229
|
-
if "j" not in arg_values:
|
|
2230
|
-
j = warp.codegen.Var(label=None, type=int, constant=0)
|
|
2231
|
-
else:
|
|
2232
|
-
j = arg_values["j"]
|
|
2231
|
+
coord = arg_values["offset"]
|
|
2233
2232
|
|
|
2234
|
-
|
|
2235
|
-
|
|
2236
|
-
|
|
2233
|
+
# zero-pad coord to match source array
|
|
2234
|
+
view_coord = [0] * len(tile.type.shape)
|
|
2235
|
+
for i in range(len(coord)):
|
|
2236
|
+
view_coord[i] = coord[i]
|
|
2237
2237
|
|
|
2238
|
-
return ((tile,
|
|
2238
|
+
return ((tile, *view_coord), (return_type,))
|
|
2239
2239
|
|
|
2240
2240
|
|
|
2241
2241
|
add_builtin(
|
|
2242
2242
|
"tile_view",
|
|
2243
|
-
input_types={"t": Tile(dtype=Any,
|
|
2243
|
+
input_types={"t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...], "shape": Tuple[int, ...]},
|
|
2244
2244
|
value_func=tile_view_value_func,
|
|
2245
2245
|
dispatch_func=tile_view_dispatch_func,
|
|
2246
|
-
defaults={"
|
|
2247
|
-
variadic=
|
|
2248
|
-
doc="""Return a
|
|
2246
|
+
defaults={"shape": None},
|
|
2247
|
+
variadic=False,
|
|
2248
|
+
doc="""Return a slice of a given tile [offset, offset+shape], if shape is not specified it will be inferred from the unspecified offset dimensions.
|
|
2249
2249
|
|
|
2250
2250
|
:param t: Input tile to extract a subrange from
|
|
2251
|
-
:param
|
|
2252
|
-
:param
|
|
2253
|
-
:
|
|
2254
|
-
:param n: Size of the subrange to return along the second dimension
|
|
2255
|
-
:returns: A tile with dimensions (m,n) and the same datatype as the input tile""",
|
|
2251
|
+
:param offset: Offset in the source tile
|
|
2252
|
+
:param shape: Shape of the returned slice
|
|
2253
|
+
:returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
|
|
2256
2254
|
group="Tile Primitives",
|
|
2255
|
+
missing_grad=True,
|
|
2257
2256
|
export=False,
|
|
2258
2257
|
)
|
|
2259
2258
|
|
|
2260
2259
|
|
|
2261
2260
|
def tile_assign_value_func(arg_types, arg_values):
|
|
2262
|
-
|
|
2261
|
+
if arg_types is None:
|
|
2262
|
+
return None
|
|
2263
|
+
|
|
2264
|
+
# force the destination tile to shared memory
|
|
2265
|
+
arg_types["dst"].storage = "shared"
|
|
2263
2266
|
return None
|
|
2264
2267
|
|
|
2265
2268
|
|
|
2269
|
+
def tile_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2270
|
+
dst = args["dst"]
|
|
2271
|
+
src = args["src"]
|
|
2272
|
+
|
|
2273
|
+
offset = tile_unpack_offset(args, len(dst.type.shape))
|
|
2274
|
+
|
|
2275
|
+
func_args = (dst, src, *offset)
|
|
2276
|
+
template_args = []
|
|
2277
|
+
|
|
2278
|
+
return (func_args, template_args)
|
|
2279
|
+
|
|
2280
|
+
|
|
2266
2281
|
add_builtin(
|
|
2267
2282
|
"tile_assign",
|
|
2268
|
-
input_types={"dst": Tile(dtype=Any,
|
|
2283
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "src": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
|
|
2269
2284
|
value_func=tile_assign_value_func,
|
|
2270
|
-
|
|
2271
|
-
|
|
2285
|
+
dispatch_func=tile_assign_dispatch_func,
|
|
2286
|
+
defaults={"offset": None},
|
|
2287
|
+
doc="""Assign a tile to a subrange of a destination tile.
|
|
2272
2288
|
|
|
2273
|
-
:param
|
|
2274
|
-
:param
|
|
2275
|
-
:param
|
|
2276
|
-
:param src: The source tile to read values from""",
|
|
2289
|
+
:param dst: The destination tile to assign to
|
|
2290
|
+
:param src: The source tile to read values from
|
|
2291
|
+
:param offset: Offset in the destination tile to write to""",
|
|
2277
2292
|
group="Tile Primitives",
|
|
2278
2293
|
export=False,
|
|
2279
2294
|
)
|
|
2280
2295
|
|
|
2296
|
+
# handles expressions like tile[i,j] = 1.0
|
|
2297
|
+
add_builtin(
|
|
2298
|
+
"assign",
|
|
2299
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "src": Scalar},
|
|
2300
|
+
value_func=tile_assign_value_func,
|
|
2301
|
+
group="Tile Primitives",
|
|
2302
|
+
export=False,
|
|
2303
|
+
hidden=True,
|
|
2304
|
+
missing_grad=True,
|
|
2305
|
+
)
|
|
2306
|
+
|
|
2307
|
+
add_builtin(
|
|
2308
|
+
"assign",
|
|
2309
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "src": Scalar},
|
|
2310
|
+
value_func=tile_assign_value_func,
|
|
2311
|
+
group="Tile Primitives",
|
|
2312
|
+
export=False,
|
|
2313
|
+
hidden=True,
|
|
2314
|
+
missing_grad=True,
|
|
2315
|
+
)
|
|
2316
|
+
|
|
2317
|
+
add_builtin(
|
|
2318
|
+
"assign",
|
|
2319
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "src": Scalar},
|
|
2320
|
+
value_func=tile_assign_value_func,
|
|
2321
|
+
group="Tile Primitives",
|
|
2322
|
+
export=False,
|
|
2323
|
+
hidden=True,
|
|
2324
|
+
missing_grad=True,
|
|
2325
|
+
)
|
|
2326
|
+
|
|
2327
|
+
add_builtin(
|
|
2328
|
+
"assign",
|
|
2329
|
+
input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int, "src": Scalar},
|
|
2330
|
+
value_func=tile_assign_value_func,
|
|
2331
|
+
group="Tile Primitives",
|
|
2332
|
+
export=False,
|
|
2333
|
+
hidden=True,
|
|
2334
|
+
missing_grad=True,
|
|
2335
|
+
)
|
|
2336
|
+
|
|
2281
2337
|
|
|
2282
2338
|
def tile_value_func(arg_types, arg_values):
|
|
2283
2339
|
# return generic type (for doc builds)
|
|
@@ -2285,7 +2341,7 @@ def tile_value_func(arg_types, arg_values):
|
|
|
2285
2341
|
return Tile
|
|
2286
2342
|
|
|
2287
2343
|
if len(arg_types) != 1:
|
|
2288
|
-
raise
|
|
2344
|
+
raise TypeError(f"tile() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2289
2345
|
|
|
2290
2346
|
dtype = None
|
|
2291
2347
|
length = None
|
|
@@ -2293,11 +2349,12 @@ def tile_value_func(arg_types, arg_values):
|
|
|
2293
2349
|
if type_is_vector(arg_types["x"]):
|
|
2294
2350
|
dtype = arg_types["x"]._wp_scalar_type_
|
|
2295
2351
|
length = arg_types["x"]._shape_[0]
|
|
2352
|
+
shape = (length, warp.codegen.options["block_dim"])
|
|
2296
2353
|
else:
|
|
2297
2354
|
dtype = arg_types["x"]
|
|
2298
|
-
|
|
2355
|
+
shape = (warp.codegen.options["block_dim"],)
|
|
2299
2356
|
|
|
2300
|
-
return Tile(dtype=dtype,
|
|
2357
|
+
return Tile(dtype=dtype, shape=shape, op="tile")
|
|
2301
2358
|
|
|
2302
2359
|
|
|
2303
2360
|
add_builtin(
|
|
@@ -2305,14 +2362,14 @@ add_builtin(
|
|
|
2305
2362
|
input_types={"x": Any},
|
|
2306
2363
|
value_func=tile_value_func,
|
|
2307
2364
|
variadic=True,
|
|
2308
|
-
doc="""
|
|
2365
|
+
doc="""Construct a new tile from per-thread kernel values.
|
|
2309
2366
|
|
|
2310
2367
|
This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
|
|
2311
2368
|
|
|
2312
2369
|
* If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
|
|
2313
2370
|
* If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
|
|
2314
2371
|
|
|
2315
|
-
:param x: A per-thread local value, e.g
|
|
2372
|
+
:param x: A per-thread local value, e.g. scalar, vector, or matrix.
|
|
2316
2373
|
:returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
|
|
2317
2374
|
|
|
2318
2375
|
This example shows how to create a linear sequence from thread variables:
|
|
@@ -2331,7 +2388,7 @@ add_builtin(
|
|
|
2331
2388
|
|
|
2332
2389
|
.. code-block:: text
|
|
2333
2390
|
|
|
2334
|
-
|
|
2391
|
+
[0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30] = tile(shape=(16), storage=register)
|
|
2335
2392
|
|
|
2336
2393
|
""",
|
|
2337
2394
|
group="Tile Primitives",
|
|
@@ -2345,38 +2402,40 @@ def untile_value_func(arg_types, arg_values):
|
|
|
2345
2402
|
return Scalar
|
|
2346
2403
|
|
|
2347
2404
|
if len(arg_types) != 1:
|
|
2348
|
-
raise
|
|
2405
|
+
raise TypeError(f"untile() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2349
2406
|
|
|
2350
2407
|
t = arg_types["a"]
|
|
2351
2408
|
|
|
2352
2409
|
if not is_tile(t):
|
|
2353
|
-
raise
|
|
2410
|
+
raise TypeError(f"untile() argument must be a tile, got {t!r}")
|
|
2354
2411
|
|
|
2355
|
-
if t.
|
|
2356
|
-
raise
|
|
2357
|
-
f"untile() argument
|
|
2412
|
+
if t.shape[-1] != warp.codegen.options["block_dim"]:
|
|
2413
|
+
raise ValueError(
|
|
2414
|
+
f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
|
|
2358
2415
|
)
|
|
2359
2416
|
|
|
2360
|
-
if t.
|
|
2417
|
+
if len(t.shape) == 1:
|
|
2361
2418
|
return t.dtype
|
|
2362
|
-
elif t.
|
|
2363
|
-
return warp.types.vector(t.
|
|
2419
|
+
elif len(t.shape) == 2:
|
|
2420
|
+
return warp.types.vector(t.shape[0], t.dtype)
|
|
2421
|
+
else:
|
|
2422
|
+
raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
|
|
2364
2423
|
|
|
2365
2424
|
|
|
2366
2425
|
add_builtin(
|
|
2367
2426
|
"untile",
|
|
2368
|
-
input_types={"a": Any},
|
|
2427
|
+
input_types={"a": Tile(dtype=Any, shape=Any)},
|
|
2369
2428
|
value_func=untile_value_func,
|
|
2370
2429
|
variadic=True,
|
|
2371
|
-
doc="""Convert a
|
|
2430
|
+
doc="""Convert a tile back to per-thread values.
|
|
2372
2431
|
|
|
2373
2432
|
This function converts a block-wide tile back to per-thread values.
|
|
2374
2433
|
|
|
2375
|
-
* If the input tile is
|
|
2376
|
-
* If the input tile is
|
|
2434
|
+
* If the input tile is 1D, then the resulting value will be a per-thread scalar
|
|
2435
|
+
* If the input tile is 2D, then the resulting value will be a per-thread vector of length M
|
|
2377
2436
|
|
|
2378
2437
|
:param a: A tile with dimensions ``shape=(M, block_dim)``
|
|
2379
|
-
:returns: A single value per-thread with the same
|
|
2438
|
+
:returns: A single value per-thread with the same data type as the tile
|
|
2380
2439
|
|
|
2381
2440
|
This example shows how to create a linear sequence from thread variables:
|
|
2382
2441
|
|
|
@@ -2390,7 +2449,7 @@ add_builtin(
|
|
|
2390
2449
|
t = wp.tile(i)*2
|
|
2391
2450
|
|
|
2392
2451
|
# convert back to per-thread values
|
|
2393
|
-
s = wp.untile()
|
|
2452
|
+
s = wp.untile(t)
|
|
2394
2453
|
|
|
2395
2454
|
print(s)
|
|
2396
2455
|
|
|
@@ -2417,21 +2476,38 @@ def tile_extract_value_func(arg_types, arg_values):
|
|
|
2417
2476
|
if arg_types is None:
|
|
2418
2477
|
return Scalar
|
|
2419
2478
|
|
|
2420
|
-
|
|
2421
|
-
|
|
2422
|
-
|
|
2423
|
-
if not is_tile(arg_types["a"]):
|
|
2424
|
-
raise RuntimeError("tile_extract() argument 0 must be a tile")
|
|
2479
|
+
# force the input tile to shared memory
|
|
2480
|
+
arg_types["a"].storage = "shared"
|
|
2425
2481
|
|
|
2426
2482
|
return arg_types["a"].dtype
|
|
2427
2483
|
|
|
2428
2484
|
|
|
2429
2485
|
add_builtin(
|
|
2430
2486
|
"tile_extract",
|
|
2431
|
-
input_types={"a": Tile(dtype=Any,
|
|
2487
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "i": int},
|
|
2432
2488
|
value_func=tile_extract_value_func,
|
|
2433
|
-
variadic=
|
|
2434
|
-
doc="""
|
|
2489
|
+
variadic=False,
|
|
2490
|
+
doc="""Extract a single element from the tile and return it as a scalar type.
|
|
2491
|
+
|
|
2492
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2493
|
+
|
|
2494
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
2495
|
+
|
|
2496
|
+
:param a: Tile to extract the element from
|
|
2497
|
+
:param i: Coordinate of element on first dimension
|
|
2498
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
2499
|
+
group="Tile Primitives",
|
|
2500
|
+
hidden=True,
|
|
2501
|
+
export=False,
|
|
2502
|
+
)
|
|
2503
|
+
|
|
2504
|
+
|
|
2505
|
+
add_builtin(
|
|
2506
|
+
"tile_extract",
|
|
2507
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int},
|
|
2508
|
+
value_func=tile_extract_value_func,
|
|
2509
|
+
variadic=False,
|
|
2510
|
+
doc="""Extract a single element from the tile and return it as a scalar type.
|
|
2435
2511
|
|
|
2436
2512
|
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2437
2513
|
|
|
@@ -2440,8 +2516,52 @@ add_builtin(
|
|
|
2440
2516
|
:param a: Tile to extract the element from
|
|
2441
2517
|
:param i: Coordinate of element on first dimension
|
|
2442
2518
|
:param j: Coordinate of element on the second dimension
|
|
2443
|
-
:returns: The value of the element at the specified tile location
|
|
2519
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
2444
2520
|
group="Tile Primitives",
|
|
2521
|
+
hidden=True,
|
|
2522
|
+
export=False,
|
|
2523
|
+
)
|
|
2524
|
+
|
|
2525
|
+
add_builtin(
|
|
2526
|
+
"tile_extract",
|
|
2527
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int},
|
|
2528
|
+
value_func=tile_extract_value_func,
|
|
2529
|
+
variadic=False,
|
|
2530
|
+
doc="""Extract a single element from the tile and return it as a scalar type.
|
|
2531
|
+
|
|
2532
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2533
|
+
|
|
2534
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
2535
|
+
|
|
2536
|
+
:param a: Tile to extract the element from
|
|
2537
|
+
:param i: Coordinate of element on first dimension
|
|
2538
|
+
:param j: Coordinate of element on the second dimension
|
|
2539
|
+
:param k: Coordinate of element on the third dimension
|
|
2540
|
+
:returns: The value of the element at the specified tile location with the same data type as the input tile""",
|
|
2541
|
+
group="Tile Primitives",
|
|
2542
|
+
hidden=True,
|
|
2543
|
+
export=False,
|
|
2544
|
+
)
|
|
2545
|
+
|
|
2546
|
+
add_builtin(
|
|
2547
|
+
"tile_extract",
|
|
2548
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int},
|
|
2549
|
+
value_func=tile_extract_value_func,
|
|
2550
|
+
variadic=False,
|
|
2551
|
+
doc="""Extract a single element from the tile and return it as a scalar type.
|
|
2552
|
+
|
|
2553
|
+
This function will extract an element from the tile and broadcast its value to all threads in the block.
|
|
2554
|
+
|
|
2555
|
+
Note that this may incur additional synchronization if the source tile is a register tile.
|
|
2556
|
+
|
|
2557
|
+
:param a: Tile to extract the element from
|
|
2558
|
+
:param i: Coordinate of element on first dimension
|
|
2559
|
+
:param j: Coordinate of element on the second dimension
|
|
2560
|
+
:param k: Coordinate of element on the third dimension
|
|
2561
|
+
:param l: Coordinate of element on the fourth dimension
|
|
2562
|
+
:returns: The value of the element at the specified tile location, with the same data type as the input tile""",
|
|
2563
|
+
group="Tile Primitives",
|
|
2564
|
+
hidden=True,
|
|
2445
2565
|
export=False,
|
|
2446
2566
|
)
|
|
2447
2567
|
|
|
@@ -2452,12 +2572,12 @@ def tile_transpose_value_func(arg_types, arg_values):
|
|
|
2452
2572
|
return Tile
|
|
2453
2573
|
|
|
2454
2574
|
if len(arg_types) != 1:
|
|
2455
|
-
raise
|
|
2575
|
+
raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2456
2576
|
|
|
2457
2577
|
t = arg_types["a"]
|
|
2458
2578
|
|
|
2459
2579
|
if not is_tile(t):
|
|
2460
|
-
raise
|
|
2580
|
+
raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
|
|
2461
2581
|
|
|
2462
2582
|
layout = None
|
|
2463
2583
|
|
|
@@ -2472,8 +2592,7 @@ def tile_transpose_value_func(arg_types, arg_values):
|
|
|
2472
2592
|
|
|
2473
2593
|
return Tile(
|
|
2474
2594
|
dtype=t.dtype,
|
|
2475
|
-
|
|
2476
|
-
N=t.M,
|
|
2595
|
+
shape=t.shape[::-1],
|
|
2477
2596
|
op="transpose",
|
|
2478
2597
|
storage=t.storage,
|
|
2479
2598
|
strides=t.strides[::-1],
|
|
@@ -2484,12 +2603,13 @@ def tile_transpose_value_func(arg_types, arg_values):
|
|
|
2484
2603
|
|
|
2485
2604
|
add_builtin(
|
|
2486
2605
|
"tile_transpose",
|
|
2487
|
-
input_types={"a": Tile(dtype=Any,
|
|
2606
|
+
input_types={"a": Tile(dtype=Any, shape=Any)},
|
|
2488
2607
|
value_func=tile_transpose_value_func,
|
|
2489
2608
|
variadic=True,
|
|
2490
2609
|
doc="""Transpose a tile.
|
|
2491
2610
|
|
|
2492
|
-
For shared memory tiles this operation will alias the input tile
|
|
2611
|
+
For shared memory tiles, this operation will alias the input tile.
|
|
2612
|
+
Register tiles will first be transferred to shared memory before transposition.
|
|
2493
2613
|
|
|
2494
2614
|
:param a: Tile to transpose with ``shape=(M,N)``
|
|
2495
2615
|
:returns: Tile with ``shape=(N,M)``""",
|
|
@@ -2503,41 +2623,36 @@ def tile_broadcast_value_func(arg_types, arg_values):
|
|
|
2503
2623
|
if arg_types is None:
|
|
2504
2624
|
return Tile
|
|
2505
2625
|
|
|
2506
|
-
if len(arg_types) != 3:
|
|
2507
|
-
raise RuntimeError("tile_broadcast() requires 1 positional args")
|
|
2508
|
-
|
|
2509
2626
|
t = arg_types["a"]
|
|
2510
|
-
m = arg_values["m"]
|
|
2511
|
-
n = arg_values["n"]
|
|
2512
2627
|
|
|
2513
|
-
|
|
2514
|
-
|
|
2628
|
+
# target shape and strides
|
|
2629
|
+
target_shape = tile_unpack_shape(arg_values)
|
|
2630
|
+
target_strides = [0] * len(target_shape)
|
|
2515
2631
|
|
|
2516
|
-
|
|
2517
|
-
if t.N == 1:
|
|
2518
|
-
stride_n = 0
|
|
2519
|
-
elif t.N == n:
|
|
2520
|
-
stride_n = t.strides[1]
|
|
2521
|
-
else:
|
|
2522
|
-
raise RuntimeError(
|
|
2523
|
-
f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
|
|
2524
|
-
)
|
|
2632
|
+
offset = len(target_shape) - len(t.shape)
|
|
2525
2633
|
|
|
2526
|
-
#
|
|
2527
|
-
|
|
2528
|
-
|
|
2529
|
-
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2634
|
+
# compute target strides
|
|
2635
|
+
for i in reversed(range(len(target_shape))):
|
|
2636
|
+
j = i - offset
|
|
2637
|
+
|
|
2638
|
+
if j < 0:
|
|
2639
|
+
target_strides[i] = 0
|
|
2640
|
+
else:
|
|
2641
|
+
# try to broadcast each dimension
|
|
2642
|
+
if t.shape[j] == 1:
|
|
2643
|
+
target_strides[i] = 0
|
|
2644
|
+
elif t.shape[j] == target_shape[i]:
|
|
2645
|
+
target_strides[i] = t.strides[j]
|
|
2646
|
+
else:
|
|
2647
|
+
raise ValueError(
|
|
2648
|
+
f"tile_broadcast() cannot broadcast dimension {t.shape[j]} into {target_shape[i]} at index {i}"
|
|
2649
|
+
)
|
|
2535
2650
|
|
|
2536
2651
|
# force the input tile to shared memory
|
|
2537
2652
|
t.storage = "shared"
|
|
2538
2653
|
|
|
2539
2654
|
tile_type = Tile(
|
|
2540
|
-
dtype=t.dtype,
|
|
2655
|
+
dtype=t.dtype, shape=target_shape, op="broadcast", storage=t.storage, strides=target_strides, owner=False
|
|
2541
2656
|
)
|
|
2542
2657
|
return tile_type
|
|
2543
2658
|
|
|
@@ -2546,8 +2661,8 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
|
|
|
2546
2661
|
tile = arg_values["a"]
|
|
2547
2662
|
|
|
2548
2663
|
template_args = []
|
|
2549
|
-
template_args.append(return_type.
|
|
2550
|
-
template_args.append(return_type.
|
|
2664
|
+
template_args.append(return_type.shape[0])
|
|
2665
|
+
template_args.append(return_type.shape[1])
|
|
2551
2666
|
template_args.append(return_type.strides[0])
|
|
2552
2667
|
template_args.append(return_type.strides[1])
|
|
2553
2668
|
|
|
@@ -2556,15 +2671,18 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
|
|
|
2556
2671
|
|
|
2557
2672
|
add_builtin(
|
|
2558
2673
|
"tile_broadcast",
|
|
2559
|
-
input_types={"a": Tile(dtype=Any,
|
|
2674
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "shape": Tuple[int, ...]},
|
|
2560
2675
|
value_func=tile_broadcast_value_func,
|
|
2561
2676
|
dispatch_func=tile_broadcast_dispatch_func,
|
|
2562
|
-
variadic=
|
|
2677
|
+
variadic=False,
|
|
2563
2678
|
doc="""Broadcast a tile.
|
|
2564
2679
|
|
|
2565
|
-
This
|
|
2680
|
+
This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n).
|
|
2681
|
+
|
|
2682
|
+
Broadcasting follows NumPy broadcast rules.
|
|
2566
2683
|
|
|
2567
2684
|
:param a: Tile to broadcast
|
|
2685
|
+
:param shape: The shape to broadcast to
|
|
2568
2686
|
:returns: Tile with broadcast ``shape=(m, n)``""",
|
|
2569
2687
|
group="Tile Primitives",
|
|
2570
2688
|
export=False,
|
|
@@ -2574,19 +2692,10 @@ add_builtin(
|
|
|
2574
2692
|
def tile_matmul_value_func(arg_types, arg_values):
|
|
2575
2693
|
# return generic type (for doc builds)
|
|
2576
2694
|
if arg_types is None:
|
|
2577
|
-
return Tile(dtype=Any,
|
|
2695
|
+
return Tile(dtype=Any, shape=Any)
|
|
2578
2696
|
|
|
2579
2697
|
if len(arg_types) != 3:
|
|
2580
|
-
raise
|
|
2581
|
-
|
|
2582
|
-
if not is_tile(arg_types["a"]):
|
|
2583
|
-
raise RuntimeError("tile_matmul() argument 0 must be a tile")
|
|
2584
|
-
|
|
2585
|
-
if not is_tile(arg_types["b"]):
|
|
2586
|
-
raise RuntimeError("tile_matmul() argument 1 must be an tile")
|
|
2587
|
-
|
|
2588
|
-
if not isinstance(arg_types["out"], Tile):
|
|
2589
|
-
raise RuntimeError("tile_matmul() output argument must be a tile")
|
|
2698
|
+
raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
|
|
2590
2699
|
|
|
2591
2700
|
return None
|
|
2592
2701
|
|
|
@@ -2621,17 +2730,17 @@ add_builtin(
|
|
|
2621
2730
|
def tile_sum_value_func(arg_types, arg_values):
|
|
2622
2731
|
# return generic type (for doc builds)
|
|
2623
2732
|
if arg_types is None:
|
|
2624
|
-
return Tile(dtype=Any,
|
|
2733
|
+
return Tile(dtype=Any, shape=(1,))
|
|
2625
2734
|
|
|
2626
2735
|
if len(arg_types) != 1:
|
|
2627
|
-
raise
|
|
2736
|
+
raise TypeError(f"tile_sum() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2628
2737
|
|
|
2629
2738
|
a = arg_types["a"]
|
|
2630
2739
|
|
|
2631
2740
|
if not is_tile(a):
|
|
2632
|
-
raise
|
|
2741
|
+
raise TypeError(f"tile_sum() argument must be a tile, got {a!r}")
|
|
2633
2742
|
|
|
2634
|
-
return Tile(dtype=a.dtype,
|
|
2743
|
+
return Tile(dtype=a.dtype, shape=(1,), op="sum")
|
|
2635
2744
|
|
|
2636
2745
|
|
|
2637
2746
|
add_builtin(
|
|
@@ -2642,7 +2751,7 @@ add_builtin(
|
|
|
2642
2751
|
doc="""Cooperatively compute the sum of the tile elements using all threads in the block.
|
|
2643
2752
|
|
|
2644
2753
|
:param a: The tile to compute the sum of
|
|
2645
|
-
:returns: A single-element tile
|
|
2754
|
+
:returns: A single-element tile holding the sum
|
|
2646
2755
|
|
|
2647
2756
|
Example:
|
|
2648
2757
|
|
|
@@ -2651,18 +2760,18 @@ add_builtin(
|
|
|
2651
2760
|
@wp.kernel
|
|
2652
2761
|
def compute():
|
|
2653
2762
|
|
|
2654
|
-
t = wp.tile_ones(dtype=float,
|
|
2763
|
+
t = wp.tile_ones(dtype=float, shape=(16, 16))
|
|
2655
2764
|
s = wp.tile_sum(t)
|
|
2656
2765
|
|
|
2657
|
-
print(
|
|
2766
|
+
print(s)
|
|
2658
2767
|
|
|
2659
|
-
wp.
|
|
2768
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
2660
2769
|
|
|
2661
2770
|
Prints:
|
|
2662
2771
|
|
|
2663
2772
|
.. code-block:: text
|
|
2664
2773
|
|
|
2665
|
-
tile(
|
|
2774
|
+
[256] = tile(shape=(1), storage=register)
|
|
2666
2775
|
|
|
2667
2776
|
""",
|
|
2668
2777
|
group="Tile Primitives",
|
|
@@ -2673,17 +2782,17 @@ add_builtin(
|
|
|
2673
2782
|
def tile_min_value_func(arg_types, arg_values):
|
|
2674
2783
|
# return generic type (for doc builds)
|
|
2675
2784
|
if arg_types is None:
|
|
2676
|
-
return Tile(dtype=Any,
|
|
2785
|
+
return Tile(dtype=Any, shape=(1,))
|
|
2677
2786
|
|
|
2678
2787
|
if len(arg_types) != 1:
|
|
2679
|
-
raise
|
|
2788
|
+
raise TypeError(f"tile_min() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2680
2789
|
|
|
2681
2790
|
a = arg_types["a"]
|
|
2682
2791
|
|
|
2683
2792
|
if not is_tile(a):
|
|
2684
|
-
raise
|
|
2793
|
+
raise TypeError(f"tile_min() argument must be a tile, got {a!r}")
|
|
2685
2794
|
|
|
2686
|
-
return Tile(dtype=a.dtype,
|
|
2795
|
+
return Tile(dtype=a.dtype, shape=(1,), op="min")
|
|
2687
2796
|
|
|
2688
2797
|
|
|
2689
2798
|
add_builtin(
|
|
@@ -2694,7 +2803,7 @@ add_builtin(
|
|
|
2694
2803
|
doc="""Cooperatively compute the minimum of the tile elements using all threads in the block.
|
|
2695
2804
|
|
|
2696
2805
|
:param a: The tile to compute the minimum of
|
|
2697
|
-
:returns: A single-element tile
|
|
2806
|
+
:returns: A single-element tile holding the minimum value
|
|
2698
2807
|
|
|
2699
2808
|
Example:
|
|
2700
2809
|
|
|
@@ -2703,18 +2812,19 @@ add_builtin(
|
|
|
2703
2812
|
@wp.kernel
|
|
2704
2813
|
def compute():
|
|
2705
2814
|
|
|
2706
|
-
t = wp.tile_arange(
|
|
2815
|
+
t = wp.tile_arange(64, 128)
|
|
2707
2816
|
s = wp.tile_min(t)
|
|
2708
2817
|
|
|
2709
|
-
print(
|
|
2818
|
+
print(s)
|
|
2819
|
+
|
|
2710
2820
|
|
|
2711
|
-
wp.
|
|
2821
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
2712
2822
|
|
|
2713
2823
|
Prints:
|
|
2714
2824
|
|
|
2715
2825
|
.. code-block:: text
|
|
2716
2826
|
|
|
2717
|
-
tile(
|
|
2827
|
+
[64] = tile(shape=(1), storage=register)
|
|
2718
2828
|
|
|
2719
2829
|
""",
|
|
2720
2830
|
group="Tile Primitives",
|
|
@@ -2725,28 +2835,28 @@ add_builtin(
|
|
|
2725
2835
|
def tile_max_value_func(arg_types, arg_values):
|
|
2726
2836
|
# return generic type (for doc builds)
|
|
2727
2837
|
if arg_types is None:
|
|
2728
|
-
return Tile(dtype=Any,
|
|
2838
|
+
return Tile(dtype=Any, shape=(1,))
|
|
2729
2839
|
|
|
2730
2840
|
if len(arg_types) != 1:
|
|
2731
|
-
raise
|
|
2841
|
+
raise TypeError(f"tile_max() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
2732
2842
|
|
|
2733
2843
|
a = arg_types["a"]
|
|
2734
2844
|
|
|
2735
2845
|
if not is_tile(a):
|
|
2736
|
-
raise
|
|
2846
|
+
raise TypeError(f"tile_max() argument must be a tile, got {a!r}")
|
|
2737
2847
|
|
|
2738
|
-
return Tile(dtype=a.dtype,
|
|
2848
|
+
return Tile(dtype=a.dtype, shape=(1,), op="min")
|
|
2739
2849
|
|
|
2740
2850
|
|
|
2741
2851
|
add_builtin(
|
|
2742
2852
|
"tile_max",
|
|
2743
|
-
input_types={"a": Tile},
|
|
2853
|
+
input_types={"a": Tile(dtype=Any, shape=Any)},
|
|
2744
2854
|
value_func=tile_max_value_func,
|
|
2745
|
-
variadic=
|
|
2855
|
+
variadic=False,
|
|
2746
2856
|
doc="""Cooperatively compute the maximum of the tile elements using all threads in the block.
|
|
2747
2857
|
|
|
2748
2858
|
:param a: The tile to compute the maximum from
|
|
2749
|
-
:returns: A single-element tile
|
|
2859
|
+
:returns: A single-element tile holding the maximum value
|
|
2750
2860
|
|
|
2751
2861
|
Example:
|
|
2752
2862
|
|
|
@@ -2755,18 +2865,18 @@ add_builtin(
|
|
|
2755
2865
|
@wp.kernel
|
|
2756
2866
|
def compute():
|
|
2757
2867
|
|
|
2758
|
-
t = wp.tile_arange(
|
|
2759
|
-
s = wp.
|
|
2868
|
+
t = wp.tile_arange(64, 128)
|
|
2869
|
+
s = wp.tile_max(t)
|
|
2760
2870
|
|
|
2761
|
-
print(
|
|
2871
|
+
print(s)
|
|
2762
2872
|
|
|
2763
|
-
wp.
|
|
2873
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
|
|
2764
2874
|
|
|
2765
2875
|
Prints:
|
|
2766
2876
|
|
|
2767
2877
|
.. code-block:: text
|
|
2768
2878
|
|
|
2769
|
-
tile(
|
|
2879
|
+
[127] = tile(shape=(1), storage=register)
|
|
2770
2880
|
|
|
2771
2881
|
""",
|
|
2772
2882
|
group="Tile Primitives",
|
|
@@ -2777,15 +2887,14 @@ add_builtin(
|
|
|
2777
2887
|
# does type propagation for load()
|
|
2778
2888
|
def tile_reduce_value_func(arg_types, arg_values):
|
|
2779
2889
|
if arg_types is None:
|
|
2780
|
-
return Tile(dtype=Any,
|
|
2890
|
+
return Tile(dtype=Any, shape=(1,))
|
|
2781
2891
|
|
|
2782
2892
|
a = arg_types["a"]
|
|
2783
2893
|
|
|
2784
|
-
# check all args are tiles
|
|
2785
2894
|
if not is_tile(a):
|
|
2786
|
-
raise
|
|
2895
|
+
raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
|
|
2787
2896
|
|
|
2788
|
-
return Tile(dtype=a.dtype,
|
|
2897
|
+
return Tile(dtype=a.dtype, shape=(1,), op="reduce")
|
|
2789
2898
|
|
|
2790
2899
|
|
|
2791
2900
|
def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
@@ -2796,7 +2905,7 @@ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any,
|
|
|
2796
2905
|
|
|
2797
2906
|
add_builtin(
|
|
2798
2907
|
"tile_reduce",
|
|
2799
|
-
input_types={"op": Callable, "a": Any},
|
|
2908
|
+
input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
|
|
2800
2909
|
value_func=tile_reduce_value_func,
|
|
2801
2910
|
native_func="tile_reduce",
|
|
2802
2911
|
doc="""Apply a custom reduction operator across the tile.
|
|
@@ -2804,8 +2913,8 @@ add_builtin(
|
|
|
2804
2913
|
This function cooperatively performs a reduction using the provided operator across the tile.
|
|
2805
2914
|
|
|
2806
2915
|
:param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
|
|
2807
|
-
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's
|
|
2808
|
-
:returns: A single-element tile with
|
|
2916
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
|
|
2917
|
+
:returns: A single-element tile with the same data type as the input tile.
|
|
2809
2918
|
|
|
2810
2919
|
Example:
|
|
2811
2920
|
|
|
@@ -2819,13 +2928,13 @@ add_builtin(
|
|
|
2819
2928
|
|
|
2820
2929
|
print(s)
|
|
2821
2930
|
|
|
2822
|
-
wp.
|
|
2931
|
+
wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
|
|
2823
2932
|
|
|
2824
2933
|
Prints:
|
|
2825
2934
|
|
|
2826
2935
|
.. code-block:: text
|
|
2827
2936
|
|
|
2828
|
-
tile(
|
|
2937
|
+
[362880] = tile(shape=(1), storage=register)
|
|
2829
2938
|
""",
|
|
2830
2939
|
group="Tile Primitives",
|
|
2831
2940
|
export=False,
|
|
@@ -2837,26 +2946,19 @@ add_builtin(
|
|
|
2837
2946
|
# does type propagation for load()
|
|
2838
2947
|
def tile_unary_map_value_func(arg_types, arg_values):
|
|
2839
2948
|
if arg_types is None:
|
|
2840
|
-
return Tile(dtype=Any,
|
|
2949
|
+
return Tile(dtype=Any, shape=Any)
|
|
2841
2950
|
|
|
2842
2951
|
a = arg_types["a"]
|
|
2843
2952
|
|
|
2844
|
-
# check all args are tiles
|
|
2845
2953
|
if not is_tile(a):
|
|
2846
|
-
raise
|
|
2954
|
+
raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
|
|
2847
2955
|
|
|
2848
2956
|
return TileUnaryMap(a)
|
|
2849
2957
|
|
|
2850
2958
|
|
|
2851
|
-
# def tile_map_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
2852
|
-
# func_args = (args["op"], *args["args"])
|
|
2853
|
-
# template_args = ()
|
|
2854
|
-
# return (func_args, template_args)
|
|
2855
|
-
|
|
2856
|
-
|
|
2857
2959
|
add_builtin(
|
|
2858
2960
|
"tile_map",
|
|
2859
|
-
input_types={"op": Callable, "a": Any},
|
|
2961
|
+
input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
|
|
2860
2962
|
value_func=tile_unary_map_value_func,
|
|
2861
2963
|
# dispatch_func=tile_map_dispatch_func,
|
|
2862
2964
|
# variadic=True,
|
|
@@ -2866,8 +2968,8 @@ add_builtin(
|
|
|
2866
2968
|
This function cooperatively applies a unary function to each element of the tile using all threads in the block.
|
|
2867
2969
|
|
|
2868
2970
|
:param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
|
|
2869
|
-
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's
|
|
2870
|
-
:returns: A tile with the same dimensions and
|
|
2971
|
+
:param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
|
|
2972
|
+
:returns: A tile with the same dimensions and data type as the input tile.
|
|
2871
2973
|
|
|
2872
2974
|
Example:
|
|
2873
2975
|
|
|
@@ -2881,13 +2983,13 @@ add_builtin(
|
|
|
2881
2983
|
|
|
2882
2984
|
print(s)
|
|
2883
2985
|
|
|
2884
|
-
wp.
|
|
2986
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
|
|
2885
2987
|
|
|
2886
2988
|
Prints:
|
|
2887
2989
|
|
|
2888
2990
|
.. code-block:: text
|
|
2889
2991
|
|
|
2890
|
-
|
|
2992
|
+
[0 0.0998334 0.198669 0.29552 0.389418 0.479426 0.564642 0.644218 0.717356 0.783327] = tile(shape=(10), storage=register)
|
|
2891
2993
|
""",
|
|
2892
2994
|
group="Tile Primitives",
|
|
2893
2995
|
export=False,
|
|
@@ -2896,34 +2998,37 @@ add_builtin(
|
|
|
2896
2998
|
|
|
2897
2999
|
def tile_binary_map_value_func(arg_types, arg_values):
|
|
2898
3000
|
if arg_types is None:
|
|
2899
|
-
return Tile(dtype=Any,
|
|
3001
|
+
return Tile(dtype=Any, shape=Any)
|
|
2900
3002
|
|
|
2901
3003
|
a = arg_types["a"]
|
|
2902
3004
|
b = arg_types["b"]
|
|
2903
3005
|
|
|
2904
3006
|
# check all args are tiles
|
|
2905
3007
|
if not is_tile(a):
|
|
2906
|
-
raise
|
|
3008
|
+
raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
|
|
2907
3009
|
|
|
2908
3010
|
if not is_tile(b):
|
|
2909
|
-
raise
|
|
3011
|
+
raise TypeError(f"tile_map() 'b' argument must be a tile, got {b!r}")
|
|
2910
3012
|
|
|
2911
|
-
#
|
|
3013
|
+
# ensure types equal
|
|
2912
3014
|
if not types_equal(a.dtype, b.dtype):
|
|
2913
|
-
raise
|
|
3015
|
+
raise TypeError(f"tile_map() arguments must have the same dtype, got {a.dtype} and {b.dtype}")
|
|
2914
3016
|
|
|
2915
|
-
if a.
|
|
2916
|
-
raise
|
|
3017
|
+
if len(a.shape) != len(b.shape):
|
|
3018
|
+
raise ValueError(
|
|
3019
|
+
f"tile_map() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
|
|
3020
|
+
)
|
|
2917
3021
|
|
|
2918
|
-
|
|
2919
|
-
|
|
3022
|
+
for i in range(len(a.shape)):
|
|
3023
|
+
if a.shape[i] != b.shape[i]:
|
|
3024
|
+
raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape[i]} and {b.shape[i]}")
|
|
2920
3025
|
|
|
2921
3026
|
return TileBinaryMap(a, b)
|
|
2922
3027
|
|
|
2923
3028
|
|
|
2924
3029
|
add_builtin(
|
|
2925
3030
|
"tile_map",
|
|
2926
|
-
input_types={"op": Callable, "a": Any, "b": Any},
|
|
3031
|
+
input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
2927
3032
|
value_func=tile_binary_map_value_func,
|
|
2928
3033
|
# dispatch_func=tile_map_dispatch_func,
|
|
2929
3034
|
# variadic=True,
|
|
@@ -2946,19 +3051,19 @@ add_builtin(
|
|
|
2946
3051
|
def compute():
|
|
2947
3052
|
|
|
2948
3053
|
a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
|
|
2949
|
-
b = wp.tile_ones(
|
|
3054
|
+
b = wp.tile_ones(shape=10, dtype=float)
|
|
2950
3055
|
|
|
2951
3056
|
s = wp.tile_map(wp.add, a, b)
|
|
2952
3057
|
|
|
2953
3058
|
print(s)
|
|
2954
3059
|
|
|
2955
|
-
wp.
|
|
3060
|
+
wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
|
|
2956
3061
|
|
|
2957
3062
|
Prints:
|
|
2958
3063
|
|
|
2959
3064
|
.. code-block:: text
|
|
2960
3065
|
|
|
2961
|
-
|
|
3066
|
+
[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9] = tile(shape=(10), storage=register)""",
|
|
2962
3067
|
group="Tile Primitives",
|
|
2963
3068
|
export=False,
|
|
2964
3069
|
)
|
|
@@ -3073,6 +3178,18 @@ add_builtin(
|
|
|
3073
3178
|
)
|
|
3074
3179
|
|
|
3075
3180
|
|
|
3181
|
+
def mlp_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
3182
|
+
warp.utils.warn(
|
|
3183
|
+
"wp.mlp() is deprecated and will be removed in a future\nversion. Use tile primitives instead.",
|
|
3184
|
+
category=DeprecationWarning,
|
|
3185
|
+
)
|
|
3186
|
+
|
|
3187
|
+
func_args = tuple(args.values())
|
|
3188
|
+
template_args = ()
|
|
3189
|
+
|
|
3190
|
+
return (func_args, template_args)
|
|
3191
|
+
|
|
3192
|
+
|
|
3076
3193
|
add_builtin(
|
|
3077
3194
|
"mlp",
|
|
3078
3195
|
input_types={
|
|
@@ -3084,9 +3201,13 @@ add_builtin(
|
|
|
3084
3201
|
"out": array(dtype=float, ndim=2),
|
|
3085
3202
|
},
|
|
3086
3203
|
value_type=None,
|
|
3204
|
+
dispatch_func=mlp_dispatch_func,
|
|
3087
3205
|
skip_replay=True,
|
|
3088
3206
|
doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
|
|
3089
3207
|
|
|
3208
|
+
.. deprecated:: 1.6
|
|
3209
|
+
Use :doc:`tile primitives </modules/tiles>` instead.
|
|
3210
|
+
|
|
3090
3211
|
:param weights: A layer's network weights with dimensions ``(m, n)``.
|
|
3091
3212
|
:param bias: An array with dimensions ``(n)``.
|
|
3092
3213
|
:param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
|
|
@@ -4665,6 +4786,19 @@ def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str,
|
|
|
4665
4786
|
return arr_type.dtype
|
|
4666
4787
|
|
|
4667
4788
|
|
|
4789
|
+
def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
4790
|
+
# as this is a codegen callback, we can mark the fact that this func writes to an array here
|
|
4791
|
+
if warp.config.verify_autograd_array_access:
|
|
4792
|
+
arr = args["arr"]
|
|
4793
|
+
arr.mark_write()
|
|
4794
|
+
|
|
4795
|
+
func_args = tuple(args.values())
|
|
4796
|
+
# we don't need to specify template arguments for atomic ops
|
|
4797
|
+
template_args = ()
|
|
4798
|
+
|
|
4799
|
+
return (func_args, template_args)
|
|
4800
|
+
|
|
4801
|
+
|
|
4668
4802
|
for array_type in array_types:
|
|
4669
4803
|
# don't list indexed array operations explicitly in docs
|
|
4670
4804
|
hidden = array_type == indexedarray
|
|
@@ -4675,6 +4809,7 @@ for array_type in array_types:
|
|
|
4675
4809
|
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4676
4810
|
constraint=atomic_op_constraint,
|
|
4677
4811
|
value_func=atomic_op_value_func,
|
|
4812
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4678
4813
|
doc="Atomically add ``value`` onto ``arr[i]`` and return the old value.",
|
|
4679
4814
|
group="Utility",
|
|
4680
4815
|
skip_replay=True,
|
|
@@ -4685,6 +4820,7 @@ for array_type in array_types:
|
|
|
4685
4820
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4686
4821
|
constraint=atomic_op_constraint,
|
|
4687
4822
|
value_func=atomic_op_value_func,
|
|
4823
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4688
4824
|
doc="Atomically add ``value`` onto ``arr[i,j]`` and return the old value.",
|
|
4689
4825
|
group="Utility",
|
|
4690
4826
|
skip_replay=True,
|
|
@@ -4695,6 +4831,7 @@ for array_type in array_types:
|
|
|
4695
4831
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4696
4832
|
constraint=atomic_op_constraint,
|
|
4697
4833
|
value_func=atomic_op_value_func,
|
|
4834
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4698
4835
|
doc="Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value.",
|
|
4699
4836
|
group="Utility",
|
|
4700
4837
|
skip_replay=True,
|
|
@@ -4705,6 +4842,7 @@ for array_type in array_types:
|
|
|
4705
4842
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4706
4843
|
constraint=atomic_op_constraint,
|
|
4707
4844
|
value_func=atomic_op_value_func,
|
|
4845
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4708
4846
|
doc="Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
|
|
4709
4847
|
group="Utility",
|
|
4710
4848
|
skip_replay=True,
|
|
@@ -4716,6 +4854,7 @@ for array_type in array_types:
|
|
|
4716
4854
|
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4717
4855
|
constraint=atomic_op_constraint,
|
|
4718
4856
|
value_func=atomic_op_value_func,
|
|
4857
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4719
4858
|
doc="Atomically subtract ``value`` onto ``arr[i]`` and return the old value.",
|
|
4720
4859
|
group="Utility",
|
|
4721
4860
|
skip_replay=True,
|
|
@@ -4726,6 +4865,7 @@ for array_type in array_types:
|
|
|
4726
4865
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4727
4866
|
constraint=atomic_op_constraint,
|
|
4728
4867
|
value_func=atomic_op_value_func,
|
|
4868
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4729
4869
|
doc="Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value.",
|
|
4730
4870
|
group="Utility",
|
|
4731
4871
|
skip_replay=True,
|
|
@@ -4736,6 +4876,7 @@ for array_type in array_types:
|
|
|
4736
4876
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4737
4877
|
constraint=atomic_op_constraint,
|
|
4738
4878
|
value_func=atomic_op_value_func,
|
|
4879
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4739
4880
|
doc="Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value.",
|
|
4740
4881
|
group="Utility",
|
|
4741
4882
|
skip_replay=True,
|
|
@@ -4746,6 +4887,7 @@ for array_type in array_types:
|
|
|
4746
4887
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4747
4888
|
constraint=atomic_op_constraint,
|
|
4748
4889
|
value_func=atomic_op_value_func,
|
|
4890
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4749
4891
|
doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
|
|
4750
4892
|
group="Utility",
|
|
4751
4893
|
skip_replay=True,
|
|
@@ -4757,6 +4899,7 @@ for array_type in array_types:
|
|
|
4757
4899
|
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4758
4900
|
constraint=atomic_op_constraint,
|
|
4759
4901
|
value_func=atomic_op_value_func,
|
|
4902
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4760
4903
|
doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
4761
4904
|
|
|
4762
4905
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
@@ -4769,6 +4912,7 @@ for array_type in array_types:
|
|
|
4769
4912
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4770
4913
|
constraint=atomic_op_constraint,
|
|
4771
4914
|
value_func=atomic_op_value_func,
|
|
4915
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4772
4916
|
doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
4773
4917
|
|
|
4774
4918
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
@@ -4781,6 +4925,7 @@ for array_type in array_types:
|
|
|
4781
4925
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4782
4926
|
constraint=atomic_op_constraint,
|
|
4783
4927
|
value_func=atomic_op_value_func,
|
|
4928
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4784
4929
|
doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
4785
4930
|
|
|
4786
4931
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
@@ -4793,6 +4938,7 @@ for array_type in array_types:
|
|
|
4793
4938
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4794
4939
|
constraint=atomic_op_constraint,
|
|
4795
4940
|
value_func=atomic_op_value_func,
|
|
4941
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4796
4942
|
doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
4797
4943
|
|
|
4798
4944
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
@@ -4806,6 +4952,7 @@ for array_type in array_types:
|
|
|
4806
4952
|
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4807
4953
|
constraint=atomic_op_constraint,
|
|
4808
4954
|
value_func=atomic_op_value_func,
|
|
4955
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4809
4956
|
doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
4810
4957
|
|
|
4811
4958
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
@@ -4818,6 +4965,7 @@ for array_type in array_types:
|
|
|
4818
4965
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4819
4966
|
constraint=atomic_op_constraint,
|
|
4820
4967
|
value_func=atomic_op_value_func,
|
|
4968
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4821
4969
|
doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
4822
4970
|
|
|
4823
4971
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
@@ -4830,6 +4978,7 @@ for array_type in array_types:
|
|
|
4830
4978
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4831
4979
|
constraint=atomic_op_constraint,
|
|
4832
4980
|
value_func=atomic_op_value_func,
|
|
4981
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4833
4982
|
doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
4834
4983
|
|
|
4835
4984
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
@@ -4842,6 +4991,7 @@ for array_type in array_types:
|
|
|
4842
4991
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4843
4992
|
constraint=atomic_op_constraint,
|
|
4844
4993
|
value_func=atomic_op_value_func,
|
|
4994
|
+
dispatch_func=atomic_op_dispatch_func,
|
|
4845
4995
|
doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
4846
4996
|
|
|
4847
4997
|
The operation is only atomic on a per-component basis for vectors and matrices.""",
|
|
@@ -4977,6 +5127,43 @@ add_builtin(
|
|
|
4977
5127
|
)
|
|
4978
5128
|
|
|
4979
5129
|
|
|
5130
|
+
# implements vector[idx] += scalar
|
|
5131
|
+
add_builtin(
|
|
5132
|
+
"augassign_add",
|
|
5133
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5134
|
+
value_type=None,
|
|
5135
|
+
hidden=True,
|
|
5136
|
+
group="Utility",
|
|
5137
|
+
)
|
|
5138
|
+
|
|
5139
|
+
# implements quaternion[idx] += scalar
|
|
5140
|
+
add_builtin(
|
|
5141
|
+
"augassign_add",
|
|
5142
|
+
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5143
|
+
value_type=None,
|
|
5144
|
+
hidden=True,
|
|
5145
|
+
group="Utility",
|
|
5146
|
+
)
|
|
5147
|
+
|
|
5148
|
+
# implements vector[idx] -= scalar
|
|
5149
|
+
add_builtin(
|
|
5150
|
+
"augassign_sub",
|
|
5151
|
+
input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
|
|
5152
|
+
value_type=None,
|
|
5153
|
+
hidden=True,
|
|
5154
|
+
group="Utility",
|
|
5155
|
+
)
|
|
5156
|
+
|
|
5157
|
+
# implements quaternion[idx] -= scalar
|
|
5158
|
+
add_builtin(
|
|
5159
|
+
"augassign_sub",
|
|
5160
|
+
input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
|
|
5161
|
+
value_type=None,
|
|
5162
|
+
hidden=True,
|
|
5163
|
+
group="Utility",
|
|
5164
|
+
)
|
|
5165
|
+
|
|
5166
|
+
|
|
4980
5167
|
def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
4981
5168
|
mat_type = arg_types["a"]
|
|
4982
5169
|
row_type = mat_type._wp_row_type_
|
|
@@ -5026,22 +5213,42 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
|
|
|
5026
5213
|
return mat_size == vec_size and mat_type == vec_type
|
|
5027
5214
|
|
|
5028
5215
|
|
|
5029
|
-
# implements matrix[i,j] = scalar
|
|
5216
|
+
# implements matrix[i,j] = scalar
|
|
5217
|
+
add_builtin(
|
|
5218
|
+
"assign",
|
|
5219
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5220
|
+
value_func=matrix_assign_value_func,
|
|
5221
|
+
hidden=True,
|
|
5222
|
+
group="Utility",
|
|
5223
|
+
)
|
|
5224
|
+
|
|
5225
|
+
|
|
5226
|
+
# implements matrix[i] = vector
|
|
5227
|
+
add_builtin(
|
|
5228
|
+
"assign",
|
|
5229
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
|
|
5230
|
+
constraint=matrix_vector_sametype,
|
|
5231
|
+
value_func=matrix_assign_value_func,
|
|
5232
|
+
hidden=True,
|
|
5233
|
+
group="Utility",
|
|
5234
|
+
)
|
|
5235
|
+
|
|
5236
|
+
|
|
5237
|
+
# implements matrix[i,j] += scalar
|
|
5030
5238
|
add_builtin(
|
|
5031
|
-
"
|
|
5239
|
+
"augassign_add",
|
|
5032
5240
|
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5033
|
-
|
|
5241
|
+
value_type=None,
|
|
5034
5242
|
hidden=True,
|
|
5035
5243
|
group="Utility",
|
|
5036
5244
|
)
|
|
5037
5245
|
|
|
5038
5246
|
|
|
5039
|
-
# implements matrix[i]
|
|
5247
|
+
# implements matrix[i,j] -= scalar
|
|
5040
5248
|
add_builtin(
|
|
5041
|
-
"
|
|
5042
|
-
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "
|
|
5043
|
-
|
|
5044
|
-
value_func=matrix_assign_value_func,
|
|
5249
|
+
"augassign_sub",
|
|
5250
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
|
|
5251
|
+
value_type=None,
|
|
5045
5252
|
hidden=True,
|
|
5046
5253
|
group="Utility",
|
|
5047
5254
|
)
|
|
@@ -5613,19 +5820,19 @@ add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.boo
|
|
|
5613
5820
|
# Tile operators
|
|
5614
5821
|
def tile_unary_value_func(arg_types, arg_values):
|
|
5615
5822
|
if arg_types is None:
|
|
5616
|
-
return Tile(dtype=Any,
|
|
5823
|
+
return Tile(dtype=Any, shape=Any)
|
|
5617
5824
|
|
|
5618
5825
|
t = arg_types["x"]
|
|
5619
5826
|
|
|
5620
5827
|
if not is_tile(t):
|
|
5621
|
-
raise
|
|
5828
|
+
raise TypeError(f"Expected tile for unary expression, got {t}")
|
|
5622
5829
|
|
|
5623
5830
|
return TileUnaryMap(t)
|
|
5624
5831
|
|
|
5625
5832
|
|
|
5626
5833
|
def tile_scalar_mul_value_func(arg_types, arg_values):
|
|
5627
5834
|
if arg_types is None:
|
|
5628
|
-
return Tile(dtype=Any,
|
|
5835
|
+
return Tile(dtype=Any, shape=Any)
|
|
5629
5836
|
|
|
5630
5837
|
x = arg_types["x"]
|
|
5631
5838
|
y = arg_types["y"]
|
|
@@ -5633,25 +5840,21 @@ def tile_scalar_mul_value_func(arg_types, arg_values):
|
|
|
5633
5840
|
# tile*scalar
|
|
5634
5841
|
if is_tile(x):
|
|
5635
5842
|
if x.dtype != y:
|
|
5636
|
-
raise
|
|
5637
|
-
"Scalar factor should have the same type as tile for tile*scalar, tile type: {x} scalar type: {y}"
|
|
5638
|
-
)
|
|
5843
|
+
raise TypeError(f"Scalar factor type {y} does not match tile type {x.dtype} for tile*scalar")
|
|
5639
5844
|
|
|
5640
|
-
return TileBinaryMap(x, TileConstant(y, x.
|
|
5845
|
+
return TileBinaryMap(x, TileConstant(y, x.shape))
|
|
5641
5846
|
|
|
5642
5847
|
# scalar*tile
|
|
5643
5848
|
if is_tile(y):
|
|
5644
5849
|
if y.dtype != x:
|
|
5645
|
-
raise
|
|
5646
|
-
"Scalar factor should have the same type as tile for scalar*tile, tile type: {x} scalar type: {y}"
|
|
5647
|
-
)
|
|
5850
|
+
raise TypeError(f"Scalar factor type {x} does not match tile type {y.dtype} for scalar*tile")
|
|
5648
5851
|
|
|
5649
|
-
return TileBinaryMap(TileConstant(x, y.
|
|
5852
|
+
return TileBinaryMap(TileConstant(x, y.shape), y)
|
|
5650
5853
|
|
|
5651
5854
|
|
|
5652
5855
|
add_builtin(
|
|
5653
5856
|
"neg",
|
|
5654
|
-
input_types={"x": Tile(dtype=Any,
|
|
5857
|
+
input_types={"x": Tile(dtype=Any, shape=Any)},
|
|
5655
5858
|
value_func=tile_unary_value_func,
|
|
5656
5859
|
doc="Negate each element of a tile",
|
|
5657
5860
|
export=False,
|
|
@@ -5661,7 +5864,7 @@ add_builtin(
|
|
|
5661
5864
|
|
|
5662
5865
|
add_builtin(
|
|
5663
5866
|
"add",
|
|
5664
|
-
input_types={"a": Tile(dtype=Any,
|
|
5867
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
5665
5868
|
value_func=tile_binary_map_value_func,
|
|
5666
5869
|
# dispatch_func=tile_map_dispatch_func,
|
|
5667
5870
|
# variadic=True,
|
|
@@ -5671,9 +5874,22 @@ add_builtin(
|
|
|
5671
5874
|
export=False,
|
|
5672
5875
|
)
|
|
5673
5876
|
|
|
5877
|
+
add_builtin(
|
|
5878
|
+
"sub",
|
|
5879
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
5880
|
+
value_func=tile_binary_map_value_func,
|
|
5881
|
+
# dispatch_func=tile_map_dispatch_func,
|
|
5882
|
+
# variadic=True,
|
|
5883
|
+
native_func="tile_sub",
|
|
5884
|
+
doc="Subtract each element b from a",
|
|
5885
|
+
group="Tile Primitives",
|
|
5886
|
+
export=False,
|
|
5887
|
+
)
|
|
5888
|
+
|
|
5889
|
+
|
|
5674
5890
|
add_builtin(
|
|
5675
5891
|
"mul",
|
|
5676
|
-
input_types={"x": Tile(dtype=Any,
|
|
5892
|
+
input_types={"x": Tile(dtype=Any, shape=Any), "y": Scalar},
|
|
5677
5893
|
value_func=tile_scalar_mul_value_func,
|
|
5678
5894
|
doc="Multiply each element of a tile by a scalar",
|
|
5679
5895
|
export=False,
|
|
@@ -5683,7 +5899,7 @@ add_builtin(
|
|
|
5683
5899
|
|
|
5684
5900
|
add_builtin(
|
|
5685
5901
|
"mul",
|
|
5686
|
-
input_types={"x": Scalar, "y": Tile(dtype=Any,
|
|
5902
|
+
input_types={"x": Scalar, "y": Tile(dtype=Any, shape=Any)},
|
|
5687
5903
|
value_func=tile_scalar_mul_value_func,
|
|
5688
5904
|
doc="Multiply each element of a tile by a scalar",
|
|
5689
5905
|
export=False,
|
|
@@ -5692,6 +5908,70 @@ add_builtin(
|
|
|
5692
5908
|
)
|
|
5693
5909
|
|
|
5694
5910
|
|
|
5911
|
+
def tile_diag_add_value_func(arg_types, arg_values):
|
|
5912
|
+
if arg_types is None:
|
|
5913
|
+
return Tile(dtype=Any, shape=Any)
|
|
5914
|
+
|
|
5915
|
+
a = arg_types["a"]
|
|
5916
|
+
d = arg_types["d"]
|
|
5917
|
+
|
|
5918
|
+
if not is_tile(a):
|
|
5919
|
+
raise TypeError(f"tile_diag_add() 'a' argument must be a tile, got {a!r}")
|
|
5920
|
+
|
|
5921
|
+
if not is_tile(d):
|
|
5922
|
+
raise TypeError(f"tile_diag_add() 'd' argument must be a tile, got {d!r}")
|
|
5923
|
+
|
|
5924
|
+
if not types_equal(a.dtype, d.dtype):
|
|
5925
|
+
raise TypeError(f"tile_diag_add() arguments must have the same dtype, got {a.dtype} and {d.dtype}")
|
|
5926
|
+
|
|
5927
|
+
if len(a.shape) != 2:
|
|
5928
|
+
raise TypeError("tile_diag_add() argument 'a' must be a 2D tile")
|
|
5929
|
+
|
|
5930
|
+
if len(d.shape) != 1:
|
|
5931
|
+
raise TypeError("tile_diag_add() argument 'd' must be a 1D tile")
|
|
5932
|
+
|
|
5933
|
+
if a.shape[0] != a.shape[1]:
|
|
5934
|
+
raise ValueError("tile_diag_add() 'a' argument must be square")
|
|
5935
|
+
|
|
5936
|
+
if a.shape[0] != d.shape[0]:
|
|
5937
|
+
raise ValueError(
|
|
5938
|
+
f"tile_diag_add() 'd' argument must have the same number of elements as the number of rows in 'a', "
|
|
5939
|
+
f"got {d.shape[0]} elements in 'd' and {a.shape[0]} rows in 'a'"
|
|
5940
|
+
)
|
|
5941
|
+
|
|
5942
|
+
# use first argument to define output type
|
|
5943
|
+
return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
|
|
5944
|
+
|
|
5945
|
+
|
|
5946
|
+
def tile_diag_add_lto_dispatch_func(
|
|
5947
|
+
arg_types: Mapping[str, type],
|
|
5948
|
+
return_type: Any,
|
|
5949
|
+
return_values: List[Var],
|
|
5950
|
+
arg_values: Mapping[str, Var],
|
|
5951
|
+
options: Mapping[str, Any],
|
|
5952
|
+
builder: warp.context.ModuleBuilder,
|
|
5953
|
+
):
|
|
5954
|
+
a = arg_values["a"]
|
|
5955
|
+
d = arg_values["d"]
|
|
5956
|
+
# force the storage type of the input variables to shared memory
|
|
5957
|
+
a.type.storage = "shared"
|
|
5958
|
+
d.type.storage = "shared"
|
|
5959
|
+
out = return_values[0]
|
|
5960
|
+
return ((a, d, out), [], [], 0)
|
|
5961
|
+
|
|
5962
|
+
|
|
5963
|
+
add_builtin(
|
|
5964
|
+
"tile_diag_add",
|
|
5965
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "d": Tile(dtype=Any, shape=Any)},
|
|
5966
|
+
value_func=tile_diag_add_value_func,
|
|
5967
|
+
lto_dispatch_func=tile_diag_add_lto_dispatch_func,
|
|
5968
|
+
native_func="tile_diag_add",
|
|
5969
|
+
doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
|
|
5970
|
+
group="Tile Primitives",
|
|
5971
|
+
export=False,
|
|
5972
|
+
)
|
|
5973
|
+
|
|
5974
|
+
|
|
5695
5975
|
##
|
|
5696
5976
|
## MathDx, LTOIR-based, Tile functions
|
|
5697
5977
|
##
|
|
@@ -5703,24 +5983,25 @@ add_builtin(
|
|
|
5703
5983
|
def tile_matmul_generic_value_func(arg_types, arg_values):
|
|
5704
5984
|
# return generic type (for doc builds)
|
|
5705
5985
|
if arg_types is None:
|
|
5706
|
-
return Tile(dtype=Any,
|
|
5986
|
+
return Tile(dtype=Any, shape=Any)
|
|
5707
5987
|
|
|
5708
5988
|
a = arg_types["a"]
|
|
5709
5989
|
b = arg_types["b"]
|
|
5710
5990
|
|
|
5711
5991
|
if not is_tile(a):
|
|
5712
|
-
raise
|
|
5992
|
+
raise TypeError(f"tile_matmul() 'a' argument must be a tile, got {a!r}")
|
|
5993
|
+
|
|
5713
5994
|
if not is_tile(b):
|
|
5714
|
-
raise
|
|
5995
|
+
raise TypeError(f"tile_matmul() 'b' argument must be a tile, got {b!r}")
|
|
5715
5996
|
|
|
5716
5997
|
# out = wp.tile_matmul(a, b)
|
|
5717
5998
|
if len(arg_types) == 2:
|
|
5718
|
-
return Tile(dtype=a.dtype,
|
|
5999
|
+
return Tile(dtype=a.dtype, shape=(a.shape[0], b.shape[1]), storage="shared")
|
|
5719
6000
|
|
|
5720
6001
|
# wp.tile_matmul(a, b, out)
|
|
5721
6002
|
elif len(arg_types) == 3:
|
|
5722
6003
|
if not is_tile(arg_types["out"]):
|
|
5723
|
-
raise
|
|
6004
|
+
raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {arg_types['out']!r}")
|
|
5724
6005
|
|
|
5725
6006
|
return None
|
|
5726
6007
|
|
|
@@ -5743,16 +6024,20 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
5743
6024
|
accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value
|
|
5744
6025
|
out = arg_values["out"]
|
|
5745
6026
|
|
|
5746
|
-
if
|
|
5747
|
-
raise
|
|
6027
|
+
if not is_tile(out.type):
|
|
6028
|
+
raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {out!r}")
|
|
5748
6029
|
|
|
5749
6030
|
if any(arg.type.dtype not in [float16, float32, float64, vec2h, vec2f, vec2d] for arg in [a, b, out]):
|
|
5750
|
-
raise
|
|
6031
|
+
raise TypeError(
|
|
5751
6032
|
"tile_matmul() arguments must be tiles of float16, float32 or float64, vec2h, vec2f, vec2d entries"
|
|
5752
6033
|
)
|
|
5753
6034
|
|
|
5754
|
-
if (
|
|
5755
|
-
|
|
6035
|
+
if (
|
|
6036
|
+
(a.type.shape[1] != b.type.shape[0])
|
|
6037
|
+
or (a.type.shape[0] != out.type.shape[0])
|
|
6038
|
+
or (b.type.shape[1] != out.type.shape[1])
|
|
6039
|
+
):
|
|
6040
|
+
raise ValueError("tile_matmul(A, B, C) requires sizes of A, B and C to be consistent for a matmul")
|
|
5756
6041
|
|
|
5757
6042
|
# set the storage type to the inputs to shared
|
|
5758
6043
|
a.type.storage = "shared"
|
|
@@ -5774,18 +6059,18 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
5774
6059
|
return ("wp::vec2f", 5, 1)
|
|
5775
6060
|
if dtype == vec2d:
|
|
5776
6061
|
return ("wp::vec2d", 6, 1)
|
|
5777
|
-
raise
|
|
6062
|
+
raise TypeError("Unsupported input type in tile_matmul")
|
|
5778
6063
|
|
|
5779
6064
|
def cublasdx_arrangement_map(layout):
|
|
5780
6065
|
if layout == "colmajor":
|
|
5781
6066
|
return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
|
|
5782
6067
|
if layout == "rowmajor":
|
|
5783
6068
|
return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
|
|
5784
|
-
raise
|
|
6069
|
+
raise ValueError("Unsupported layout in tile_matmul")
|
|
5785
6070
|
|
|
5786
6071
|
# generate the LTO
|
|
5787
|
-
M, K = a.type.
|
|
5788
|
-
_, N = b.type.
|
|
6072
|
+
M, K = a.type.shape[0], a.type.shape[1]
|
|
6073
|
+
_, N = b.type.shape[0], b.type.shape[1]
|
|
5789
6074
|
num_threads = options["block_dim"]
|
|
5790
6075
|
arch = options["output_arch"]
|
|
5791
6076
|
|
|
@@ -5798,7 +6083,8 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
5798
6083
|
c_arrangement = cublasdx_arrangement_map(clayout)
|
|
5799
6084
|
|
|
5800
6085
|
if a_type != b_type or a_type != c_type:
|
|
5801
|
-
raise
|
|
6086
|
+
raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
|
|
6087
|
+
|
|
5802
6088
|
element_type = a_type
|
|
5803
6089
|
|
|
5804
6090
|
lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
|
|
@@ -5893,15 +6179,16 @@ def tile_matmul_generic_lto_dispatch_func(
|
|
|
5893
6179
|
),
|
|
5894
6180
|
template_args,
|
|
5895
6181
|
[lto_forward, lto_backward_A, lto_backward_B],
|
|
6182
|
+
0,
|
|
5896
6183
|
)
|
|
5897
6184
|
|
|
5898
6185
|
|
|
5899
6186
|
add_builtin(
|
|
5900
6187
|
"tile_matmul",
|
|
5901
6188
|
input_types={
|
|
5902
|
-
"a": Tile(dtype=Any,
|
|
5903
|
-
"b": Tile(dtype=Any,
|
|
5904
|
-
"out": Tile(dtype=Any,
|
|
6189
|
+
"a": Tile(dtype=Any, shape=Any),
|
|
6190
|
+
"b": Tile(dtype=Any, shape=Any),
|
|
6191
|
+
"out": Tile(dtype=Any, shape=Any),
|
|
5905
6192
|
},
|
|
5906
6193
|
value_func=tile_matmul_generic_value_func,
|
|
5907
6194
|
lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
|
|
@@ -5925,7 +6212,7 @@ add_builtin(
|
|
|
5925
6212
|
|
|
5926
6213
|
add_builtin(
|
|
5927
6214
|
"tile_matmul",
|
|
5928
|
-
input_types={"a": Tile(dtype=Any,
|
|
6215
|
+
input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
|
|
5929
6216
|
value_func=tile_matmul_generic_value_func,
|
|
5930
6217
|
lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
|
|
5931
6218
|
variadic=False,
|
|
@@ -5952,16 +6239,23 @@ add_builtin(
|
|
|
5952
6239
|
##
|
|
5953
6240
|
def tile_fft_generic_value_func(arg_types, arg_values):
|
|
5954
6241
|
if arg_types is None:
|
|
5955
|
-
return Tile(dtype=Any,
|
|
6242
|
+
return Tile(dtype=Any, shape=Any)
|
|
5956
6243
|
|
|
5957
6244
|
if len(arg_types) != 1:
|
|
5958
|
-
raise
|
|
6245
|
+
raise TypeError(f"tile_fft() takes exactly 1 positional argument but {len(arg_types)} were given")
|
|
5959
6246
|
|
|
5960
|
-
|
|
5961
|
-
raise RuntimeError("tile_fft() argument 0 must be a tile")
|
|
6247
|
+
inout = arg_types["inout"]
|
|
5962
6248
|
|
|
5963
|
-
if
|
|
5964
|
-
raise
|
|
6249
|
+
if not is_tile(inout):
|
|
6250
|
+
raise TypeError(f"tile_fft() argument must be a tile, got {inout!r}")
|
|
6251
|
+
|
|
6252
|
+
if inout.storage != "register":
|
|
6253
|
+
raise ValueError(f"tile_fft() argument must have 'register' storage, got {inout.storage}")
|
|
6254
|
+
|
|
6255
|
+
if inout.dtype not in [vec2f, vec2d]:
|
|
6256
|
+
raise TypeError(
|
|
6257
|
+
f"tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries, got {inout.dtype!r}"
|
|
6258
|
+
)
|
|
5965
6259
|
|
|
5966
6260
|
return None
|
|
5967
6261
|
|
|
@@ -5978,19 +6272,13 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
5978
6272
|
inout = arg_values["inout"]
|
|
5979
6273
|
inout.type.storage = "register"
|
|
5980
6274
|
|
|
5981
|
-
if not is_tile(inout.type):
|
|
5982
|
-
raise RuntimeError("tile_fft() arguments must be a single tile with register storage")
|
|
5983
|
-
|
|
5984
|
-
if inout.type.dtype not in [vec2f, vec2d]:
|
|
5985
|
-
raise RuntimeError("tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries")
|
|
5986
|
-
|
|
5987
6275
|
# see libcufftdx.hpp
|
|
5988
6276
|
if direction == "forward":
|
|
5989
6277
|
dir = 0 # CUFFTDX_DIRECTION_FORWARD
|
|
5990
6278
|
elif direction == "inverse":
|
|
5991
6279
|
dir = 1 # CUFFTDX_DIRECTION_INVERSE
|
|
5992
6280
|
else:
|
|
5993
|
-
raise
|
|
6281
|
+
raise ValueError(f"Invalid direction: {direction!r}. Expected 'forward' or 'inverse'.")
|
|
5994
6282
|
|
|
5995
6283
|
if inout.type.dtype == vec2f:
|
|
5996
6284
|
dtype = "wp::vec2f"
|
|
@@ -5999,10 +6287,10 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
5999
6287
|
dtype = "wp::vec2d"
|
|
6000
6288
|
precision = 6 # COMMONDX_PRECISION_F64
|
|
6001
6289
|
else:
|
|
6002
|
-
raise
|
|
6290
|
+
raise TypeError(f"Unsupported data type, got {dtype!r}")
|
|
6003
6291
|
|
|
6004
6292
|
# M FFTs of size N each
|
|
6005
|
-
batch, size = inout.type.
|
|
6293
|
+
batch, size = inout.type.shape[0], inout.type.shape[1]
|
|
6006
6294
|
num_threads = options["block_dim"]
|
|
6007
6295
|
arch = options["output_arch"]
|
|
6008
6296
|
ept = size // num_threads
|
|
@@ -6034,7 +6322,7 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6034
6322
|
lto_code.close()
|
|
6035
6323
|
if lto_code_path.exists():
|
|
6036
6324
|
lto_code_path.unlink()
|
|
6037
|
-
raise RuntimeError("Failed to compile
|
|
6325
|
+
raise RuntimeError("Failed to compile tile_fft")
|
|
6038
6326
|
|
|
6039
6327
|
with open(lto_code.name, "rb") as f:
|
|
6040
6328
|
lto_code_data = f.read()
|
|
@@ -6044,17 +6332,20 @@ def tile_fft_generic_lto_dispatch_func(
|
|
|
6044
6332
|
|
|
6045
6333
|
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6046
6334
|
|
|
6335
|
+
shared_memory_bytes = Tile.round_up(shared_memory_size.value)
|
|
6336
|
+
|
|
6047
6337
|
return (
|
|
6048
6338
|
(
|
|
6049
6339
|
Var(lto_symbol, str, False, True, False),
|
|
6050
6340
|
Var(dtype, str, False, True, False),
|
|
6051
|
-
Var(str(
|
|
6341
|
+
Var(str(shared_memory_bytes), str, False, True, False),
|
|
6052
6342
|
Var(str(batch), str, False, True, False),
|
|
6053
6343
|
Var(str(ept), str, False, True, False),
|
|
6054
6344
|
inout,
|
|
6055
6345
|
),
|
|
6056
6346
|
[],
|
|
6057
6347
|
[lto_code_data],
|
|
6348
|
+
shared_memory_bytes,
|
|
6058
6349
|
)
|
|
6059
6350
|
|
|
6060
6351
|
|
|
@@ -6068,6 +6359,8 @@ add_builtin(
|
|
|
6068
6359
|
|
|
6069
6360
|
This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
|
|
6070
6361
|
|
|
6362
|
+
Note that computing the adjoint is not yet supported.
|
|
6363
|
+
|
|
6071
6364
|
Supported datatypes are:
|
|
6072
6365
|
* vec2f, vec2d
|
|
6073
6366
|
|
|
@@ -6087,6 +6380,8 @@ add_builtin(
|
|
|
6087
6380
|
|
|
6088
6381
|
This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
|
|
6089
6382
|
|
|
6383
|
+
Note that computing the adjoint is not yet supported.
|
|
6384
|
+
|
|
6090
6385
|
Supported datatypes are:
|
|
6091
6386
|
* vec2f, vec2d
|
|
6092
6387
|
|
|
@@ -6096,6 +6391,283 @@ add_builtin(
|
|
|
6096
6391
|
namespace="",
|
|
6097
6392
|
)
|
|
6098
6393
|
|
|
6394
|
+
|
|
6395
|
+
##
|
|
6396
|
+
## Cholesky
|
|
6397
|
+
##
|
|
6398
|
+
def tile_cholesky_generic_value_func(arg_types, arg_values):
|
|
6399
|
+
if arg_types is None:
|
|
6400
|
+
return Tile(dtype=Any, shape=Any)
|
|
6401
|
+
|
|
6402
|
+
if len(arg_types) != 1:
|
|
6403
|
+
raise TypeError("tile_cholesky() requires 1 positional args")
|
|
6404
|
+
|
|
6405
|
+
a = arg_types["A"]
|
|
6406
|
+
|
|
6407
|
+
if not is_tile(a):
|
|
6408
|
+
raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
|
|
6409
|
+
|
|
6410
|
+
if len(a.shape) != 2:
|
|
6411
|
+
raise ValueError("tile_cholesky() argumust must be a 2D tile")
|
|
6412
|
+
|
|
6413
|
+
if a.shape[0] != a.shape[1]:
|
|
6414
|
+
raise ValueError("tile_cholesky() argument must be square")
|
|
6415
|
+
|
|
6416
|
+
return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
|
|
6417
|
+
|
|
6418
|
+
|
|
6419
|
+
cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3}
|
|
6420
|
+
|
|
6421
|
+
cusolver_type_map = {float32: ("wp::float32", 5), float64: ("wp::float64", 6)}
|
|
6422
|
+
|
|
6423
|
+
cusolver_fill_mode_map = {"upper": 0, "lower": 1}
|
|
6424
|
+
|
|
6425
|
+
|
|
6426
|
+
def tile_cholesky_generic_lto_dispatch_func(
|
|
6427
|
+
arg_types: Mapping[str, type],
|
|
6428
|
+
return_type: Any,
|
|
6429
|
+
return_values: List[Var],
|
|
6430
|
+
arg_values: Mapping[str, Var],
|
|
6431
|
+
options: Mapping[str, Any],
|
|
6432
|
+
builder: warp.context.ModuleBuilder,
|
|
6433
|
+
):
|
|
6434
|
+
a = arg_values["A"]
|
|
6435
|
+
# force source tile to shared memory
|
|
6436
|
+
a.type.storage = "shared"
|
|
6437
|
+
|
|
6438
|
+
if a.type.dtype not in cusolver_type_map.keys():
|
|
6439
|
+
raise TypeError("tile_cholesky() argument must be a tile of float32 or float64 entries")
|
|
6440
|
+
|
|
6441
|
+
if len(return_values) != 1:
|
|
6442
|
+
raise TypeError("tile_cholesky() returns one output")
|
|
6443
|
+
out = return_values[0]
|
|
6444
|
+
|
|
6445
|
+
dtype, precision_enum = cusolver_type_map[a.type.dtype]
|
|
6446
|
+
|
|
6447
|
+
# We already ensured a is square in tile_cholesky_generic_value_func()
|
|
6448
|
+
M, N = a.type.shape[0], a.type.shape[1]
|
|
6449
|
+
if out.type.shape[0] != M or out.type.shape[1] != M:
|
|
6450
|
+
raise ValueError("tile_cholesky() output tile must be square")
|
|
6451
|
+
|
|
6452
|
+
num_threads = options["block_dim"]
|
|
6453
|
+
arch = options["output_arch"]
|
|
6454
|
+
lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
|
|
6455
|
+
|
|
6456
|
+
# early out if LTO for this combination already exists for this module
|
|
6457
|
+
if lto_symbol in builder.ltoirs:
|
|
6458
|
+
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6459
|
+
|
|
6460
|
+
# otherwise compile LTO
|
|
6461
|
+
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6462
|
+
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6463
|
+
|
|
6464
|
+
# cuSOLVERDx only support col-major input/outputs,
|
|
6465
|
+
# so we use upper to mimic a row-major input
|
|
6466
|
+
result = warp.context.runtime.core.cuda_compile_solver(
|
|
6467
|
+
universal_fatbin_code.name.encode("utf-8"),
|
|
6468
|
+
lto_code.name.encode("utf-8"),
|
|
6469
|
+
lto_symbol.encode("utf-8"),
|
|
6470
|
+
0,
|
|
6471
|
+
None,
|
|
6472
|
+
None,
|
|
6473
|
+
arch,
|
|
6474
|
+
M,
|
|
6475
|
+
N,
|
|
6476
|
+
cusolver_function_map["potrf"],
|
|
6477
|
+
precision_enum,
|
|
6478
|
+
cusolver_fill_mode_map["upper"],
|
|
6479
|
+
num_threads,
|
|
6480
|
+
)
|
|
6481
|
+
|
|
6482
|
+
if not result:
|
|
6483
|
+
for f in [lto_code, universal_fatbin_code]:
|
|
6484
|
+
f.close()
|
|
6485
|
+
if Path(f.name).exists():
|
|
6486
|
+
Path(f.name).unlink()
|
|
6487
|
+
raise RuntimeError("Failed to compile tile_cholesky")
|
|
6488
|
+
|
|
6489
|
+
else:
|
|
6490
|
+
with open(lto_code.name, "rb") as f:
|
|
6491
|
+
lto_code_data = f.read()
|
|
6492
|
+
with open(universal_fatbin_code.name, "rb") as f:
|
|
6493
|
+
universal_fatbin_code_data = f.read()
|
|
6494
|
+
for f in [lto_code, universal_fatbin_code]:
|
|
6495
|
+
f.close()
|
|
6496
|
+
Path(f.name).unlink()
|
|
6497
|
+
|
|
6498
|
+
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6499
|
+
builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, unsigned);"
|
|
6500
|
+
builder.fatbins["cholesky"] = universal_fatbin_code_data
|
|
6501
|
+
|
|
6502
|
+
return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
|
|
6503
|
+
|
|
6504
|
+
|
|
6505
|
+
add_builtin(
|
|
6506
|
+
"tile_cholesky",
|
|
6507
|
+
input_types={"A": Tile},
|
|
6508
|
+
value_func=tile_cholesky_generic_value_func,
|
|
6509
|
+
lto_dispatch_func=tile_cholesky_generic_lto_dispatch_func,
|
|
6510
|
+
variadic=True,
|
|
6511
|
+
doc="""Compute the Cholesky factorization L of a matrix A.
|
|
6512
|
+
L is lower triangular and satisfies LL^T = A.
|
|
6513
|
+
|
|
6514
|
+
Note that computing the adjoint is not yet supported.
|
|
6515
|
+
|
|
6516
|
+
Supported datatypes are:
|
|
6517
|
+
* float32
|
|
6518
|
+
* float64
|
|
6519
|
+
|
|
6520
|
+
:param A: A square, symmetric positive-definite, matrix.
|
|
6521
|
+
:returns L: A square, lower triangular, matrix, such that LL^T = A""",
|
|
6522
|
+
group="Tile Primitives",
|
|
6523
|
+
export=False,
|
|
6524
|
+
namespace="",
|
|
6525
|
+
)
|
|
6526
|
+
|
|
6527
|
+
|
|
6528
|
+
def tile_cholesky_solve_generic_value_func(arg_types, arg_values):
|
|
6529
|
+
if arg_types is None:
|
|
6530
|
+
return None
|
|
6531
|
+
|
|
6532
|
+
if len(arg_types) != 2:
|
|
6533
|
+
raise TypeError("tile_cholesky_solve() requires exactly 2 positional args")
|
|
6534
|
+
|
|
6535
|
+
l = arg_types["L"]
|
|
6536
|
+
x = arg_types["x"]
|
|
6537
|
+
|
|
6538
|
+
if not is_tile(l):
|
|
6539
|
+
raise TypeError(f"tile_cholesky_solve() 'L' argument must be a tile, got {l!r}")
|
|
6540
|
+
|
|
6541
|
+
if not is_tile(x):
|
|
6542
|
+
raise TypeError(f"tile_cholesky_solve() 'x' argument must be a tile, got {l!r}")
|
|
6543
|
+
|
|
6544
|
+
if not types_equal(l.dtype, x.dtype):
|
|
6545
|
+
raise TypeError(f"tile_cholesky_solve() arguments must have the same dtype, got {l.dtype} and {x.dtype}")
|
|
6546
|
+
|
|
6547
|
+
if l.shape[0] != l.shape[1]:
|
|
6548
|
+
raise ValueError("tile_cholesky_solve() 'L' argument must be square")
|
|
6549
|
+
|
|
6550
|
+
if len(x.shape) != 1:
|
|
6551
|
+
raise TypeError("tile_cholesky_solve() 'x' argument must be a 1D tile")
|
|
6552
|
+
|
|
6553
|
+
if x.shape[0] != l.shape[0]:
|
|
6554
|
+
raise ValueError(
|
|
6555
|
+
f"tile_cholesky_solve() 'x' argument must have the same number of elements as the number of rows in 'L', "
|
|
6556
|
+
f"got {x.shape[0]} elements in 'x' and {l.shape[0]} rows in 'L'"
|
|
6557
|
+
)
|
|
6558
|
+
|
|
6559
|
+
return Tile(dtype=l.dtype, shape=x.shape, storage="shared")
|
|
6560
|
+
|
|
6561
|
+
|
|
6562
|
+
def tile_cholesky_solve_generic_lto_dispatch_func(
|
|
6563
|
+
arg_types: Mapping[str, type],
|
|
6564
|
+
return_type: Any,
|
|
6565
|
+
return_values: List[Var],
|
|
6566
|
+
arg_values: Mapping[str, Var],
|
|
6567
|
+
options: Mapping[str, Any],
|
|
6568
|
+
builder: warp.context.ModuleBuilder,
|
|
6569
|
+
):
|
|
6570
|
+
L = arg_values["L"]
|
|
6571
|
+
x = arg_values["x"]
|
|
6572
|
+
# force the storage type of the input variables to shared memory
|
|
6573
|
+
L.type.storage = "shared"
|
|
6574
|
+
x.type.storage = "shared"
|
|
6575
|
+
|
|
6576
|
+
if len(return_values) != 1:
|
|
6577
|
+
raise TypeError(f"tile_cholesky_solve() must return exactly one value, got {len(return_values)}")
|
|
6578
|
+
|
|
6579
|
+
y = return_values[0]
|
|
6580
|
+
|
|
6581
|
+
if any(T not in cusolver_type_map.keys() for T in [x.type.dtype, L.type.dtype]):
|
|
6582
|
+
raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
|
|
6583
|
+
|
|
6584
|
+
dtype, precision_enum = cusolver_type_map[L.type.dtype]
|
|
6585
|
+
M, N = L.type.shape[0], L.type.shape[1]
|
|
6586
|
+
|
|
6587
|
+
if len(y.type.shape) != 1:
|
|
6588
|
+
raise TypeError("tile_cholesky_solve() output vector must be 1D")
|
|
6589
|
+
|
|
6590
|
+
if y.type.shape[0] != M:
|
|
6591
|
+
raise ValueError(
|
|
6592
|
+
"tile_cholesky_solve() output vector must have same number of elements as the number of rows in 'L' "
|
|
6593
|
+
f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
|
|
6594
|
+
)
|
|
6595
|
+
|
|
6596
|
+
num_threads = options["block_dim"]
|
|
6597
|
+
arch = options["output_arch"]
|
|
6598
|
+
lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
|
|
6599
|
+
|
|
6600
|
+
# early out if LTO for this combination already exists for this module
|
|
6601
|
+
if lto_symbol in builder.ltoirs:
|
|
6602
|
+
return lto_symbol, builder.ltoirs[lto_symbol]
|
|
6603
|
+
|
|
6604
|
+
# otherwise compile LTO
|
|
6605
|
+
lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6606
|
+
universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
|
|
6607
|
+
|
|
6608
|
+
# cuSOLVERDx only support col-major input/outputs,
|
|
6609
|
+
# so we use upper to mimic a row-major input
|
|
6610
|
+
result = warp.context.runtime.core.cuda_compile_solver(
|
|
6611
|
+
universal_fatbin_code.name.encode("utf-8"),
|
|
6612
|
+
lto_code.name.encode("utf-8"),
|
|
6613
|
+
lto_symbol.encode("utf-8"),
|
|
6614
|
+
0,
|
|
6615
|
+
None,
|
|
6616
|
+
None,
|
|
6617
|
+
arch,
|
|
6618
|
+
M,
|
|
6619
|
+
N,
|
|
6620
|
+
cusolver_function_map["potrs"],
|
|
6621
|
+
precision_enum,
|
|
6622
|
+
cusolver_fill_mode_map["upper"],
|
|
6623
|
+
num_threads,
|
|
6624
|
+
)
|
|
6625
|
+
|
|
6626
|
+
if not result:
|
|
6627
|
+
for f in [lto_code, universal_fatbin_code]:
|
|
6628
|
+
f.close()
|
|
6629
|
+
if Path(f.name).exists():
|
|
6630
|
+
Path(f.name).unlink()
|
|
6631
|
+
raise RuntimeError("Failed to compile tile_cholesky_solve")
|
|
6632
|
+
|
|
6633
|
+
else:
|
|
6634
|
+
with open(lto_code.name, "rb") as f:
|
|
6635
|
+
lto_code_data = f.read()
|
|
6636
|
+
with open(universal_fatbin_code.name, "rb") as f:
|
|
6637
|
+
universal_fatbin_code_data = f.read()
|
|
6638
|
+
for f in [lto_code, universal_fatbin_code]:
|
|
6639
|
+
f.close()
|
|
6640
|
+
Path(f.name).unlink()
|
|
6641
|
+
|
|
6642
|
+
builder.ltoirs[lto_symbol] = lto_code_data
|
|
6643
|
+
builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, {dtype}*);"
|
|
6644
|
+
builder.fatbins["cholesky"] = universal_fatbin_code_data
|
|
6645
|
+
|
|
6646
|
+
return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
|
|
6647
|
+
|
|
6648
|
+
|
|
6649
|
+
add_builtin(
|
|
6650
|
+
"tile_cholesky_solve",
|
|
6651
|
+
input_types={"L": Tile, "x": Tile},
|
|
6652
|
+
value_func=tile_cholesky_solve_generic_value_func,
|
|
6653
|
+
lto_dispatch_func=tile_cholesky_solve_generic_lto_dispatch_func,
|
|
6654
|
+
variadic=True,
|
|
6655
|
+
doc="""With L such that LL^T = A, solve for x in Ax = y
|
|
6656
|
+
|
|
6657
|
+
Note that computing the adjoint is not yet supported.
|
|
6658
|
+
|
|
6659
|
+
Supported datatypes are:
|
|
6660
|
+
* float32
|
|
6661
|
+
* float64
|
|
6662
|
+
|
|
6663
|
+
:param L: A square, lower triangular, matrix, such that LL^T = A
|
|
6664
|
+
:param x: An 1D tile of length M
|
|
6665
|
+
:returns y: An 1D tile of length M such that LL^T y = x""",
|
|
6666
|
+
group="Tile Primitives",
|
|
6667
|
+
export=False,
|
|
6668
|
+
namespace="",
|
|
6669
|
+
)
|
|
6670
|
+
|
|
6099
6671
|
# ---------------------------------
|
|
6100
6672
|
# Code Generation
|
|
6101
6673
|
|
|
@@ -6103,7 +6675,7 @@ add_builtin(
|
|
|
6103
6675
|
"static",
|
|
6104
6676
|
input_types={"expr": Any},
|
|
6105
6677
|
value_type=Any,
|
|
6106
|
-
doc="""
|
|
6678
|
+
doc="""Evaluate a static Python expression and replaces it with its result.
|
|
6107
6679
|
|
|
6108
6680
|
See the :ref:`code generation guide <static_expressions>` for more details.
|
|
6109
6681
|
|
|
@@ -6127,3 +6699,58 @@ def static(expr):
|
|
|
6127
6699
|
which includes constant variables and variables captured in the current closure in which the function or kernel is implemented.
|
|
6128
6700
|
"""
|
|
6129
6701
|
return expr
|
|
6702
|
+
|
|
6703
|
+
|
|
6704
|
+
add_builtin(
|
|
6705
|
+
"len",
|
|
6706
|
+
input_types={"a": vector(length=Any, dtype=Scalar)},
|
|
6707
|
+
value_type=int,
|
|
6708
|
+
doc="Return the number of elements in a vector.",
|
|
6709
|
+
group="Utility",
|
|
6710
|
+
export=False,
|
|
6711
|
+
)
|
|
6712
|
+
|
|
6713
|
+
add_builtin(
|
|
6714
|
+
"len",
|
|
6715
|
+
input_types={"a": quaternion(dtype=Scalar)},
|
|
6716
|
+
value_type=int,
|
|
6717
|
+
doc="Return the number of elements in a quaternion.",
|
|
6718
|
+
group="Utility",
|
|
6719
|
+
export=False,
|
|
6720
|
+
)
|
|
6721
|
+
|
|
6722
|
+
add_builtin(
|
|
6723
|
+
"len",
|
|
6724
|
+
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
|
|
6725
|
+
value_type=int,
|
|
6726
|
+
doc="Return the number of rows in a matrix.",
|
|
6727
|
+
group="Utility",
|
|
6728
|
+
export=False,
|
|
6729
|
+
)
|
|
6730
|
+
|
|
6731
|
+
add_builtin(
|
|
6732
|
+
"len",
|
|
6733
|
+
input_types={"a": transformation(dtype=Float)},
|
|
6734
|
+
value_type=int,
|
|
6735
|
+
doc="Return the number of elements in a transformation.",
|
|
6736
|
+
group="Utility",
|
|
6737
|
+
export=False,
|
|
6738
|
+
)
|
|
6739
|
+
|
|
6740
|
+
add_builtin(
|
|
6741
|
+
"len",
|
|
6742
|
+
input_types={"a": array(dtype=Any)},
|
|
6743
|
+
value_type=int,
|
|
6744
|
+
doc="Return the size of the first dimension in an array.",
|
|
6745
|
+
group="Utility",
|
|
6746
|
+
export=False,
|
|
6747
|
+
)
|
|
6748
|
+
|
|
6749
|
+
add_builtin(
|
|
6750
|
+
"len",
|
|
6751
|
+
input_types={"a": Tile(dtype=Any, shape=Any)},
|
|
6752
|
+
value_type=int,
|
|
6753
|
+
doc="Return the number of rows in a tile.",
|
|
6754
|
+
group="Utility",
|
|
6755
|
+
export=False,
|
|
6756
|
+
)
|