warp-lang 1.5.1__py3-none-manylinux2014_aarch64.whl → 1.6.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (123) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1076 -480
  8. warp/codegen.py +240 -119
  9. warp/config.py +1 -1
  10. warp/context.py +298 -84
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth_self_contact.py +260 -0
  27. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  28. warp/examples/sim/example_jacobian_ik.py +0 -2
  29. warp/examples/sim/example_quadruped.py +5 -2
  30. warp/examples/tile/example_tile_cholesky.py +79 -0
  31. warp/examples/tile/example_tile_convolution.py +2 -2
  32. warp/examples/tile/example_tile_fft.py +2 -2
  33. warp/examples/tile/example_tile_filtering.py +3 -3
  34. warp/examples/tile/example_tile_matmul.py +4 -4
  35. warp/examples/tile/example_tile_mlp.py +12 -12
  36. warp/examples/tile/example_tile_nbody.py +180 -0
  37. warp/examples/tile/example_tile_walker.py +319 -0
  38. warp/math.py +147 -0
  39. warp/native/array.h +12 -0
  40. warp/native/builtin.h +0 -1
  41. warp/native/bvh.cpp +149 -70
  42. warp/native/bvh.cu +287 -68
  43. warp/native/bvh.h +195 -85
  44. warp/native/clang/clang.cpp +5 -1
  45. warp/native/cuda_util.cpp +35 -0
  46. warp/native/cuda_util.h +5 -0
  47. warp/native/exports.h +40 -40
  48. warp/native/intersect.h +17 -0
  49. warp/native/mat.h +41 -0
  50. warp/native/mathdx.cpp +19 -0
  51. warp/native/mesh.cpp +25 -8
  52. warp/native/mesh.cu +153 -101
  53. warp/native/mesh.h +482 -403
  54. warp/native/quat.h +40 -0
  55. warp/native/solid_angle.h +7 -0
  56. warp/native/sort.cpp +85 -0
  57. warp/native/sort.cu +34 -0
  58. warp/native/sort.h +3 -1
  59. warp/native/spatial.h +11 -0
  60. warp/native/tile.h +1185 -664
  61. warp/native/tile_reduce.h +8 -6
  62. warp/native/vec.h +41 -0
  63. warp/native/warp.cpp +8 -1
  64. warp/native/warp.cu +263 -40
  65. warp/native/warp.h +19 -5
  66. warp/optim/linear.py +22 -4
  67. warp/render/render_opengl.py +124 -59
  68. warp/sim/__init__.py +6 -1
  69. warp/sim/collide.py +270 -26
  70. warp/sim/integrator_euler.py +25 -7
  71. warp/sim/integrator_featherstone.py +154 -35
  72. warp/sim/integrator_vbd.py +842 -40
  73. warp/sim/model.py +111 -53
  74. warp/stubs.py +248 -115
  75. warp/tape.py +28 -30
  76. warp/tests/aux_test_module_unload.py +15 -0
  77. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  78. warp/tests/test_array.py +74 -0
  79. warp/tests/test_assert.py +242 -0
  80. warp/tests/test_codegen.py +14 -61
  81. warp/tests/test_collision.py +2 -2
  82. warp/tests/test_examples.py +9 -0
  83. warp/tests/test_grad_debug.py +87 -2
  84. warp/tests/test_hash_grid.py +1 -1
  85. warp/tests/test_ipc.py +116 -0
  86. warp/tests/test_mat.py +138 -167
  87. warp/tests/test_math.py +47 -1
  88. warp/tests/test_matmul.py +11 -7
  89. warp/tests/test_matmul_lite.py +4 -4
  90. warp/tests/test_mesh.py +84 -60
  91. warp/tests/test_mesh_query_aabb.py +165 -0
  92. warp/tests/test_mesh_query_point.py +328 -286
  93. warp/tests/test_mesh_query_ray.py +134 -121
  94. warp/tests/test_mlp.py +2 -2
  95. warp/tests/test_operators.py +43 -0
  96. warp/tests/test_overwrite.py +2 -2
  97. warp/tests/test_quat.py +77 -0
  98. warp/tests/test_reload.py +29 -0
  99. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  100. warp/tests/test_static.py +16 -0
  101. warp/tests/test_tape.py +25 -0
  102. warp/tests/test_tile.py +134 -191
  103. warp/tests/test_tile_load.py +356 -0
  104. warp/tests/test_tile_mathdx.py +61 -8
  105. warp/tests/test_tile_mlp.py +17 -17
  106. warp/tests/test_tile_reduce.py +24 -18
  107. warp/tests/test_tile_shared_memory.py +66 -17
  108. warp/tests/test_tile_view.py +165 -0
  109. warp/tests/test_torch.py +35 -0
  110. warp/tests/test_utils.py +36 -24
  111. warp/tests/test_vec.py +110 -0
  112. warp/tests/unittest_suites.py +29 -4
  113. warp/tests/unittest_utils.py +30 -11
  114. warp/thirdparty/unittest_parallel.py +2 -2
  115. warp/types.py +409 -99
  116. warp/utils.py +9 -5
  117. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/METADATA +68 -44
  118. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/RECORD +121 -110
  119. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  120. warp/examples/benchmarks/benchmark_tile.py +0 -179
  121. warp/native/tile_gemm.h +0 -341
  122. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  123. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -1707,64 +1707,98 @@ add_builtin(
1707
1707
 
1708
1708
  # ------------------
1709
1709
  # Tile-based primitives
1710
- shared_memory_id = 0
1710
+
1711
+
1712
+ def tile_unpack_shape(arg_values):
1713
+ shape = arg_values["shape"]
1714
+
1715
+ if not isinstance(shape, tuple):
1716
+ # promote to tuple
1717
+ shape = (shape,)
1718
+
1719
+ # check that components are constants
1720
+ for d in shape:
1721
+ if d is None:
1722
+ raise ValueError("Tile functions require shape to be a compile time constant.")
1723
+
1724
+ return shape
1725
+
1726
+
1727
+ def tile_unpack_offset(arg_values, ndim=0):
1728
+ if "offset" in arg_values:
1729
+ offset = arg_values["offset"]
1730
+ else:
1731
+ offset = (0,) * ndim
1732
+
1733
+ if isinstance(offset, tuple):
1734
+ return offset
1735
+ else:
1736
+ # promote to tuple
1737
+ return (offset,)
1711
1738
 
1712
1739
 
1713
1740
  def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1714
1741
  # return generic type (for doc builds)
1715
1742
  if arg_types is None:
1716
- return Tile(dtype=Any, M=Any, N=Any)
1717
-
1718
- if "m" not in arg_values:
1719
- raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
1743
+ return Tile(dtype=Any, shape=Any)
1720
1744
 
1721
- if "n" not in arg_values:
1722
- raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
1745
+ shape = tile_unpack_shape(arg_values)
1723
1746
 
1724
1747
  if "dtype" not in arg_values:
1725
- raise RuntimeError("'dtype' keyword argument must be specified when calling tile_zeros() function")
1748
+ raise TypeError("tile_zeros() missing required keyword argument 'dtype'")
1726
1749
 
1727
1750
  if "storage" not in arg_values:
1728
- raise ValueError("'storage' keyword not provided for tile_zeros")
1751
+ raise TypeError("tile_zeros() missing required keyword argument 'storage'")
1729
1752
 
1730
1753
  if arg_values["storage"] not in {"shared", "register"}:
1731
- raise ValueError(
1732
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1733
- )
1754
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
1734
1755
 
1735
- m, n = arg_values["m"], arg_values["n"]
1736
1756
  dtype = arg_values["dtype"]
1737
1757
 
1738
- return TileZeros(dtype=dtype, M=m, N=n, storage=arg_values["storage"])
1758
+ return TileZeros(dtype=dtype, shape=shape, storage=arg_values["storage"])
1739
1759
 
1740
1760
 
1741
1761
  def tile_zeros_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1742
- m, n, dtype = arg_values["m"], arg_values["n"], arg_values["dtype"]
1762
+ shape = tile_unpack_shape(arg_values)
1763
+ dtype = arg_values["dtype"]
1743
1764
 
1744
1765
  template_args = []
1745
1766
  template_args.append(dtype)
1746
- template_args.append(m.constant)
1747
- template_args.append(n.constant)
1767
+ for d in shape:
1768
+ template_args.append(d.constant)
1748
1769
 
1749
1770
  return ([], template_args)
1750
1771
 
1751
1772
 
1752
1773
  add_builtin(
1753
1774
  "tile_zeros",
1754
- input_types={"m": int, "n": int, "dtype": Any, "storage": str},
1755
- defaults={"storage": "register"},
1775
+ input_types={"shape": Tuple[int, ...], "dtype": Any, "storage": str},
1776
+ defaults={"storage": "register", "dtype": float},
1756
1777
  value_func=tile_zeros_value_func,
1757
1778
  dispatch_func=tile_zeros_dispatch_func,
1758
1779
  variadic=False,
1759
1780
  missing_grad=True,
1760
- doc="""Allocates a tile of zero-initialized items.
1781
+ doc="""Allocate a tile of zero-initialized items.
1761
1782
 
1762
- :param m: Size of the first dimension of the output tile
1763
- :param n: Size of the second dimension of the output tile
1764
- :param dtype: Datatype of output tile's elements
1783
+ :param shape: Shape of the output tile
1784
+ :param dtype: Data type of output tile's elements (default float)
1765
1785
  :param storage: The storage location for the tile: ``"register"`` for registers
1766
1786
  (default) or ``"shared"`` for shared memory.
1767
- :returns: A zero-initialized tile with ``shape=(m,n)`` and the specified datatype""",
1787
+ :returns: A zero-initialized tile with shape and data type as specified""",
1788
+ group="Tile Primitives",
1789
+ export=False,
1790
+ )
1791
+
1792
+ # overload for scalar shape
1793
+ add_builtin(
1794
+ "tile_zeros",
1795
+ input_types={"shape": int, "dtype": Any, "storage": str},
1796
+ defaults={"storage": "register", "dtype": float},
1797
+ value_func=tile_zeros_value_func,
1798
+ dispatch_func=tile_zeros_dispatch_func,
1799
+ variadic=False,
1800
+ missing_grad=True,
1801
+ hidden=True,
1768
1802
  group="Tile Primitives",
1769
1803
  export=False,
1770
1804
  )
@@ -1773,54 +1807,63 @@ add_builtin(
1773
1807
  def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1774
1808
  # return generic type (for doc builds)
1775
1809
  if arg_types is None:
1776
- return Tile(dtype=Any, M=Any, N=Any)
1810
+ return Tile(dtype=Any, shape=Any)
1777
1811
 
1778
- if "m" not in arg_values:
1779
- raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
1780
-
1781
- if "n" not in arg_values:
1782
- raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
1812
+ shape = tile_unpack_shape(arg_values)
1783
1813
 
1784
1814
  if "dtype" not in arg_values:
1785
- raise RuntimeError("'dtype' keyword argument must be specified when calling tile_zeros() function")
1815
+ raise TypeError("tile_ones() missing required keyword argument 'dtype'")
1816
+
1817
+ if "storage" not in arg_values:
1818
+ raise TypeError("tile_ones() missing required keyword argument 'storage'")
1786
1819
 
1787
1820
  if arg_values["storage"] not in {"shared", "register"}:
1788
- raise ValueError(
1789
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1790
- )
1821
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
1791
1822
 
1792
- m, n = arg_values["m"], arg_values["n"]
1793
1823
  dtype = arg_values["dtype"]
1794
1824
 
1795
- return TileZeros(dtype=dtype, M=m, N=n, storage=arg_values["storage"])
1825
+ return TileOnes(dtype=dtype, shape=shape, storage=arg_values["storage"])
1796
1826
 
1797
1827
 
1798
1828
  def tile_ones_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1799
- m, n, dtype = arg_values["m"], arg_values["n"], arg_values["dtype"]
1829
+ shape = tile_unpack_shape(arg_values)
1830
+ dtype = arg_values["dtype"]
1800
1831
 
1801
1832
  template_args = []
1802
1833
  template_args.append(dtype)
1803
- template_args.append(m.constant)
1804
- template_args.append(n.constant)
1834
+ for d in shape:
1835
+ template_args.append(d.constant)
1805
1836
 
1806
1837
  return ([], template_args)
1807
1838
 
1808
1839
 
1809
1840
  add_builtin(
1810
1841
  "tile_ones",
1811
- input_types={"m": int, "n": int, "dtype": Any, "storage": str},
1842
+ input_types={"shape": Tuple[int, ...], "dtype": Any, "storage": str},
1812
1843
  defaults={"storage": "register"},
1813
1844
  value_func=tile_ones_value_func,
1814
1845
  dispatch_func=tile_ones_dispatch_func,
1815
1846
  missing_grad=True,
1816
- doc="""Allocates a tile of one-initialized items.
1847
+ doc="""Allocate a tile of one-initialized items.
1817
1848
 
1818
- :param m: Size of the first dimension of the output tile
1819
- :param n: Size of the second dimension of the output tile
1820
- :param dtype: Datatype of output tile's elements
1849
+ :param shape: Shape of the output tile
1850
+ :param dtype: Data type of output tile's elements
1821
1851
  :param storage: The storage location for the tile: ``"register"`` for registers
1822
1852
  (default) or ``"shared"`` for shared memory.
1823
- :returns: A one-initialized tile with ``shape=(m,n)`` and the specified dtype""",
1853
+ :returns: A one-initialized tile with shape and data type as specified""",
1854
+ group="Tile Primitives",
1855
+ export=False,
1856
+ )
1857
+
1858
+ # overload for scalar shape
1859
+ add_builtin(
1860
+ "tile_ones",
1861
+ input_types={"shape": int, "dtype": Any, "storage": str},
1862
+ defaults={"storage": "register"},
1863
+ value_func=tile_ones_value_func,
1864
+ dispatch_func=tile_ones_dispatch_func,
1865
+ missing_grad=True,
1866
+ hidden=True,
1824
1867
  group="Tile Primitives",
1825
1868
  export=False,
1826
1869
  )
@@ -1829,14 +1872,16 @@ add_builtin(
1829
1872
  def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1830
1873
  # return generic type (for doc builds)
1831
1874
  if arg_types is None:
1832
- return Tile(dtype=Any, M=Any, N=Any)
1875
+ return Tile(dtype=Any, shape=Any)
1876
+
1877
+ if "args" not in arg_values:
1878
+ raise TypeError("tile_arange() requires at least one positional argument specifying the range")
1879
+
1880
+ args = arg_values["args"]
1833
1881
 
1834
1882
  start = 0
1835
1883
  stop = 0
1836
1884
  step = 1
1837
- dtype = int
1838
-
1839
- args = arg_values["args"]
1840
1885
 
1841
1886
  if len(args) == 1:
1842
1887
  start = 0
@@ -1852,8 +1897,7 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
1852
1897
  step = args[2]
1853
1898
 
1854
1899
  if start is None or stop is None or step is None:
1855
- print(args)
1856
- raise RuntimeError("wp.tile_arange() arguments must be compile time constants")
1900
+ raise RuntimeError("tile_arange() arguments must be compile time constants")
1857
1901
 
1858
1902
  if "dtype" in arg_values:
1859
1903
  dtype = arg_values["dtype"]
@@ -1861,26 +1905,37 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
1861
1905
  dtype = float
1862
1906
 
1863
1907
  if arg_values["storage"] not in {"shared", "register"}:
1864
- raise ValueError(
1865
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1866
- )
1908
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
1867
1909
 
1868
1910
  return TileRange(dtype=dtype, start=start, stop=stop, step=step, storage=arg_values["storage"])
1869
1911
 
1870
1912
 
1871
1913
  def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1872
- m, n, dtype = return_type.M, return_type.N, return_type.dtype
1914
+ size, dtype = return_type.size, return_type.dtype
1873
1915
 
1874
1916
  template_args = []
1875
1917
  template_args.append(dtype)
1876
- template_args.append(m)
1877
- template_args.append(n)
1918
+ template_args.append(size)
1919
+
1920
+ if "args" not in arg_values:
1921
+ raise TypeError("tile_arange() requires at least one positional argument specifying the range")
1922
+
1923
+ args = arg_values["args"]
1878
1924
 
1879
- # todo: it is somewhat redundant to create new vars here since some of start,stop,step
1880
- # already exist depending on which form the function was called by the user
1881
- start = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.start)
1882
- stop = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.stop)
1883
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
1925
+ if len(args) == 1:
1926
+ start = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.start)
1927
+ stop = args[0]
1928
+ step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
1929
+ elif len(args) == 2:
1930
+ start = args[0]
1931
+ stop = args[1]
1932
+ step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
1933
+ elif len(args) == 3:
1934
+ start = args[0]
1935
+ stop = args[1]
1936
+ step = args[2]
1937
+ else:
1938
+ raise TypeError(f"tile_arange() accepts at most 3 positional arguments, got {len(args)}")
1884
1939
 
1885
1940
  function_args = []
1886
1941
  function_args.append(start)
@@ -1898,7 +1953,7 @@ add_builtin(
1898
1953
  dispatch_func=tile_arange_dispatch_func,
1899
1954
  variadic=True,
1900
1955
  missing_grad=True,
1901
- doc="""Generates a tile of linearly spaced elements.
1956
+ doc="""Generate a tile of linearly spaced elements.
1902
1957
 
1903
1958
  :param args: Variable-length positional arguments, interpreted as:
1904
1959
 
@@ -1906,246 +1961,157 @@ add_builtin(
1906
1961
  - ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
1907
1962
  - ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
1908
1963
 
1909
- :param dtype: Datatype of output tile's elements (optional, default: int)
1964
+ :param dtype: Data type of output tile's elements (optional, default: ``float``)
1910
1965
  :param storage: The storage location for the tile: ``"register"`` for registers
1911
1966
  (default) or ``"shared"`` for shared memory.
1912
- :returns: A tile with ``shape=(1,n)`` with linearly spaced elements of specified dtype""",
1967
+ :returns: A tile with ``shape=(n)`` with linearly spaced elements of specified data type""",
1913
1968
  group="Tile Primitives",
1914
1969
  export=False,
1915
1970
  )
1916
1971
 
1917
1972
 
1918
- def tile_load_1d_value_func(arg_types, arg_values):
1919
- # return generic type (for doc builds)
1973
+ def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1920
1974
  if arg_types is None:
1921
- return Tile(dtype=Any, M=Any, N=Any)
1922
-
1923
- if not is_array(arg_types["a"]):
1924
- raise RuntimeError("tile_load() argument 0 must be an array")
1975
+ return array(dtype=Scalar)
1925
1976
 
1926
- if arg_types["a"].ndim != 1:
1927
- raise RuntimeError(
1928
- "tile_load() argument 0 must be 1-dimensional if using the ``wp.tile_load(array, i, n)`` syntax."
1929
- )
1977
+ a = arg_types["a"]
1930
1978
 
1931
- if not type_is_int(arg_types["i"]):
1932
- raise RuntimeError("tile_load() argument 1 must be an integer")
1979
+ shape = tile_unpack_shape(arg_values)
1980
+ offset = tile_unpack_offset(arg_values, a.ndim)
1933
1981
 
1934
- if "n" not in arg_values:
1935
- raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
1982
+ if a.ndim != len(shape):
1983
+ raise ValueError(
1984
+ f"tile_load() array argument must have same number of dimensions as the tile shape, trying to perform an {len(shape)} dimensional load from an array with {a.ndim} dimensions."
1985
+ )
1936
1986
 
1937
- if arg_values["storage"] not in {"shared", "register"}:
1987
+ if a.ndim != len(offset):
1938
1988
  raise ValueError(
1939
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1989
+ f"tile_load() offset argument must have the same number of dimensions as the array to load from, got {len(offset)} indices for an array with {a.ndim} dimensions"
1940
1990
  )
1941
1991
 
1942
- a = arg_types["a"]
1943
- _m, n = 1, arg_values["n"]
1992
+ if arg_values["storage"] not in {"shared", "register"}:
1993
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
1944
1994
 
1945
- return TileLoad(a, 1, n, arg_values["storage"])
1995
+ return Tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
1946
1996
 
1947
1997
 
1948
- def tile_load_1d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1949
- array = arg_values["a"]
1950
- i = arg_values["i"]
1951
- n = arg_values["n"].constant
1952
- dtype = arg_values["a"].type.dtype
1998
+ def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1999
+ a = args["a"]
2000
+ shape = tile_unpack_shape(args)
2001
+ offset = tile_unpack_offset(args, a.type.ndim)
1953
2002
 
1954
- template_args = []
1955
- template_args.append(dtype)
1956
- template_args.append(n)
2003
+ func_args = (a, *offset)
2004
+ template_args = (d.constant for d in shape)
1957
2005
 
1958
- return ((array, i), template_args)
2006
+ return (func_args, template_args)
1959
2007
 
1960
2008
 
1961
2009
  add_builtin(
1962
2010
  "tile_load",
1963
- input_types={"a": array(dtype=Any), "i": int, "n": int, "storage": str},
1964
- defaults={"storage": "register"},
1965
- value_func=tile_load_1d_value_func,
1966
- dispatch_func=tile_load_1d_dispatch_func,
2011
+ input_types={"a": array(dtype=Any), "shape": Tuple[int, ...], "offset": Tuple[int, ...], "storage": str},
2012
+ value_func=tile_load_tuple_value_func,
2013
+ dispatch_func=tile_load_tuple_dispatch_func,
2014
+ defaults={"offset": None, "storage": "register"},
1967
2015
  variadic=False,
1968
- doc="""Loads a 1D tile from a global memory array.
2016
+ doc="""Loads a tile from a global memory array.
1969
2017
 
1970
2018
  This method will cooperatively load a tile from global memory using all threads in the block.
1971
2019
 
1972
2020
  :param a: The source array in global memory
1973
- :param i: Offset in the source array measured in multiples of ``n``, i.e.: ``offset=i*n``
1974
- :param n: The number of elements in the tile
2021
+ :param shape: Shape of the tile to load, must have the same number of dimensions as ``a``
2022
+ :param offset: Offset in the source array to begin reading from (optional)
1975
2023
  :param storage: The storage location for the tile: ``"register"`` for registers
1976
2024
  (default) or ``"shared"`` for shared memory.
1977
- :returns: A tile with ``shape=(1,n)`` and dtype the same as the source array""",
2025
+ :returns: A tile with shape as specified and data type the same as the source array""",
1978
2026
  group="Tile Primitives",
1979
2027
  export=False,
1980
2028
  )
1981
2029
 
1982
-
1983
- def tile_load_2d_value_func(arg_types, arg_values):
1984
- # return generic type (for doc builds)
1985
- if arg_types is None:
1986
- return Tile(dtype=Any, M=Any, N=Any)
1987
-
1988
- if not is_array(arg_types["a"]):
1989
- raise RuntimeError("tile_load() argument 0 must be an array")
1990
-
1991
- if arg_types["a"].ndim != 2:
1992
- raise RuntimeError(
1993
- "tile_load() argument 0 must be 2-dimensional if using the ``wp.tile_load(array, i, j, m, n)`` syntax."
1994
- )
1995
-
1996
- if not type_is_int(arg_types["i"]):
1997
- raise RuntimeError("tile_load() argument 1 must be an integer")
1998
-
1999
- if not type_is_int(arg_types["j"]):
2000
- raise RuntimeError("tile_load() argument 1 must be an integer")
2001
-
2002
- if "m" not in arg_values:
2003
- raise RuntimeError("'m' keyword argument must be specified when calling tile_load() function")
2004
-
2005
- if "n" not in arg_values:
2006
- raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
2007
-
2008
- if arg_values["storage"] not in {"shared", "register"}:
2009
- raise ValueError(
2010
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
2011
- )
2012
-
2013
- a = arg_types["a"]
2014
- m, n = arg_values["m"], arg_values["n"]
2015
-
2016
- return TileLoad(a, m, n, arg_values["storage"])
2017
-
2018
-
2019
- def tile_load_2d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2020
- array = arg_values["a"]
2021
- i, j = arg_values["i"], arg_values["j"]
2022
- m, n = arg_values["m"].constant, arg_values["n"].constant
2023
- dtype = arg_values["a"].type.dtype
2024
-
2025
- template_args = []
2026
- template_args.append(dtype)
2027
- template_args.append(m)
2028
- template_args.append(n)
2029
-
2030
- return ((array, i, j), template_args)
2031
-
2032
-
2030
+ # overload for scalar shape
2033
2031
  add_builtin(
2034
2032
  "tile_load",
2035
- input_types={"a": array(dtype=Any), "i": int, "j": int, "m": int, "n": int, "storage": str},
2036
- defaults={"storage": "register"},
2037
- value_func=tile_load_2d_value_func,
2038
- dispatch_func=tile_load_2d_dispatch_func,
2039
- variadic=False,
2040
- doc="""Loads a 2D tile from a global memory array.
2041
-
2042
- This method will cooperatively load a tile from global memory using all threads in the block.
2043
-
2044
- :param a: The source array in global memory
2045
- :param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
2046
- :param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
2047
- :param m: The size of the tile's first dimension
2048
- :param n: The size of the tile's second dimension
2049
- :param storage: The storage location for the tile: ``"register"`` for registers
2050
- (default) or ``"shared"`` for shared memory.
2051
- :returns: A tile with ``shape=(m,n)`` and dtype the same as the source array""",
2033
+ input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str},
2034
+ value_func=tile_load_tuple_value_func,
2035
+ dispatch_func=tile_load_tuple_dispatch_func,
2036
+ defaults={"offset": None, "storage": "register"},
2052
2037
  group="Tile Primitives",
2038
+ hidden=True,
2053
2039
  export=False,
2054
2040
  )
2055
2041
 
2056
2042
 
2057
- def tile_store_1d_value_func(arg_types, arg_values):
2043
+ def tile_store_value_func(arg_types, arg_values):
2058
2044
  # return generic type (for doc builds)
2059
2045
  if arg_types is None:
2060
2046
  return None
2061
2047
 
2062
- if len(arg_types) != 3:
2063
- raise RuntimeError("tile_store() requires 3 positional args")
2048
+ a = arg_types["a"]
2049
+ t = arg_types["t"]
2064
2050
 
2065
- if not is_array(arg_types["a"]):
2066
- raise RuntimeError("tile_store() argument 0 must be an array")
2051
+ c = tile_unpack_offset(arg_types, a.ndim)
2067
2052
 
2068
- if arg_types["a"].ndim != 1:
2069
- raise RuntimeError(
2070
- "tile_load() argument 0 must be a 1-dimensional array if using the ``wp.tile_store(array, i, t)`` syntax."
2053
+ if len(c) != a.ndim:
2054
+ raise ValueError(
2055
+ f"tile_store() 'a' argument must have {len(c)} dimensions, "
2056
+ f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
2071
2057
  )
2072
2058
 
2073
- if not type_is_int(arg_types["i"]):
2074
- raise RuntimeError("tile_store() argument 1 must be an integer")
2075
-
2076
- if not is_tile(arg_types["t"]):
2077
- raise RuntimeError("tile_store() argument 2 must be a tile")
2059
+ if len(t.shape) != a.ndim:
2060
+ raise ValueError(
2061
+ f"tile_store() 'a' argument must have the same number of dimensions as the 't' argument, "
2062
+ f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
2063
+ )
2078
2064
 
2079
2065
  if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2080
- raise RuntimeError("tile_store() destination array must have same type as source tile")
2066
+ raise TypeError(
2067
+ f"tile_store() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2068
+ )
2081
2069
 
2082
2070
  return None
2083
2071
 
2084
2072
 
2073
+ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2074
+ a = args["a"]
2075
+ t = args["t"]
2076
+
2077
+ offset = tile_unpack_offset(args, a.type.ndim)
2078
+
2079
+ func_args = (a, *offset, t)
2080
+ template_args = []
2081
+
2082
+ return (func_args, template_args)
2083
+
2084
+
2085
2085
  add_builtin(
2086
2086
  "tile_store",
2087
- input_types={"a": array(dtype=Any), "i": int, "t": Tile(dtype=Any, M=Any, N=Any)},
2088
- value_func=tile_store_1d_value_func,
2087
+ input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
2088
+ value_func=tile_store_value_func,
2089
+ dispatch_func=tile_store_dispatch_func,
2090
+ defaults={"offset": None},
2089
2091
  variadic=False,
2090
2092
  skip_replay=True,
2091
- doc="""Stores a 1D tile to a global memory array.
2093
+ doc="""Store a tile to a global memory array.
2092
2094
 
2093
2095
  This method will cooperatively store a tile to global memory using all threads in the block.
2094
2096
 
2095
2097
  :param a: The destination array in global memory
2096
- :param i: Offset in the destination array measured in multiples of ``n``, i.e.: ``offset=i*n``
2097
- :param t: The source tile to store data from, must have the same dtype as the destination array""",
2098
+ :param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
2099
+ :param offset: Offset in the destination array (optional)""",
2098
2100
  group="Tile Primitives",
2099
2101
  export=False,
2100
2102
  )
2101
2103
 
2102
-
2103
- def tile_store_2d_value_func(arg_types, arg_values):
2104
- # return generic type (for doc builds)
2105
- if arg_types is None:
2106
- return None
2107
-
2108
- if len(arg_types) != 4:
2109
- raise RuntimeError("tile_store() requires 4 positional args")
2110
-
2111
- if not is_array(arg_types["a"]):
2112
- raise RuntimeError("tile_store() argument 0 must be an array")
2113
-
2114
- if arg_types["a"].ndim != 2:
2115
- raise RuntimeError(
2116
- "tile_load() argument 0 must be a 2-dimensional array if using the ``wp.tile_store(array, i, j, t)`` syntax."
2117
- )
2118
-
2119
- if not type_is_int(arg_types["i"]):
2120
- raise RuntimeError("tile_store() argument 1 must be an integer")
2121
-
2122
- if not type_is_int(arg_types["j"]):
2123
- raise RuntimeError("tile_store() argument 2 must be an integer")
2124
-
2125
- if not is_tile(arg_types["t"]):
2126
- raise RuntimeError("tile_store() argument 3 must be a tile")
2127
-
2128
- if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2129
- raise RuntimeError("tile_store() destination array must have same type as source tile")
2130
-
2131
- return None
2132
-
2133
-
2104
+ # overload for scalar offset
2134
2105
  add_builtin(
2135
2106
  "tile_store",
2136
- input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Tile(dtype=Any, M=Any, N=Any)},
2137
- value_func=tile_store_2d_value_func,
2107
+ input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
2108
+ value_func=tile_store_value_func,
2109
+ dispatch_func=tile_store_dispatch_func,
2110
+ defaults={"offset": None},
2138
2111
  variadic=False,
2139
2112
  skip_replay=True,
2140
- doc="""Stores a tile to a global memory array.
2141
-
2142
- This method will cooperatively store a tile to global memory using all threads in the block.
2143
-
2144
- :param a: The destination array in global memory
2145
- :param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
2146
- :param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
2147
- :param t: The source tile to store data from, must have the same dtype as the destination array""",
2148
2113
  group="Tile Primitives",
2114
+ hidden=True,
2149
2115
  export=False,
2150
2116
  )
2151
2117
 
@@ -2153,130 +2119,219 @@ add_builtin(
2153
2119
  def tile_atomic_add_value_func(arg_types, arg_values):
2154
2120
  # return generic type (for doc builds)
2155
2121
  if arg_types is None:
2156
- return Tile(dtype=Any, M=Any, N=Any)
2122
+ return Tile(dtype=Any, shape=Any)
2123
+
2124
+ a = arg_types["a"]
2125
+ t = arg_types["t"]
2126
+
2127
+ c = tile_unpack_offset(arg_types, a.ndim)
2128
+ if len(c) != a.ndim:
2129
+ raise ValueError(
2130
+ f"tile_atomic_add() 'a' argument must have {len(c)} dimensions, "
2131
+ f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
2132
+ )
2133
+
2134
+ if a.ndim != len(t.shape):
2135
+ raise ValueError(
2136
+ f"tile_atomic_add() 'a' argument must have the same number of dimensions as the 't' argument, "
2137
+ f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
2138
+ )
2157
2139
 
2158
- if len(arg_types) != 4:
2159
- raise RuntimeError("tile_atomic_add() requires 4 positional args")
2140
+ if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2141
+ raise TypeError(
2142
+ f"tile_atomic_add() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2143
+ )
2160
2144
 
2161
- if not is_array(arg_types["a"]):
2162
- raise RuntimeError("tile_atomic_add() argument 0 must be an array")
2145
+ return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape)
2163
2146
 
2164
- if not type_is_int(arg_types["x"]):
2165
- raise RuntimeError("tile_atomic_add() argument 1 must be an integer")
2166
2147
 
2167
- if not type_is_int(arg_types["y"]):
2168
- raise RuntimeError("tile_atomic_add() argument 2 must be an integer")
2148
+ def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2149
+ a = args["a"]
2150
+ t = args["t"]
2169
2151
 
2170
- if not is_tile(arg_types["t"]):
2171
- raise RuntimeError("tile_atomic_add() argument 3 must be a tile")
2152
+ offset = tile_unpack_offset(args, a.type.ndim)
2172
2153
 
2173
- if arg_types["a"].dtype != arg_types["t"].dtype:
2174
- raise RuntimeError("tile_atomic_add() tile dtype and array dtype must match")
2154
+ func_args = (a, *offset, t)
2155
+ template_args = []
2175
2156
 
2176
- return Tile(dtype=arg_types["t"].dtype, M=arg_types["t"].M, N=arg_types["t"].N)
2157
+ return (func_args, template_args)
2177
2158
 
2178
2159
 
2179
2160
  add_builtin(
2180
2161
  "tile_atomic_add",
2181
- input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Tile(dtype=Any, M=Any, N=Any)},
2162
+ input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
2182
2163
  value_func=tile_atomic_add_value_func,
2183
- variadic=True,
2164
+ dispatch_func=tile_atomic_add_dispatch_func,
2165
+ defaults={"offset": None},
2166
+ variadic=False,
2184
2167
  skip_replay=True,
2185
- doc="""Atomically add a tile to the array `a`, each element will be updated atomically.
2168
+ doc="""Atomically add a 1D tile to the array `a`, each element will be updated atomically.
2186
2169
 
2187
2170
  :param a: Array in global memory, should have the same ``dtype`` as the input tile
2188
- :param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
2189
- :param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
2190
2171
  :param t: Source tile to add to the destination array
2191
- :returns: A tile with the same dimensions and type as the source tile, holding the original value of the destination elements""",
2172
+ :param offset: Offset in the destination array (optional)
2173
+ :returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements""",
2192
2174
  group="Tile Primitives",
2193
2175
  export=False,
2194
2176
  )
2195
2177
 
2178
+ # overload for scalar offset
2179
+ add_builtin(
2180
+ "tile_atomic_add",
2181
+ input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
2182
+ value_func=tile_atomic_add_value_func,
2183
+ dispatch_func=tile_atomic_add_dispatch_func,
2184
+ defaults={"offset": None},
2185
+ variadic=False,
2186
+ skip_replay=True,
2187
+ group="Tile Primitives",
2188
+ hidden=True,
2189
+ export=False,
2190
+ )
2191
+
2196
2192
 
2197
2193
  def tile_view_value_func(arg_types, arg_values):
2198
2194
  # return generic type (for doc builds)
2199
2195
  if arg_types is None:
2200
- return Tile(dtype=Any, M=Any, N=Any)
2196
+ return Tile(dtype=Any, shape=Any)
2201
2197
 
2202
2198
  tile = arg_types["t"]
2199
+ offset = arg_types["offset"]
2203
2200
 
2204
- if "m" not in arg_values:
2205
- m = 1
2206
- else:
2207
- m = arg_values["m"]
2201
+ if len(offset) > len(tile.shape):
2202
+ raise ValueError(f"tile_view() specified too many offset coordinates {len(offset)} > {len(tile.shape)}")
2208
2203
 
2209
- if "n" not in arg_values:
2210
- n = tile.N
2204
+ if "shape" in arg_values:
2205
+ # if shape is specified take it directly, e.g.:
2206
+ # tile_view(t, offset=(i,j), shape=(m,n))
2207
+ shape = arg_values["shape"]
2208
+ strides = tile.strides
2209
+
2210
+ if len(shape) != len(tile.shape):
2211
+ raise ValueError(
2212
+ f"tile_view() if shape is specified it must have same number of dimensions as source tile, expected {len(tile.shape)}, got {len(shape)}"
2213
+ )
2211
2214
  else:
2212
- n = arg_values["n"]
2215
+ # if not specified, then take output shape from unspecified src dimensions
2216
+ # e.g.: tile[i] will return a whole row of a 2D tile
2217
+ shape = tile.shape[len(offset) :]
2218
+ strides = tile.strides[len(offset) :]
2213
2219
 
2214
- if m > tile.M or n > tile.N:
2215
- raise RuntimeError(
2216
- f"Trying to view a tile subrange with dimensions ({m}, {n}) which is larger than source tile with dimensions ({tile.M}, {tile.N})"
2217
- )
2220
+ assert len(shape) == len(strides)
2218
2221
 
2219
2222
  # force source tile to shared memory
2220
2223
  tile.storage = "shared"
2221
2224
 
2222
- output = Tile(dtype=tile.dtype, M=m, N=n, strides=tile.strides, layout=tile.layout, storage="shared", owner=False)
2225
+ output = Tile(dtype=tile.dtype, shape=shape, strides=strides, layout=tile.layout, storage="shared", owner=False)
2223
2226
  return output
2224
2227
 
2225
2228
 
2226
2229
  def tile_view_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2227
2230
  tile = arg_values["t"]
2228
- i = arg_values["i"]
2229
-
2230
- if "j" not in arg_values:
2231
- j = warp.codegen.Var(label=None, type=int, constant=0)
2232
- else:
2233
- j = arg_values["j"]
2231
+ coord = arg_values["offset"]
2234
2232
 
2235
- template_args = []
2236
- template_args.append(return_type.M)
2237
- template_args.append(return_type.N)
2233
+ # zero-pad coord to match source array
2234
+ view_coord = [0] * len(tile.type.shape)
2235
+ for i in range(len(coord)):
2236
+ view_coord[i] = coord[i]
2238
2237
 
2239
- return ((tile, i, j), template_args)
2238
+ return ((tile, *view_coord), (return_type,))
2240
2239
 
2241
2240
 
2242
2241
  add_builtin(
2243
2242
  "tile_view",
2244
- input_types={"t": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int, "m": int, "n": int},
2243
+ input_types={"t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...], "shape": Tuple[int, ...]},
2245
2244
  value_func=tile_view_value_func,
2246
2245
  dispatch_func=tile_view_dispatch_func,
2247
- defaults={"j": None, "m": None, "n": None},
2248
- variadic=True,
2249
- doc="""Return a subrange of a given tile from coordinates (i,j) to (i+m, j+n).
2246
+ defaults={"shape": None},
2247
+ variadic=False,
2248
+ doc="""Return a slice of a given tile [offset, offset+shape], if shape is not specified it will be inferred from the unspecified offset dimensions.
2250
2249
 
2251
2250
  :param t: Input tile to extract a subrange from
2252
- :param i: Offset in the source tile along the first dimension
2253
- :param j: Offset in the source tile along the second dimensions
2254
- :param m: Size of the subrange to return along the first dimension
2255
- :param n: Size of the subrange to return along the second dimension
2256
- :returns: A tile with dimensions (m,n) and the same datatype as the input tile""",
2251
+ :param offset: Offset in the source tile
2252
+ :param shape: Shape of the returned slice
2253
+ :returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
2257
2254
  group="Tile Primitives",
2255
+ missing_grad=True,
2258
2256
  export=False,
2259
2257
  )
2260
2258
 
2261
2259
 
2262
2260
  def tile_assign_value_func(arg_types, arg_values):
2263
- # return generic type (for doc builds)
2261
+ if arg_types is None:
2262
+ return None
2263
+
2264
+ # force the destination tile to shared memory
2265
+ arg_types["dst"].storage = "shared"
2264
2266
  return None
2265
2267
 
2266
2268
 
2269
+ def tile_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2270
+ dst = args["dst"]
2271
+ src = args["src"]
2272
+
2273
+ offset = tile_unpack_offset(args, len(dst.type.shape))
2274
+
2275
+ func_args = (dst, src, *offset)
2276
+ template_args = []
2277
+
2278
+ return (func_args, template_args)
2279
+
2280
+
2267
2281
  add_builtin(
2268
2282
  "tile_assign",
2269
- input_types={"dst": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int, "src": Tile(dtype=Any, M=Any, N=Any)},
2283
+ input_types={"dst": Tile(dtype=Any, shape=Any), "src": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
2284
+ value_func=tile_assign_value_func,
2285
+ dispatch_func=tile_assign_dispatch_func,
2286
+ defaults={"offset": None},
2287
+ doc="""Assign a tile to a subrange of a destination tile.
2288
+
2289
+ :param dst: The destination tile to assign to
2290
+ :param src: The source tile to read values from
2291
+ :param offset: Offset in the destination tile to write to""",
2292
+ group="Tile Primitives",
2293
+ export=False,
2294
+ )
2295
+
2296
+ # handles expressions like tile[i,j] = 1.0
2297
+ add_builtin(
2298
+ "assign",
2299
+ input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "src": Scalar},
2300
+ value_func=tile_assign_value_func,
2301
+ group="Tile Primitives",
2302
+ export=False,
2303
+ hidden=True,
2304
+ missing_grad=True,
2305
+ )
2306
+
2307
+ add_builtin(
2308
+ "assign",
2309
+ input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "src": Scalar},
2270
2310
  value_func=tile_assign_value_func,
2271
- # dispatch_func=tile_assign_dispatch_func,
2272
- doc="""Assign a tile to a subrange of a destination tile at coordinates (i,j).
2311
+ group="Tile Primitives",
2312
+ export=False,
2313
+ hidden=True,
2314
+ missing_grad=True,
2315
+ )
2316
+
2317
+ add_builtin(
2318
+ "assign",
2319
+ input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "src": Scalar},
2320
+ value_func=tile_assign_value_func,
2321
+ group="Tile Primitives",
2322
+ export=False,
2323
+ hidden=True,
2324
+ missing_grad=True,
2325
+ )
2273
2326
 
2274
- :param t: The destination tile to assign to
2275
- :param i: Offset in the source tile along the first dimension
2276
- :param j: Offset in the source tile along the second dimensions
2277
- :param src: The source tile to read values from""",
2327
+ add_builtin(
2328
+ "assign",
2329
+ input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int, "src": Scalar},
2330
+ value_func=tile_assign_value_func,
2278
2331
  group="Tile Primitives",
2279
2332
  export=False,
2333
+ hidden=True,
2334
+ missing_grad=True,
2280
2335
  )
2281
2336
 
2282
2337
 
@@ -2286,7 +2341,7 @@ def tile_value_func(arg_types, arg_values):
2286
2341
  return Tile
2287
2342
 
2288
2343
  if len(arg_types) != 1:
2289
- raise RuntimeError("tile() requires 1 positional arg")
2344
+ raise TypeError(f"tile() takes exactly 1 positional argument but {len(arg_types)} were given")
2290
2345
 
2291
2346
  dtype = None
2292
2347
  length = None
@@ -2294,11 +2349,12 @@ def tile_value_func(arg_types, arg_values):
2294
2349
  if type_is_vector(arg_types["x"]):
2295
2350
  dtype = arg_types["x"]._wp_scalar_type_
2296
2351
  length = arg_types["x"]._shape_[0]
2352
+ shape = (length, warp.codegen.options["block_dim"])
2297
2353
  else:
2298
2354
  dtype = arg_types["x"]
2299
- length = 1
2355
+ shape = (warp.codegen.options["block_dim"],)
2300
2356
 
2301
- return Tile(dtype=dtype, M=length, N=warp.codegen.options["block_dim"], op="tile")
2357
+ return Tile(dtype=dtype, shape=shape, op="tile")
2302
2358
 
2303
2359
 
2304
2360
  add_builtin(
@@ -2306,14 +2362,14 @@ add_builtin(
2306
2362
  input_types={"x": Any},
2307
2363
  value_func=tile_value_func,
2308
2364
  variadic=True,
2309
- doc="""Constructs a new Tile from per-thread kernel values.
2365
+ doc="""Construct a new tile from per-thread kernel values.
2310
2366
 
2311
2367
  This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
2312
2368
 
2313
2369
  * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
2314
2370
  * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
2315
2371
 
2316
- :param x: A per-thread local value, e.g.: scalar, vector, or matrix.
2372
+ :param x: A per-thread local value, e.g. scalar, vector, or matrix.
2317
2373
  :returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
2318
2374
 
2319
2375
  This example shows how to create a linear sequence from thread variables:
@@ -2332,7 +2388,7 @@ add_builtin(
2332
2388
 
2333
2389
  .. code-block:: text
2334
2390
 
2335
- tile(m=1, n=16, storage=register) = [[0 2 4 6 8 ...]]
2391
+ [0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30] = tile(shape=(16), storage=register)
2336
2392
 
2337
2393
  """,
2338
2394
  group="Tile Primitives",
@@ -2346,38 +2402,40 @@ def untile_value_func(arg_types, arg_values):
2346
2402
  return Scalar
2347
2403
 
2348
2404
  if len(arg_types) != 1:
2349
- raise RuntimeError("untile() requires 1 positional arg")
2405
+ raise TypeError(f"untile() takes exactly 1 positional argument but {len(arg_types)} were given")
2350
2406
 
2351
2407
  t = arg_types["a"]
2352
2408
 
2353
2409
  if not is_tile(t):
2354
- raise RuntimeError(f"untile() accepts arguments of type tile only, got {arg_types[0]}")
2410
+ raise TypeError(f"untile() argument must be a tile, got {t!r}")
2355
2411
 
2356
- if t.N != warp.codegen.options["block_dim"]:
2357
- raise RuntimeError(
2358
- f"untile() argument must have the same length as the block width, got {t.N}, expected {warp.codegen.options['block_dim']}"
2412
+ if t.shape[-1] != warp.codegen.options["block_dim"]:
2413
+ raise ValueError(
2414
+ f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
2359
2415
  )
2360
2416
 
2361
- if t.M == 1:
2417
+ if len(t.shape) == 1:
2362
2418
  return t.dtype
2363
- elif t.M > 1:
2364
- return warp.types.vector(t.M, t.dtype)
2419
+ elif len(t.shape) == 2:
2420
+ return warp.types.vector(t.shape[0], t.dtype)
2421
+ else:
2422
+ raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
2365
2423
 
2366
2424
 
2367
2425
  add_builtin(
2368
2426
  "untile",
2369
- input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
2427
+ input_types={"a": Tile(dtype=Any, shape=Any)},
2370
2428
  value_func=untile_value_func,
2371
2429
  variadic=True,
2372
- doc="""Convert a Tile back to per-thread values.
2430
+ doc="""Convert a tile back to per-thread values.
2373
2431
 
2374
2432
  This function converts a block-wide tile back to per-thread values.
2375
2433
 
2376
- * If the input tile is 1-dimensional then the resulting value will be a per-thread scalar
2377
- * If the input tile is 2-dimensional then the resulting value will be a per-thread vector of length M
2434
+ * If the input tile is 1D, then the resulting value will be a per-thread scalar
2435
+ * If the input tile is 2D, then the resulting value will be a per-thread vector of length M
2378
2436
 
2379
2437
  :param a: A tile with dimensions ``shape=(M, block_dim)``
2380
- :returns: A single value per-thread with the same dtype as the tile
2438
+ :returns: A single value per-thread with the same data type as the tile
2381
2439
 
2382
2440
  This example shows how to create a linear sequence from thread variables:
2383
2441
 
@@ -2418,21 +2476,58 @@ def tile_extract_value_func(arg_types, arg_values):
2418
2476
  if arg_types is None:
2419
2477
  return Scalar
2420
2478
 
2421
- if len(arg_types) != 3:
2422
- raise RuntimeError("tile_extract() requires 3 positional args")
2423
-
2424
- if not is_tile(arg_types["a"]):
2425
- raise RuntimeError("tile_extract() argument 0 must be a tile")
2479
+ # force the input tile to shared memory
2480
+ arg_types["a"].storage = "shared"
2426
2481
 
2427
2482
  return arg_types["a"].dtype
2428
2483
 
2429
2484
 
2430
2485
  add_builtin(
2431
2486
  "tile_extract",
2432
- input_types={"a": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int},
2487
+ input_types={"a": Tile(dtype=Any, shape=Any), "i": int},
2433
2488
  value_func=tile_extract_value_func,
2434
- variadic=True,
2435
- doc="""Extracts a single element from the tile and returns it as a scalar type.
2489
+ variadic=False,
2490
+ doc="""Extract a single element from the tile and return it as a scalar type.
2491
+
2492
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
2493
+
2494
+ Note that this may incur additional synchronization if the source tile is a register tile.
2495
+
2496
+ :param a: Tile to extract the element from
2497
+ :param i: Coordinate of element on first dimension
2498
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
2499
+ group="Tile Primitives",
2500
+ hidden=True,
2501
+ export=False,
2502
+ )
2503
+
2504
+
2505
+ add_builtin(
2506
+ "tile_extract",
2507
+ input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int},
2508
+ value_func=tile_extract_value_func,
2509
+ variadic=False,
2510
+ doc="""Extract a single element from the tile and return it as a scalar type.
2511
+
2512
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
2513
+
2514
+ Note that this may incur additional synchronization if the source tile is a register tile.
2515
+
2516
+ :param a: Tile to extract the element from
2517
+ :param i: Coordinate of element on first dimension
2518
+ :param j: Coordinate of element on the second dimension
2519
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
2520
+ group="Tile Primitives",
2521
+ hidden=True,
2522
+ export=False,
2523
+ )
2524
+
2525
+ add_builtin(
2526
+ "tile_extract",
2527
+ input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int},
2528
+ value_func=tile_extract_value_func,
2529
+ variadic=False,
2530
+ doc="""Extract a single element from the tile and return it as a scalar type.
2436
2531
 
2437
2532
  This function will extract an element from the tile and broadcast its value to all threads in the block.
2438
2533
 
@@ -2441,8 +2536,32 @@ add_builtin(
2441
2536
  :param a: Tile to extract the element from
2442
2537
  :param i: Coordinate of element on first dimension
2443
2538
  :param j: Coordinate of element on the second dimension
2444
- :returns: The value of the element at the specified tile location, with the same type as the input tile's per-element dtype""",
2539
+ :param k: Coordinate of element on the third dimension
2540
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
2445
2541
  group="Tile Primitives",
2542
+ hidden=True,
2543
+ export=False,
2544
+ )
2545
+
2546
+ add_builtin(
2547
+ "tile_extract",
2548
+ input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int},
2549
+ value_func=tile_extract_value_func,
2550
+ variadic=False,
2551
+ doc="""Extract a single element from the tile and return it as a scalar type.
2552
+
2553
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
2554
+
2555
+ Note that this may incur additional synchronization if the source tile is a register tile.
2556
+
2557
+ :param a: Tile to extract the element from
2558
+ :param i: Coordinate of element on first dimension
2559
+ :param j: Coordinate of element on the second dimension
2560
+ :param k: Coordinate of element on the third dimension
2561
+ :param l: Coordinate of element on the fourth dimension
2562
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
2563
+ group="Tile Primitives",
2564
+ hidden=True,
2446
2565
  export=False,
2447
2566
  )
2448
2567
 
@@ -2453,12 +2572,12 @@ def tile_transpose_value_func(arg_types, arg_values):
2453
2572
  return Tile
2454
2573
 
2455
2574
  if len(arg_types) != 1:
2456
- raise RuntimeError("tile_transpose() requires 1 positional args")
2575
+ raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
2457
2576
 
2458
2577
  t = arg_types["a"]
2459
2578
 
2460
2579
  if not is_tile(t):
2461
- raise RuntimeError("tile_transpose() argument 0 must be a tile")
2580
+ raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
2462
2581
 
2463
2582
  layout = None
2464
2583
 
@@ -2473,8 +2592,7 @@ def tile_transpose_value_func(arg_types, arg_values):
2473
2592
 
2474
2593
  return Tile(
2475
2594
  dtype=t.dtype,
2476
- M=t.N,
2477
- N=t.M,
2595
+ shape=t.shape[::-1],
2478
2596
  op="transpose",
2479
2597
  storage=t.storage,
2480
2598
  strides=t.strides[::-1],
@@ -2485,12 +2603,13 @@ def tile_transpose_value_func(arg_types, arg_values):
2485
2603
 
2486
2604
  add_builtin(
2487
2605
  "tile_transpose",
2488
- input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
2606
+ input_types={"a": Tile(dtype=Any, shape=Any)},
2489
2607
  value_func=tile_transpose_value_func,
2490
2608
  variadic=True,
2491
2609
  doc="""Transpose a tile.
2492
2610
 
2493
- For shared memory tiles this operation will alias the input tile, register tiles will first be transferred to shared memory before transposition.
2611
+ For shared memory tiles, this operation will alias the input tile.
2612
+ Register tiles will first be transferred to shared memory before transposition.
2494
2613
 
2495
2614
  :param a: Tile to transpose with ``shape=(M,N)``
2496
2615
  :returns: Tile with ``shape=(N,M)``""",
@@ -2504,41 +2623,36 @@ def tile_broadcast_value_func(arg_types, arg_values):
2504
2623
  if arg_types is None:
2505
2624
  return Tile
2506
2625
 
2507
- if len(arg_types) != 3:
2508
- raise RuntimeError("tile_broadcast() requires 1 positional args")
2509
-
2510
2626
  t = arg_types["a"]
2511
- m = arg_values["m"]
2512
- n = arg_values["n"]
2513
2627
 
2514
- if not is_tile(t):
2515
- raise RuntimeError("tile_broadcast() argument 0 must be a tile")
2628
+ # target shape and strides
2629
+ target_shape = tile_unpack_shape(arg_values)
2630
+ target_strides = [0] * len(target_shape)
2516
2631
 
2517
- # try to broadcast last dimension
2518
- if t.N == 1:
2519
- stride_n = 0
2520
- elif t.N == n:
2521
- stride_n = t.strides[1]
2522
- else:
2523
- raise RuntimeError(
2524
- f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
2525
- )
2632
+ offset = len(target_shape) - len(t.shape)
2526
2633
 
2527
- # try to broadcast first dimension
2528
- if t.M == 1:
2529
- stride_m = 0
2530
- elif t.M == m:
2531
- stride_m = t.strides[0]
2532
- else:
2533
- raise RuntimeError(
2534
- f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
2535
- )
2634
+ # compute target strides
2635
+ for i in reversed(range(len(target_shape))):
2636
+ j = i - offset
2637
+
2638
+ if j < 0:
2639
+ target_strides[i] = 0
2640
+ else:
2641
+ # try to broadcast each dimension
2642
+ if t.shape[j] == 1:
2643
+ target_strides[i] = 0
2644
+ elif t.shape[j] == target_shape[i]:
2645
+ target_strides[i] = t.strides[j]
2646
+ else:
2647
+ raise ValueError(
2648
+ f"tile_broadcast() cannot broadcast dimension {t.shape[j]} into {target_shape[i]} at index {i}"
2649
+ )
2536
2650
 
2537
2651
  # force the input tile to shared memory
2538
2652
  t.storage = "shared"
2539
2653
 
2540
2654
  tile_type = Tile(
2541
- dtype=t.dtype, M=m, N=n, op="broadcast", storage=t.storage, strides=(stride_m, stride_n), owner=False
2655
+ dtype=t.dtype, shape=target_shape, op="broadcast", storage=t.storage, strides=target_strides, owner=False
2542
2656
  )
2543
2657
  return tile_type
2544
2658
 
@@ -2547,8 +2661,8 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
2547
2661
  tile = arg_values["a"]
2548
2662
 
2549
2663
  template_args = []
2550
- template_args.append(return_type.M)
2551
- template_args.append(return_type.N)
2664
+ template_args.append(return_type.shape[0])
2665
+ template_args.append(return_type.shape[1])
2552
2666
  template_args.append(return_type.strides[0])
2553
2667
  template_args.append(return_type.strides[1])
2554
2668
 
@@ -2557,15 +2671,18 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
2557
2671
 
2558
2672
  add_builtin(
2559
2673
  "tile_broadcast",
2560
- input_types={"a": Tile(dtype=Any, M=Any, N=Any), "m": int, "n": int},
2674
+ input_types={"a": Tile(dtype=Any, shape=Any), "shape": Tuple[int, ...]},
2561
2675
  value_func=tile_broadcast_value_func,
2562
2676
  dispatch_func=tile_broadcast_dispatch_func,
2563
- variadic=True,
2677
+ variadic=False,
2564
2678
  doc="""Broadcast a tile.
2565
2679
 
2566
- This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
2680
+ This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n).
2681
+
2682
+ Broadcasting follows NumPy broadcast rules.
2567
2683
 
2568
2684
  :param a: Tile to broadcast
2685
+ :param shape: The shape to broadcast to
2569
2686
  :returns: Tile with broadcast ``shape=(m, n)``""",
2570
2687
  group="Tile Primitives",
2571
2688
  export=False,
@@ -2575,19 +2692,10 @@ add_builtin(
2575
2692
  def tile_matmul_value_func(arg_types, arg_values):
2576
2693
  # return generic type (for doc builds)
2577
2694
  if arg_types is None:
2578
- return Tile(dtype=Any, M=Any, N=Any)
2695
+ return Tile(dtype=Any, shape=Any)
2579
2696
 
2580
2697
  if len(arg_types) != 3:
2581
- raise RuntimeError("tile_matmul() requires 4 positional args")
2582
-
2583
- if not is_tile(arg_types["a"]):
2584
- raise RuntimeError("tile_matmul() argument 0 must be a tile")
2585
-
2586
- if not is_tile(arg_types["b"]):
2587
- raise RuntimeError("tile_matmul() argument 1 must be an tile")
2588
-
2589
- if not isinstance(arg_types["out"], Tile):
2590
- raise RuntimeError("tile_matmul() output argument must be a tile")
2698
+ raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
2591
2699
 
2592
2700
  return None
2593
2701
 
@@ -2622,17 +2730,17 @@ add_builtin(
2622
2730
  def tile_sum_value_func(arg_types, arg_values):
2623
2731
  # return generic type (for doc builds)
2624
2732
  if arg_types is None:
2625
- return Tile(dtype=Any, M=1, N=1)
2733
+ return Tile(dtype=Any, shape=(1,))
2626
2734
 
2627
2735
  if len(arg_types) != 1:
2628
- raise RuntimeError("tile_sum() requires 1 positional args")
2736
+ raise TypeError(f"tile_sum() takes exactly 1 positional argument but {len(arg_types)} were given")
2629
2737
 
2630
2738
  a = arg_types["a"]
2631
2739
 
2632
2740
  if not is_tile(a):
2633
- raise RuntimeError("tile_sum() argument 0 must be a tile")
2741
+ raise TypeError(f"tile_sum() argument must be a tile, got {a!r}")
2634
2742
 
2635
- return Tile(dtype=a.dtype, M=1, N=1, op="sum")
2743
+ return Tile(dtype=a.dtype, shape=(1,), op="sum")
2636
2744
 
2637
2745
 
2638
2746
  add_builtin(
@@ -2643,7 +2751,7 @@ add_builtin(
2643
2751
  doc="""Cooperatively compute the sum of the tile elements using all threads in the block.
2644
2752
 
2645
2753
  :param a: The tile to compute the sum of
2646
- :returns: A single-element tile with dimensions of (1,1) holding the sum
2754
+ :returns: A single-element tile holding the sum
2647
2755
 
2648
2756
  Example:
2649
2757
 
@@ -2652,7 +2760,7 @@ add_builtin(
2652
2760
  @wp.kernel
2653
2761
  def compute():
2654
2762
 
2655
- t = wp.tile_ones(dtype=float, m=16, n=16)
2763
+ t = wp.tile_ones(dtype=float, shape=(16, 16))
2656
2764
  s = wp.tile_sum(t)
2657
2765
 
2658
2766
  print(s)
@@ -2663,7 +2771,7 @@ add_builtin(
2663
2771
 
2664
2772
  .. code-block:: text
2665
2773
 
2666
- tile(m=1, n=1, storage=register) = [[256]]
2774
+ [256] = tile(shape=(1), storage=register)
2667
2775
 
2668
2776
  """,
2669
2777
  group="Tile Primitives",
@@ -2674,17 +2782,17 @@ add_builtin(
2674
2782
  def tile_min_value_func(arg_types, arg_values):
2675
2783
  # return generic type (for doc builds)
2676
2784
  if arg_types is None:
2677
- return Tile(dtype=Any, M=1, N=1)
2785
+ return Tile(dtype=Any, shape=(1,))
2678
2786
 
2679
2787
  if len(arg_types) != 1:
2680
- raise RuntimeError("tile_min() requires 1 positional args")
2788
+ raise TypeError(f"tile_min() takes exactly 1 positional argument but {len(arg_types)} were given")
2681
2789
 
2682
2790
  a = arg_types["a"]
2683
2791
 
2684
2792
  if not is_tile(a):
2685
- raise RuntimeError("tile_min() argument 0 must be a tile")
2793
+ raise TypeError(f"tile_min() argument must be a tile, got {a!r}")
2686
2794
 
2687
- return Tile(dtype=a.dtype, M=1, N=1, op="min")
2795
+ return Tile(dtype=a.dtype, shape=(1,), op="min")
2688
2796
 
2689
2797
 
2690
2798
  add_builtin(
@@ -2695,7 +2803,7 @@ add_builtin(
2695
2803
  doc="""Cooperatively compute the minimum of the tile elements using all threads in the block.
2696
2804
 
2697
2805
  :param a: The tile to compute the minimum of
2698
- :returns: A single-element tile with dimensions of (1,1) holding the minimum value
2806
+ :returns: A single-element tile holding the minimum value
2699
2807
 
2700
2808
  Example:
2701
2809
 
@@ -2716,7 +2824,7 @@ add_builtin(
2716
2824
 
2717
2825
  .. code-block:: text
2718
2826
 
2719
- tile(m=1, n=1, storage=register) = [[64 ]]
2827
+ [64] = tile(shape=(1), storage=register)
2720
2828
 
2721
2829
  """,
2722
2830
  group="Tile Primitives",
@@ -2727,28 +2835,28 @@ add_builtin(
2727
2835
  def tile_max_value_func(arg_types, arg_values):
2728
2836
  # return generic type (for doc builds)
2729
2837
  if arg_types is None:
2730
- return Tile(dtype=Any, M=1, N=1)
2838
+ return Tile(dtype=Any, shape=(1,))
2731
2839
 
2732
2840
  if len(arg_types) != 1:
2733
- raise RuntimeError("tile_max() requires 1 positional args")
2841
+ raise TypeError(f"tile_max() takes exactly 1 positional argument but {len(arg_types)} were given")
2734
2842
 
2735
2843
  a = arg_types["a"]
2736
2844
 
2737
2845
  if not is_tile(a):
2738
- raise RuntimeError("tile_max() argument 0 must be a tile")
2846
+ raise TypeError(f"tile_max() argument must be a tile, got {a!r}")
2739
2847
 
2740
- return Tile(dtype=a.dtype, M=1, N=1, op="min")
2848
+ return Tile(dtype=a.dtype, shape=(1,), op="min")
2741
2849
 
2742
2850
 
2743
2851
  add_builtin(
2744
2852
  "tile_max",
2745
- input_types={"a": Tile},
2853
+ input_types={"a": Tile(dtype=Any, shape=Any)},
2746
2854
  value_func=tile_max_value_func,
2747
- variadic=True,
2855
+ variadic=False,
2748
2856
  doc="""Cooperatively compute the maximum of the tile elements using all threads in the block.
2749
2857
 
2750
2858
  :param a: The tile to compute the maximum from
2751
- :returns: A single-element tile with dimensions of (1,1) holding the maximum value
2859
+ :returns: A single-element tile holding the maximum value
2752
2860
 
2753
2861
  Example:
2754
2862
 
@@ -2768,7 +2876,7 @@ add_builtin(
2768
2876
 
2769
2877
  .. code-block:: text
2770
2878
 
2771
- tile(m=1, n=1, storage=register) = [[127 ]]
2879
+ [127] = tile(shape=(1), storage=register)
2772
2880
 
2773
2881
  """,
2774
2882
  group="Tile Primitives",
@@ -2779,15 +2887,14 @@ add_builtin(
2779
2887
  # does type propagation for load()
2780
2888
  def tile_reduce_value_func(arg_types, arg_values):
2781
2889
  if arg_types is None:
2782
- return Tile(dtype=Any, M=Any, N=Any)
2890
+ return Tile(dtype=Any, shape=(1,))
2783
2891
 
2784
2892
  a = arg_types["a"]
2785
2893
 
2786
- # check all args are tiles
2787
2894
  if not is_tile(a):
2788
- raise RuntimeError(f"tile_reduce() arguments must be tiles, got type {a}")
2895
+ raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
2789
2896
 
2790
- return Tile(dtype=a.dtype, M=1, N=1, op="reduce")
2897
+ return Tile(dtype=a.dtype, shape=(1,), op="reduce")
2791
2898
 
2792
2899
 
2793
2900
  def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
@@ -2798,7 +2905,7 @@ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any,
2798
2905
 
2799
2906
  add_builtin(
2800
2907
  "tile_reduce",
2801
- input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)},
2908
+ input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
2802
2909
  value_func=tile_reduce_value_func,
2803
2910
  native_func="tile_reduce",
2804
2911
  doc="""Apply a custom reduction operator across the tile.
@@ -2806,8 +2913,8 @@ add_builtin(
2806
2913
  This function cooperatively performs a reduction using the provided operator across the tile.
2807
2914
 
2808
2915
  :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
2809
- :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
2810
- :returns: A single-element tile with ``shape=(1,1)`` with the same datatype as the input tile.
2916
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
2917
+ :returns: A single-element tile with the same data type as the input tile.
2811
2918
 
2812
2919
  Example:
2813
2920
 
@@ -2827,7 +2934,7 @@ add_builtin(
2827
2934
 
2828
2935
  .. code-block:: text
2829
2936
 
2830
- tile(m=1, n=1, storage=register) = [[362880]]
2937
+ [362880] = tile(shape=(1), storage=register)
2831
2938
  """,
2832
2939
  group="Tile Primitives",
2833
2940
  export=False,
@@ -2839,26 +2946,19 @@ add_builtin(
2839
2946
  # does type propagation for load()
2840
2947
  def tile_unary_map_value_func(arg_types, arg_values):
2841
2948
  if arg_types is None:
2842
- return Tile(dtype=Any, M=Any, N=Any)
2949
+ return Tile(dtype=Any, shape=Any)
2843
2950
 
2844
2951
  a = arg_types["a"]
2845
2952
 
2846
- # check all args are tiles
2847
2953
  if not is_tile(a):
2848
- raise RuntimeError(f"tile_map() arguments must be tiles, got type {a}")
2954
+ raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
2849
2955
 
2850
2956
  return TileUnaryMap(a)
2851
2957
 
2852
2958
 
2853
- # def tile_map_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2854
- # func_args = (args["op"], *args["args"])
2855
- # template_args = ()
2856
- # return (func_args, template_args)
2857
-
2858
-
2859
2959
  add_builtin(
2860
2960
  "tile_map",
2861
- input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)},
2961
+ input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
2862
2962
  value_func=tile_unary_map_value_func,
2863
2963
  # dispatch_func=tile_map_dispatch_func,
2864
2964
  # variadic=True,
@@ -2868,8 +2968,8 @@ add_builtin(
2868
2968
  This function cooperatively applies a unary function to each element of the tile using all threads in the block.
2869
2969
 
2870
2970
  :param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
2871
- :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
2872
- :returns: A tile with the same dimensions and datatype as the input tile.
2971
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
2972
+ :returns: A tile with the same dimensions and data type as the input tile.
2873
2973
 
2874
2974
  Example:
2875
2975
 
@@ -2889,7 +2989,7 @@ add_builtin(
2889
2989
 
2890
2990
  .. code-block:: text
2891
2991
 
2892
- tile(m=1, n=10, storage=register) = [[0 0.0998334 0.198669 0.29552 ...]]
2992
+ [0 0.0998334 0.198669 0.29552 0.389418 0.479426 0.564642 0.644218 0.717356 0.783327] = tile(shape=(10), storage=register)
2893
2993
  """,
2894
2994
  group="Tile Primitives",
2895
2995
  export=False,
@@ -2898,34 +2998,37 @@ add_builtin(
2898
2998
 
2899
2999
  def tile_binary_map_value_func(arg_types, arg_values):
2900
3000
  if arg_types is None:
2901
- return Tile(dtype=Any, M=Any, N=Any)
3001
+ return Tile(dtype=Any, shape=Any)
2902
3002
 
2903
3003
  a = arg_types["a"]
2904
3004
  b = arg_types["b"]
2905
3005
 
2906
3006
  # check all args are tiles
2907
3007
  if not is_tile(a):
2908
- raise RuntimeError(f"tile_map() arguments must be tiles, got type {a}")
3008
+ raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
2909
3009
 
2910
3010
  if not is_tile(b):
2911
- raise RuntimeError(f"tile_map() arguments must be tiles, got type {b}")
3011
+ raise TypeError(f"tile_map() 'b' argument must be a tile, got {b!r}")
2912
3012
 
2913
- # use first argument to define output type
3013
+ # ensure types equal
2914
3014
  if not types_equal(a.dtype, b.dtype):
2915
- raise RuntimeError(f"tile_map() arguments must all have the same type {a.dtype} != {b.dtype}")
3015
+ raise TypeError(f"tile_map() arguments must have the same dtype, got {a.dtype} and {b.dtype}")
2916
3016
 
2917
- if a.M != b.M:
2918
- raise RuntimeError(f"tile_map() arguments must all have the same m dimension {a.M} != {b.M}")
3017
+ if len(a.shape) != len(b.shape):
3018
+ raise ValueError(
3019
+ f"tile_map() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
3020
+ )
2919
3021
 
2920
- if a.N != b.N:
2921
- raise RuntimeError(f"tile_map() arguments must all have the same n dimension {a.N} != {b.N}")
3022
+ for i in range(len(a.shape)):
3023
+ if a.shape[i] != b.shape[i]:
3024
+ raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape[i]} and {b.shape[i]}")
2922
3025
 
2923
3026
  return TileBinaryMap(a, b)
2924
3027
 
2925
3028
 
2926
3029
  add_builtin(
2927
3030
  "tile_map",
2928
- input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
3031
+ input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
2929
3032
  value_func=tile_binary_map_value_func,
2930
3033
  # dispatch_func=tile_map_dispatch_func,
2931
3034
  # variadic=True,
@@ -2948,7 +3051,7 @@ add_builtin(
2948
3051
  def compute():
2949
3052
 
2950
3053
  a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
2951
- b = wp.tile_ones(m=1, n=10, dtype=float)
3054
+ b = wp.tile_ones(shape=10, dtype=float)
2952
3055
 
2953
3056
  s = wp.tile_map(wp.add, a, b)
2954
3057
 
@@ -2960,7 +3063,7 @@ add_builtin(
2960
3063
 
2961
3064
  .. code-block:: text
2962
3065
 
2963
- tile(m=1, n=10, storage=register) = [[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]]""",
3066
+ [1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9] = tile(shape=(10), storage=register)""",
2964
3067
  group="Tile Primitives",
2965
3068
  export=False,
2966
3069
  )
@@ -3075,6 +3178,18 @@ add_builtin(
3075
3178
  )
3076
3179
 
3077
3180
 
3181
+ def mlp_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
3182
+ warp.utils.warn(
3183
+ "wp.mlp() is deprecated and will be removed in a future\nversion. Use tile primitives instead.",
3184
+ category=DeprecationWarning,
3185
+ )
3186
+
3187
+ func_args = tuple(args.values())
3188
+ template_args = ()
3189
+
3190
+ return (func_args, template_args)
3191
+
3192
+
3078
3193
  add_builtin(
3079
3194
  "mlp",
3080
3195
  input_types={
@@ -3086,9 +3201,13 @@ add_builtin(
3086
3201
  "out": array(dtype=float, ndim=2),
3087
3202
  },
3088
3203
  value_type=None,
3204
+ dispatch_func=mlp_dispatch_func,
3089
3205
  skip_replay=True,
3090
3206
  doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
3091
3207
 
3208
+ .. deprecated:: 1.6
3209
+ Use :doc:`tile primitives </modules/tiles>` instead.
3210
+
3092
3211
  :param weights: A layer's network weights with dimensions ``(m, n)``.
3093
3212
  :param bias: An array with dimensions ``(n)``.
3094
3213
  :param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
@@ -5008,6 +5127,43 @@ add_builtin(
5008
5127
  )
5009
5128
 
5010
5129
 
5130
+ # implements vector[idx] += scalar
5131
+ add_builtin(
5132
+ "augassign_add",
5133
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5134
+ value_type=None,
5135
+ hidden=True,
5136
+ group="Utility",
5137
+ )
5138
+
5139
+ # implements quaternion[idx] += scalar
5140
+ add_builtin(
5141
+ "augassign_add",
5142
+ input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5143
+ value_type=None,
5144
+ hidden=True,
5145
+ group="Utility",
5146
+ )
5147
+
5148
+ # implements vector[idx] -= scalar
5149
+ add_builtin(
5150
+ "augassign_sub",
5151
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5152
+ value_type=None,
5153
+ hidden=True,
5154
+ group="Utility",
5155
+ )
5156
+
5157
+ # implements quaternion[idx] -= scalar
5158
+ add_builtin(
5159
+ "augassign_sub",
5160
+ input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5161
+ value_type=None,
5162
+ hidden=True,
5163
+ group="Utility",
5164
+ )
5165
+
5166
+
5011
5167
  def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
5012
5168
  mat_type = arg_types["a"]
5013
5169
  row_type = mat_type._wp_row_type_
@@ -5057,22 +5213,42 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
5057
5213
  return mat_size == vec_size and mat_type == vec_type
5058
5214
 
5059
5215
 
5060
- # implements matrix[i,j] = scalar
5216
+ # implements matrix[i,j] = scalar
5217
+ add_builtin(
5218
+ "assign",
5219
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5220
+ value_func=matrix_assign_value_func,
5221
+ hidden=True,
5222
+ group="Utility",
5223
+ )
5224
+
5225
+
5226
+ # implements matrix[i] = vector
5227
+ add_builtin(
5228
+ "assign",
5229
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5230
+ constraint=matrix_vector_sametype,
5231
+ value_func=matrix_assign_value_func,
5232
+ hidden=True,
5233
+ group="Utility",
5234
+ )
5235
+
5236
+
5237
+ # implements matrix[i,j] += scalar
5061
5238
  add_builtin(
5062
- "assign",
5239
+ "augassign_add",
5063
5240
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5064
- value_func=matrix_assign_value_func,
5241
+ value_type=None,
5065
5242
  hidden=True,
5066
5243
  group="Utility",
5067
5244
  )
5068
5245
 
5069
5246
 
5070
- # implements matrix[i] = vector
5247
+ # implements matrix[i,j] -= scalar
5071
5248
  add_builtin(
5072
- "assign",
5073
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5074
- constraint=matrix_vector_sametype,
5075
- value_func=matrix_assign_value_func,
5249
+ "augassign_sub",
5250
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5251
+ value_type=None,
5076
5252
  hidden=True,
5077
5253
  group="Utility",
5078
5254
  )
@@ -5644,19 +5820,19 @@ add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.boo
5644
5820
  # Tile operators
5645
5821
  def tile_unary_value_func(arg_types, arg_values):
5646
5822
  if arg_types is None:
5647
- return Tile(dtype=Any, M=Any, N=Any)
5823
+ return Tile(dtype=Any, shape=Any)
5648
5824
 
5649
5825
  t = arg_types["x"]
5650
5826
 
5651
5827
  if not is_tile(t):
5652
- raise RuntimeError("Expected tile for unary expression")
5828
+ raise TypeError(f"Expected tile for unary expression, got {t}")
5653
5829
 
5654
5830
  return TileUnaryMap(t)
5655
5831
 
5656
5832
 
5657
5833
  def tile_scalar_mul_value_func(arg_types, arg_values):
5658
5834
  if arg_types is None:
5659
- return Tile(dtype=Any, M=Any, N=Any)
5835
+ return Tile(dtype=Any, shape=Any)
5660
5836
 
5661
5837
  x = arg_types["x"]
5662
5838
  y = arg_types["y"]
@@ -5664,25 +5840,21 @@ def tile_scalar_mul_value_func(arg_types, arg_values):
5664
5840
  # tile*scalar
5665
5841
  if is_tile(x):
5666
5842
  if x.dtype != y:
5667
- raise RuntimeError(
5668
- "Scalar factor should have the same type as tile for tile*scalar, tile type: {x} scalar type: {y}"
5669
- )
5843
+ raise TypeError(f"Scalar factor type {y} does not match tile type {x.dtype} for tile*scalar")
5670
5844
 
5671
- return TileBinaryMap(x, TileConstant(y, x.M, x.N))
5845
+ return TileBinaryMap(x, TileConstant(y, x.shape))
5672
5846
 
5673
5847
  # scalar*tile
5674
5848
  if is_tile(y):
5675
5849
  if y.dtype != x:
5676
- raise RuntimeError(
5677
- "Scalar factor should have the same type as tile for scalar*tile, tile type: {x} scalar type: {y}"
5678
- )
5850
+ raise TypeError(f"Scalar factor type {x} does not match tile type {y.dtype} for scalar*tile")
5679
5851
 
5680
- return TileBinaryMap(TileConstant(x, y.M, y.N), y)
5852
+ return TileBinaryMap(TileConstant(x, y.shape), y)
5681
5853
 
5682
5854
 
5683
5855
  add_builtin(
5684
5856
  "neg",
5685
- input_types={"x": Tile(dtype=Any, M=Any, N=Any)},
5857
+ input_types={"x": Tile(dtype=Any, shape=Any)},
5686
5858
  value_func=tile_unary_value_func,
5687
5859
  doc="Negate each element of a tile",
5688
5860
  export=False,
@@ -5692,7 +5864,7 @@ add_builtin(
5692
5864
 
5693
5865
  add_builtin(
5694
5866
  "add",
5695
- input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
5867
+ input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
5696
5868
  value_func=tile_binary_map_value_func,
5697
5869
  # dispatch_func=tile_map_dispatch_func,
5698
5870
  # variadic=True,
@@ -5702,9 +5874,22 @@ add_builtin(
5702
5874
  export=False,
5703
5875
  )
5704
5876
 
5877
+ add_builtin(
5878
+ "sub",
5879
+ input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
5880
+ value_func=tile_binary_map_value_func,
5881
+ # dispatch_func=tile_map_dispatch_func,
5882
+ # variadic=True,
5883
+ native_func="tile_sub",
5884
+ doc="Subtract each element b from a",
5885
+ group="Tile Primitives",
5886
+ export=False,
5887
+ )
5888
+
5889
+
5705
5890
  add_builtin(
5706
5891
  "mul",
5707
- input_types={"x": Tile(dtype=Any, M=Any, N=Any), "y": Scalar},
5892
+ input_types={"x": Tile(dtype=Any, shape=Any), "y": Scalar},
5708
5893
  value_func=tile_scalar_mul_value_func,
5709
5894
  doc="Multiply each element of a tile by a scalar",
5710
5895
  export=False,
@@ -5714,7 +5899,7 @@ add_builtin(
5714
5899
 
5715
5900
  add_builtin(
5716
5901
  "mul",
5717
- input_types={"x": Scalar, "y": Tile(dtype=Any, M=Any, N=Any)},
5902
+ input_types={"x": Scalar, "y": Tile(dtype=Any, shape=Any)},
5718
5903
  value_func=tile_scalar_mul_value_func,
5719
5904
  doc="Multiply each element of a tile by a scalar",
5720
5905
  export=False,
@@ -5723,6 +5908,70 @@ add_builtin(
5723
5908
  )
5724
5909
 
5725
5910
 
5911
+ def tile_diag_add_value_func(arg_types, arg_values):
5912
+ if arg_types is None:
5913
+ return Tile(dtype=Any, shape=Any)
5914
+
5915
+ a = arg_types["a"]
5916
+ d = arg_types["d"]
5917
+
5918
+ if not is_tile(a):
5919
+ raise TypeError(f"tile_diag_add() 'a' argument must be a tile, got {a!r}")
5920
+
5921
+ if not is_tile(d):
5922
+ raise TypeError(f"tile_diag_add() 'd' argument must be a tile, got {d!r}")
5923
+
5924
+ if not types_equal(a.dtype, d.dtype):
5925
+ raise TypeError(f"tile_diag_add() arguments must have the same dtype, got {a.dtype} and {d.dtype}")
5926
+
5927
+ if len(a.shape) != 2:
5928
+ raise TypeError("tile_diag_add() argument 'a' must be a 2D tile")
5929
+
5930
+ if len(d.shape) != 1:
5931
+ raise TypeError("tile_diag_add() argument 'd' must be a 1D tile")
5932
+
5933
+ if a.shape[0] != a.shape[1]:
5934
+ raise ValueError("tile_diag_add() 'a' argument must be square")
5935
+
5936
+ if a.shape[0] != d.shape[0]:
5937
+ raise ValueError(
5938
+ f"tile_diag_add() 'd' argument must have the same number of elements as the number of rows in 'a', "
5939
+ f"got {d.shape[0]} elements in 'd' and {a.shape[0]} rows in 'a'"
5940
+ )
5941
+
5942
+ # use first argument to define output type
5943
+ return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
5944
+
5945
+
5946
+ def tile_diag_add_lto_dispatch_func(
5947
+ arg_types: Mapping[str, type],
5948
+ return_type: Any,
5949
+ return_values: List[Var],
5950
+ arg_values: Mapping[str, Var],
5951
+ options: Mapping[str, Any],
5952
+ builder: warp.context.ModuleBuilder,
5953
+ ):
5954
+ a = arg_values["a"]
5955
+ d = arg_values["d"]
5956
+ # force the storage type of the input variables to shared memory
5957
+ a.type.storage = "shared"
5958
+ d.type.storage = "shared"
5959
+ out = return_values[0]
5960
+ return ((a, d, out), [], [], 0)
5961
+
5962
+
5963
+ add_builtin(
5964
+ "tile_diag_add",
5965
+ input_types={"a": Tile(dtype=Any, shape=Any), "d": Tile(dtype=Any, shape=Any)},
5966
+ value_func=tile_diag_add_value_func,
5967
+ lto_dispatch_func=tile_diag_add_lto_dispatch_func,
5968
+ native_func="tile_diag_add",
5969
+ doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
5970
+ group="Tile Primitives",
5971
+ export=False,
5972
+ )
5973
+
5974
+
5726
5975
  ##
5727
5976
  ## MathDx, LTOIR-based, Tile functions
5728
5977
  ##
@@ -5734,24 +5983,25 @@ add_builtin(
5734
5983
  def tile_matmul_generic_value_func(arg_types, arg_values):
5735
5984
  # return generic type (for doc builds)
5736
5985
  if arg_types is None:
5737
- return Tile(dtype=Any, M=Any, N=Any)
5986
+ return Tile(dtype=Any, shape=Any)
5738
5987
 
5739
5988
  a = arg_types["a"]
5740
5989
  b = arg_types["b"]
5741
5990
 
5742
5991
  if not is_tile(a):
5743
- raise RuntimeError("tile_matmul() argument 0 must be a tile")
5992
+ raise TypeError(f"tile_matmul() 'a' argument must be a tile, got {a!r}")
5993
+
5744
5994
  if not is_tile(b):
5745
- raise RuntimeError("tile_matmul() argument 1 must be an tile")
5995
+ raise TypeError(f"tile_matmul() 'b' argument must be a tile, got {b!r}")
5746
5996
 
5747
5997
  # out = wp.tile_matmul(a, b)
5748
5998
  if len(arg_types) == 2:
5749
- return Tile(dtype=a.dtype, M=a.M, N=b.N, storage="shared")
5999
+ return Tile(dtype=a.dtype, shape=(a.shape[0], b.shape[1]), storage="shared")
5750
6000
 
5751
6001
  # wp.tile_matmul(a, b, out)
5752
6002
  elif len(arg_types) == 3:
5753
6003
  if not is_tile(arg_types["out"]):
5754
- raise RuntimeError("tile_matmul() output argument must be a tile")
6004
+ raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {arg_types['out']!r}")
5755
6005
 
5756
6006
  return None
5757
6007
 
@@ -5774,16 +6024,20 @@ def tile_matmul_generic_lto_dispatch_func(
5774
6024
  accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value
5775
6025
  out = arg_values["out"]
5776
6026
 
5777
- if any(not is_tile(arg.type) for arg in [a, b, out]):
5778
- raise RuntimeError("tile_matmul() requires three Tile arguments")
6027
+ if not is_tile(out.type):
6028
+ raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {out!r}")
5779
6029
 
5780
6030
  if any(arg.type.dtype not in [float16, float32, float64, vec2h, vec2f, vec2d] for arg in [a, b, out]):
5781
- raise RuntimeError(
6031
+ raise TypeError(
5782
6032
  "tile_matmul() arguments must be tiles of float16, float32 or float64, vec2h, vec2f, vec2d entries"
5783
6033
  )
5784
6034
 
5785
- if (a.type.N != b.type.M) or (a.type.M != out.type.M) or (b.type.N != out.type.N):
5786
- raise RuntimeError("tile_matmul(A, B, C) requires sizes of A, B and C to be consistent for a matmul")
6035
+ if (
6036
+ (a.type.shape[1] != b.type.shape[0])
6037
+ or (a.type.shape[0] != out.type.shape[0])
6038
+ or (b.type.shape[1] != out.type.shape[1])
6039
+ ):
6040
+ raise ValueError("tile_matmul(A, B, C) requires sizes of A, B and C to be consistent for a matmul")
5787
6041
 
5788
6042
  # set the storage type to the inputs to shared
5789
6043
  a.type.storage = "shared"
@@ -5805,18 +6059,18 @@ def tile_matmul_generic_lto_dispatch_func(
5805
6059
  return ("wp::vec2f", 5, 1)
5806
6060
  if dtype == vec2d:
5807
6061
  return ("wp::vec2d", 6, 1)
5808
- raise RuntimeError("Unsupported input type in tile_matmul")
6062
+ raise TypeError("Unsupported input type in tile_matmul")
5809
6063
 
5810
6064
  def cublasdx_arrangement_map(layout):
5811
6065
  if layout == "colmajor":
5812
6066
  return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
5813
6067
  if layout == "rowmajor":
5814
6068
  return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
5815
- raise RuntimeError("Unsupported layout in tile_matmul")
6069
+ raise ValueError("Unsupported layout in tile_matmul")
5816
6070
 
5817
6071
  # generate the LTO
5818
- M, K = a.type.M, a.type.N
5819
- _, N = b.type.M, b.type.N
6072
+ M, K = a.type.shape[0], a.type.shape[1]
6073
+ _, N = b.type.shape[0], b.type.shape[1]
5820
6074
  num_threads = options["block_dim"]
5821
6075
  arch = options["output_arch"]
5822
6076
 
@@ -5829,7 +6083,8 @@ def tile_matmul_generic_lto_dispatch_func(
5829
6083
  c_arrangement = cublasdx_arrangement_map(clayout)
5830
6084
 
5831
6085
  if a_type != b_type or a_type != c_type:
5832
- raise RuntimeError("time_matmul(A, B, C) requires all inputs to be real or complex")
6086
+ raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
6087
+
5833
6088
  element_type = a_type
5834
6089
 
5835
6090
  lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
@@ -5924,15 +6179,16 @@ def tile_matmul_generic_lto_dispatch_func(
5924
6179
  ),
5925
6180
  template_args,
5926
6181
  [lto_forward, lto_backward_A, lto_backward_B],
6182
+ 0,
5927
6183
  )
5928
6184
 
5929
6185
 
5930
6186
  add_builtin(
5931
6187
  "tile_matmul",
5932
6188
  input_types={
5933
- "a": Tile(dtype=Any, M=Any, N=Any),
5934
- "b": Tile(dtype=Any, M=Any, N=Any),
5935
- "out": Tile(dtype=Any, M=Any, N=Any),
6189
+ "a": Tile(dtype=Any, shape=Any),
6190
+ "b": Tile(dtype=Any, shape=Any),
6191
+ "out": Tile(dtype=Any, shape=Any),
5936
6192
  },
5937
6193
  value_func=tile_matmul_generic_value_func,
5938
6194
  lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
@@ -5956,7 +6212,7 @@ add_builtin(
5956
6212
 
5957
6213
  add_builtin(
5958
6214
  "tile_matmul",
5959
- input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
6215
+ input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
5960
6216
  value_func=tile_matmul_generic_value_func,
5961
6217
  lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
5962
6218
  variadic=False,
@@ -5983,16 +6239,23 @@ add_builtin(
5983
6239
  ##
5984
6240
  def tile_fft_generic_value_func(arg_types, arg_values):
5985
6241
  if arg_types is None:
5986
- return Tile(dtype=Any, M=Any, N=Any)
6242
+ return Tile(dtype=Any, shape=Any)
5987
6243
 
5988
6244
  if len(arg_types) != 1:
5989
- raise RuntimeError("tile_fft() requires 1 positional args")
6245
+ raise TypeError(f"tile_fft() takes exactly 1 positional argument but {len(arg_types)} were given")
5990
6246
 
5991
- if not is_tile(arg_types["inout"]):
5992
- raise RuntimeError("tile_fft() argument 0 must be a tile")
6247
+ inout = arg_types["inout"]
5993
6248
 
5994
- if arg_types["inout"].storage != "register":
5995
- raise RuntimeError("tile_fft() input/output argument must have register memory storage")
6249
+ if not is_tile(inout):
6250
+ raise TypeError(f"tile_fft() argument must be a tile, got {inout!r}")
6251
+
6252
+ if inout.storage != "register":
6253
+ raise ValueError(f"tile_fft() argument must have 'register' storage, got {inout.storage}")
6254
+
6255
+ if inout.dtype not in [vec2f, vec2d]:
6256
+ raise TypeError(
6257
+ f"tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries, got {inout.dtype!r}"
6258
+ )
5996
6259
 
5997
6260
  return None
5998
6261
 
@@ -6009,19 +6272,13 @@ def tile_fft_generic_lto_dispatch_func(
6009
6272
  inout = arg_values["inout"]
6010
6273
  inout.type.storage = "register"
6011
6274
 
6012
- if not is_tile(inout.type):
6013
- raise RuntimeError("tile_fft() arguments must be a single tile with register storage")
6014
-
6015
- if inout.type.dtype not in [vec2f, vec2d]:
6016
- raise RuntimeError("tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries")
6017
-
6018
6275
  # see libcufftdx.hpp
6019
6276
  if direction == "forward":
6020
6277
  dir = 0 # CUFFTDX_DIRECTION_FORWARD
6021
6278
  elif direction == "inverse":
6022
6279
  dir = 1 # CUFFTDX_DIRECTION_INVERSE
6023
6280
  else:
6024
- raise RuntimeError("Invalid direction")
6281
+ raise ValueError(f"Invalid direction: {direction!r}. Expected 'forward' or 'inverse'.")
6025
6282
 
6026
6283
  if inout.type.dtype == vec2f:
6027
6284
  dtype = "wp::vec2f"
@@ -6030,10 +6287,10 @@ def tile_fft_generic_lto_dispatch_func(
6030
6287
  dtype = "wp::vec2d"
6031
6288
  precision = 6 # COMMONDX_PRECISION_F64
6032
6289
  else:
6033
- raise RuntimeError("Unsupported datatype")
6290
+ raise TypeError(f"Unsupported data type, got {dtype!r}")
6034
6291
 
6035
6292
  # M FFTs of size N each
6036
- batch, size = inout.type.M, inout.type.N
6293
+ batch, size = inout.type.shape[0], inout.type.shape[1]
6037
6294
  num_threads = options["block_dim"]
6038
6295
  arch = options["output_arch"]
6039
6296
  ept = size // num_threads
@@ -6065,7 +6322,7 @@ def tile_fft_generic_lto_dispatch_func(
6065
6322
  lto_code.close()
6066
6323
  if lto_code_path.exists():
6067
6324
  lto_code_path.unlink()
6068
- raise RuntimeError("Failed to compile tile_matmul")
6325
+ raise RuntimeError("Failed to compile tile_fft")
6069
6326
 
6070
6327
  with open(lto_code.name, "rb") as f:
6071
6328
  lto_code_data = f.read()
@@ -6075,17 +6332,20 @@ def tile_fft_generic_lto_dispatch_func(
6075
6332
 
6076
6333
  builder.ltoirs[lto_symbol] = lto_code_data
6077
6334
 
6335
+ shared_memory_bytes = Tile.round_up(shared_memory_size.value)
6336
+
6078
6337
  return (
6079
6338
  (
6080
6339
  Var(lto_symbol, str, False, True, False),
6081
6340
  Var(dtype, str, False, True, False),
6082
- Var(str(shared_memory_size.value), str, False, True, False),
6341
+ Var(str(shared_memory_bytes), str, False, True, False),
6083
6342
  Var(str(batch), str, False, True, False),
6084
6343
  Var(str(ept), str, False, True, False),
6085
6344
  inout,
6086
6345
  ),
6087
6346
  [],
6088
6347
  [lto_code_data],
6348
+ shared_memory_bytes,
6089
6349
  )
6090
6350
 
6091
6351
 
@@ -6099,6 +6359,8 @@ add_builtin(
6099
6359
 
6100
6360
  This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
6101
6361
 
6362
+ Note that computing the adjoint is not yet supported.
6363
+
6102
6364
  Supported datatypes are:
6103
6365
  * vec2f, vec2d
6104
6366
 
@@ -6118,6 +6380,8 @@ add_builtin(
6118
6380
 
6119
6381
  This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
6120
6382
 
6383
+ Note that computing the adjoint is not yet supported.
6384
+
6121
6385
  Supported datatypes are:
6122
6386
  * vec2f, vec2d
6123
6387
 
@@ -6127,6 +6391,283 @@ add_builtin(
6127
6391
  namespace="",
6128
6392
  )
6129
6393
 
6394
+
6395
+ ##
6396
+ ## Cholesky
6397
+ ##
6398
+ def tile_cholesky_generic_value_func(arg_types, arg_values):
6399
+ if arg_types is None:
6400
+ return Tile(dtype=Any, shape=Any)
6401
+
6402
+ if len(arg_types) != 1:
6403
+ raise TypeError("tile_cholesky() requires 1 positional args")
6404
+
6405
+ a = arg_types["A"]
6406
+
6407
+ if not is_tile(a):
6408
+ raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
6409
+
6410
+ if len(a.shape) != 2:
6411
+ raise ValueError("tile_cholesky() argumust must be a 2D tile")
6412
+
6413
+ if a.shape[0] != a.shape[1]:
6414
+ raise ValueError("tile_cholesky() argument must be square")
6415
+
6416
+ return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
6417
+
6418
+
6419
+ cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3}
6420
+
6421
+ cusolver_type_map = {float32: ("wp::float32", 5), float64: ("wp::float64", 6)}
6422
+
6423
+ cusolver_fill_mode_map = {"upper": 0, "lower": 1}
6424
+
6425
+
6426
+ def tile_cholesky_generic_lto_dispatch_func(
6427
+ arg_types: Mapping[str, type],
6428
+ return_type: Any,
6429
+ return_values: List[Var],
6430
+ arg_values: Mapping[str, Var],
6431
+ options: Mapping[str, Any],
6432
+ builder: warp.context.ModuleBuilder,
6433
+ ):
6434
+ a = arg_values["A"]
6435
+ # force source tile to shared memory
6436
+ a.type.storage = "shared"
6437
+
6438
+ if a.type.dtype not in cusolver_type_map.keys():
6439
+ raise TypeError("tile_cholesky() argument must be a tile of float32 or float64 entries")
6440
+
6441
+ if len(return_values) != 1:
6442
+ raise TypeError("tile_cholesky() returns one output")
6443
+ out = return_values[0]
6444
+
6445
+ dtype, precision_enum = cusolver_type_map[a.type.dtype]
6446
+
6447
+ # We already ensured a is square in tile_cholesky_generic_value_func()
6448
+ M, N = a.type.shape[0], a.type.shape[1]
6449
+ if out.type.shape[0] != M or out.type.shape[1] != M:
6450
+ raise ValueError("tile_cholesky() output tile must be square")
6451
+
6452
+ num_threads = options["block_dim"]
6453
+ arch = options["output_arch"]
6454
+ lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
6455
+
6456
+ # early out if LTO for this combination already exists for this module
6457
+ if lto_symbol in builder.ltoirs:
6458
+ return lto_symbol, builder.ltoirs[lto_symbol]
6459
+
6460
+ # otherwise compile LTO
6461
+ lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6462
+ universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6463
+
6464
+ # cuSOLVERDx only support col-major input/outputs,
6465
+ # so we use upper to mimic a row-major input
6466
+ result = warp.context.runtime.core.cuda_compile_solver(
6467
+ universal_fatbin_code.name.encode("utf-8"),
6468
+ lto_code.name.encode("utf-8"),
6469
+ lto_symbol.encode("utf-8"),
6470
+ 0,
6471
+ None,
6472
+ None,
6473
+ arch,
6474
+ M,
6475
+ N,
6476
+ cusolver_function_map["potrf"],
6477
+ precision_enum,
6478
+ cusolver_fill_mode_map["upper"],
6479
+ num_threads,
6480
+ )
6481
+
6482
+ if not result:
6483
+ for f in [lto_code, universal_fatbin_code]:
6484
+ f.close()
6485
+ if Path(f.name).exists():
6486
+ Path(f.name).unlink()
6487
+ raise RuntimeError("Failed to compile tile_cholesky")
6488
+
6489
+ else:
6490
+ with open(lto_code.name, "rb") as f:
6491
+ lto_code_data = f.read()
6492
+ with open(universal_fatbin_code.name, "rb") as f:
6493
+ universal_fatbin_code_data = f.read()
6494
+ for f in [lto_code, universal_fatbin_code]:
6495
+ f.close()
6496
+ Path(f.name).unlink()
6497
+
6498
+ builder.ltoirs[lto_symbol] = lto_code_data
6499
+ builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, unsigned);"
6500
+ builder.fatbins["cholesky"] = universal_fatbin_code_data
6501
+
6502
+ return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
6503
+
6504
+
6505
+ add_builtin(
6506
+ "tile_cholesky",
6507
+ input_types={"A": Tile},
6508
+ value_func=tile_cholesky_generic_value_func,
6509
+ lto_dispatch_func=tile_cholesky_generic_lto_dispatch_func,
6510
+ variadic=True,
6511
+ doc="""Compute the Cholesky factorization L of a matrix A.
6512
+ L is lower triangular and satisfies LL^T = A.
6513
+
6514
+ Note that computing the adjoint is not yet supported.
6515
+
6516
+ Supported datatypes are:
6517
+ * float32
6518
+ * float64
6519
+
6520
+ :param A: A square, symmetric positive-definite, matrix.
6521
+ :returns L: A square, lower triangular, matrix, such that LL^T = A""",
6522
+ group="Tile Primitives",
6523
+ export=False,
6524
+ namespace="",
6525
+ )
6526
+
6527
+
6528
+ def tile_cholesky_solve_generic_value_func(arg_types, arg_values):
6529
+ if arg_types is None:
6530
+ return None
6531
+
6532
+ if len(arg_types) != 2:
6533
+ raise TypeError("tile_cholesky_solve() requires exactly 2 positional args")
6534
+
6535
+ l = arg_types["L"]
6536
+ x = arg_types["x"]
6537
+
6538
+ if not is_tile(l):
6539
+ raise TypeError(f"tile_cholesky_solve() 'L' argument must be a tile, got {l!r}")
6540
+
6541
+ if not is_tile(x):
6542
+ raise TypeError(f"tile_cholesky_solve() 'x' argument must be a tile, got {l!r}")
6543
+
6544
+ if not types_equal(l.dtype, x.dtype):
6545
+ raise TypeError(f"tile_cholesky_solve() arguments must have the same dtype, got {l.dtype} and {x.dtype}")
6546
+
6547
+ if l.shape[0] != l.shape[1]:
6548
+ raise ValueError("tile_cholesky_solve() 'L' argument must be square")
6549
+
6550
+ if len(x.shape) != 1:
6551
+ raise TypeError("tile_cholesky_solve() 'x' argument must be a 1D tile")
6552
+
6553
+ if x.shape[0] != l.shape[0]:
6554
+ raise ValueError(
6555
+ f"tile_cholesky_solve() 'x' argument must have the same number of elements as the number of rows in 'L', "
6556
+ f"got {x.shape[0]} elements in 'x' and {l.shape[0]} rows in 'L'"
6557
+ )
6558
+
6559
+ return Tile(dtype=l.dtype, shape=x.shape, storage="shared")
6560
+
6561
+
6562
+ def tile_cholesky_solve_generic_lto_dispatch_func(
6563
+ arg_types: Mapping[str, type],
6564
+ return_type: Any,
6565
+ return_values: List[Var],
6566
+ arg_values: Mapping[str, Var],
6567
+ options: Mapping[str, Any],
6568
+ builder: warp.context.ModuleBuilder,
6569
+ ):
6570
+ L = arg_values["L"]
6571
+ x = arg_values["x"]
6572
+ # force the storage type of the input variables to shared memory
6573
+ L.type.storage = "shared"
6574
+ x.type.storage = "shared"
6575
+
6576
+ if len(return_values) != 1:
6577
+ raise TypeError(f"tile_cholesky_solve() must return exactly one value, got {len(return_values)}")
6578
+
6579
+ y = return_values[0]
6580
+
6581
+ if any(T not in cusolver_type_map.keys() for T in [x.type.dtype, L.type.dtype]):
6582
+ raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
6583
+
6584
+ dtype, precision_enum = cusolver_type_map[L.type.dtype]
6585
+ M, N = L.type.shape[0], L.type.shape[1]
6586
+
6587
+ if len(y.type.shape) != 1:
6588
+ raise TypeError("tile_cholesky_solve() output vector must be 1D")
6589
+
6590
+ if y.type.shape[0] != M:
6591
+ raise ValueError(
6592
+ "tile_cholesky_solve() output vector must have same number of elements as the number of rows in 'L' "
6593
+ f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
6594
+ )
6595
+
6596
+ num_threads = options["block_dim"]
6597
+ arch = options["output_arch"]
6598
+ lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
6599
+
6600
+ # early out if LTO for this combination already exists for this module
6601
+ if lto_symbol in builder.ltoirs:
6602
+ return lto_symbol, builder.ltoirs[lto_symbol]
6603
+
6604
+ # otherwise compile LTO
6605
+ lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6606
+ universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6607
+
6608
+ # cuSOLVERDx only support col-major input/outputs,
6609
+ # so we use upper to mimic a row-major input
6610
+ result = warp.context.runtime.core.cuda_compile_solver(
6611
+ universal_fatbin_code.name.encode("utf-8"),
6612
+ lto_code.name.encode("utf-8"),
6613
+ lto_symbol.encode("utf-8"),
6614
+ 0,
6615
+ None,
6616
+ None,
6617
+ arch,
6618
+ M,
6619
+ N,
6620
+ cusolver_function_map["potrs"],
6621
+ precision_enum,
6622
+ cusolver_fill_mode_map["upper"],
6623
+ num_threads,
6624
+ )
6625
+
6626
+ if not result:
6627
+ for f in [lto_code, universal_fatbin_code]:
6628
+ f.close()
6629
+ if Path(f.name).exists():
6630
+ Path(f.name).unlink()
6631
+ raise RuntimeError("Failed to compile tile_cholesky_solve")
6632
+
6633
+ else:
6634
+ with open(lto_code.name, "rb") as f:
6635
+ lto_code_data = f.read()
6636
+ with open(universal_fatbin_code.name, "rb") as f:
6637
+ universal_fatbin_code_data = f.read()
6638
+ for f in [lto_code, universal_fatbin_code]:
6639
+ f.close()
6640
+ Path(f.name).unlink()
6641
+
6642
+ builder.ltoirs[lto_symbol] = lto_code_data
6643
+ builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, {dtype}*);"
6644
+ builder.fatbins["cholesky"] = universal_fatbin_code_data
6645
+
6646
+ return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
6647
+
6648
+
6649
+ add_builtin(
6650
+ "tile_cholesky_solve",
6651
+ input_types={"L": Tile, "x": Tile},
6652
+ value_func=tile_cholesky_solve_generic_value_func,
6653
+ lto_dispatch_func=tile_cholesky_solve_generic_lto_dispatch_func,
6654
+ variadic=True,
6655
+ doc="""With L such that LL^T = A, solve for x in Ax = y
6656
+
6657
+ Note that computing the adjoint is not yet supported.
6658
+
6659
+ Supported datatypes are:
6660
+ * float32
6661
+ * float64
6662
+
6663
+ :param L: A square, lower triangular, matrix, such that LL^T = A
6664
+ :param x: An 1D tile of length M
6665
+ :returns y: An 1D tile of length M such that LL^T y = x""",
6666
+ group="Tile Primitives",
6667
+ export=False,
6668
+ namespace="",
6669
+ )
6670
+
6130
6671
  # ---------------------------------
6131
6672
  # Code Generation
6132
6673
 
@@ -6134,7 +6675,7 @@ add_builtin(
6134
6675
  "static",
6135
6676
  input_types={"expr": Any},
6136
6677
  value_type=Any,
6137
- doc="""Evaluates a static Python expression and replaces it with its result.
6678
+ doc="""Evaluate a static Python expression and replaces it with its result.
6138
6679
 
6139
6680
  See the :ref:`code generation guide <static_expressions>` for more details.
6140
6681
 
@@ -6158,3 +6699,58 @@ def static(expr):
6158
6699
  which includes constant variables and variables captured in the current closure in which the function or kernel is implemented.
6159
6700
  """
6160
6701
  return expr
6702
+
6703
+
6704
+ add_builtin(
6705
+ "len",
6706
+ input_types={"a": vector(length=Any, dtype=Scalar)},
6707
+ value_type=int,
6708
+ doc="Return the number of elements in a vector.",
6709
+ group="Utility",
6710
+ export=False,
6711
+ )
6712
+
6713
+ add_builtin(
6714
+ "len",
6715
+ input_types={"a": quaternion(dtype=Scalar)},
6716
+ value_type=int,
6717
+ doc="Return the number of elements in a quaternion.",
6718
+ group="Utility",
6719
+ export=False,
6720
+ )
6721
+
6722
+ add_builtin(
6723
+ "len",
6724
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
6725
+ value_type=int,
6726
+ doc="Return the number of rows in a matrix.",
6727
+ group="Utility",
6728
+ export=False,
6729
+ )
6730
+
6731
+ add_builtin(
6732
+ "len",
6733
+ input_types={"a": transformation(dtype=Float)},
6734
+ value_type=int,
6735
+ doc="Return the number of elements in a transformation.",
6736
+ group="Utility",
6737
+ export=False,
6738
+ )
6739
+
6740
+ add_builtin(
6741
+ "len",
6742
+ input_types={"a": array(dtype=Any)},
6743
+ value_type=int,
6744
+ doc="Return the size of the first dimension in an array.",
6745
+ group="Utility",
6746
+ export=False,
6747
+ )
6748
+
6749
+ add_builtin(
6750
+ "len",
6751
+ input_types={"a": Tile(dtype=Any, shape=Any)},
6752
+ value_type=int,
6753
+ doc="Return the number of rows in a tile.",
6754
+ group="Utility",
6755
+ export=False,
6756
+ )