warp-lang 1.5.0__py3-none-manylinux2014_x86_64.whl → 1.6.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (132) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1124 -497
  8. warp/codegen.py +261 -136
  9. warp/config.py +1 -1
  10. warp/context.py +357 -119
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth.py +3 -1
  27. warp/examples/sim/example_cloth_self_contact.py +260 -0
  28. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  29. warp/examples/sim/example_jacobian_ik.py +0 -2
  30. warp/examples/sim/example_quadruped.py +5 -2
  31. warp/examples/tile/example_tile_cholesky.py +79 -0
  32. warp/examples/tile/example_tile_convolution.py +2 -2
  33. warp/examples/tile/example_tile_fft.py +2 -2
  34. warp/examples/tile/example_tile_filtering.py +3 -3
  35. warp/examples/tile/example_tile_matmul.py +4 -4
  36. warp/examples/tile/example_tile_mlp.py +12 -12
  37. warp/examples/tile/example_tile_nbody.py +180 -0
  38. warp/examples/tile/example_tile_walker.py +319 -0
  39. warp/fem/geometry/geometry.py +0 -2
  40. warp/math.py +147 -0
  41. warp/native/array.h +12 -0
  42. warp/native/builtin.h +0 -1
  43. warp/native/bvh.cpp +149 -70
  44. warp/native/bvh.cu +287 -68
  45. warp/native/bvh.h +195 -85
  46. warp/native/clang/clang.cpp +5 -1
  47. warp/native/coloring.cpp +5 -1
  48. warp/native/cuda_util.cpp +91 -53
  49. warp/native/cuda_util.h +5 -0
  50. warp/native/exports.h +40 -40
  51. warp/native/intersect.h +17 -0
  52. warp/native/mat.h +41 -0
  53. warp/native/mathdx.cpp +19 -0
  54. warp/native/mesh.cpp +25 -8
  55. warp/native/mesh.cu +153 -101
  56. warp/native/mesh.h +482 -403
  57. warp/native/quat.h +40 -0
  58. warp/native/solid_angle.h +7 -0
  59. warp/native/sort.cpp +85 -0
  60. warp/native/sort.cu +34 -0
  61. warp/native/sort.h +3 -1
  62. warp/native/spatial.h +11 -0
  63. warp/native/tile.h +1187 -669
  64. warp/native/tile_reduce.h +8 -6
  65. warp/native/vec.h +41 -0
  66. warp/native/warp.cpp +8 -1
  67. warp/native/warp.cu +263 -40
  68. warp/native/warp.h +19 -5
  69. warp/optim/linear.py +22 -4
  70. warp/render/render_opengl.py +130 -64
  71. warp/sim/__init__.py +6 -1
  72. warp/sim/collide.py +270 -26
  73. warp/sim/import_urdf.py +8 -8
  74. warp/sim/integrator_euler.py +25 -7
  75. warp/sim/integrator_featherstone.py +154 -35
  76. warp/sim/integrator_vbd.py +842 -40
  77. warp/sim/model.py +134 -72
  78. warp/sparse.py +1 -1
  79. warp/stubs.py +265 -132
  80. warp/tape.py +28 -30
  81. warp/tests/aux_test_module_unload.py +15 -0
  82. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  83. warp/tests/test_array.py +74 -0
  84. warp/tests/test_assert.py +242 -0
  85. warp/tests/test_codegen.py +14 -61
  86. warp/tests/test_collision.py +2 -2
  87. warp/tests/test_coloring.py +12 -2
  88. warp/tests/test_examples.py +12 -1
  89. warp/tests/test_func.py +21 -4
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_lerp.py +13 -87
  94. warp/tests/test_mat.py +138 -167
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +17 -16
  97. warp/tests/test_matmul_lite.py +10 -15
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +47 -2
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_smoothstep.py +17 -83
  109. warp/tests/test_static.py +19 -3
  110. warp/tests/test_tape.py +25 -0
  111. warp/tests/test_tile.py +178 -191
  112. warp/tests/test_tile_load.py +356 -0
  113. warp/tests/test_tile_mathdx.py +61 -8
  114. warp/tests/test_tile_mlp.py +17 -17
  115. warp/tests/test_tile_reduce.py +24 -18
  116. warp/tests/test_tile_shared_memory.py +66 -17
  117. warp/tests/test_tile_view.py +165 -0
  118. warp/tests/test_torch.py +35 -0
  119. warp/tests/test_utils.py +36 -24
  120. warp/tests/test_vec.py +110 -0
  121. warp/tests/unittest_suites.py +29 -4
  122. warp/tests/unittest_utils.py +30 -13
  123. warp/thirdparty/unittest_parallel.py +2 -2
  124. warp/types.py +411 -101
  125. warp/utils.py +10 -7
  126. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
  127. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
  128. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  129. warp/examples/benchmarks/benchmark_tile.py +0 -179
  130. warp/native/tile_gemm.h +0 -341
  131. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  132. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -399,11 +399,11 @@ def scalar_infer_type(arg_types: Mapping[str, type]):
399
399
 
400
400
  scalar_types = set()
401
401
  for t in arg_types:
402
- t = strip_reference(t)
403
- if hasattr(t, "_wp_scalar_type_"):
404
- scalar_types.add(t._wp_scalar_type_)
405
- elif t in scalar_and_bool_types:
406
- scalar_types.add(t)
402
+ t_val = strip_reference(t)
403
+ if hasattr(t_val, "_wp_scalar_type_"):
404
+ scalar_types.add(t_val._wp_scalar_type_)
405
+ elif t_val in scalar_and_bool_types:
406
+ scalar_types.add(t_val)
407
407
 
408
408
  if len(scalar_types) > 1:
409
409
  raise RuntimeError(
@@ -1707,64 +1707,98 @@ add_builtin(
1707
1707
 
1708
1708
  # ------------------
1709
1709
  # Tile-based primitives
1710
- shared_memory_id = 0
1710
+
1711
+
1712
+ def tile_unpack_shape(arg_values):
1713
+ shape = arg_values["shape"]
1714
+
1715
+ if not isinstance(shape, tuple):
1716
+ # promote to tuple
1717
+ shape = (shape,)
1718
+
1719
+ # check that components are constants
1720
+ for d in shape:
1721
+ if d is None:
1722
+ raise ValueError("Tile functions require shape to be a compile time constant.")
1723
+
1724
+ return shape
1725
+
1726
+
1727
+ def tile_unpack_offset(arg_values, ndim=0):
1728
+ if "offset" in arg_values:
1729
+ offset = arg_values["offset"]
1730
+ else:
1731
+ offset = (0,) * ndim
1732
+
1733
+ if isinstance(offset, tuple):
1734
+ return offset
1735
+ else:
1736
+ # promote to tuple
1737
+ return (offset,)
1711
1738
 
1712
1739
 
1713
1740
  def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1714
1741
  # return generic type (for doc builds)
1715
1742
  if arg_types is None:
1716
- return Tile(dtype=Any, M=Any, N=Any)
1717
-
1718
- if "m" not in arg_values:
1719
- raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
1743
+ return Tile(dtype=Any, shape=Any)
1720
1744
 
1721
- if "n" not in arg_values:
1722
- raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
1745
+ shape = tile_unpack_shape(arg_values)
1723
1746
 
1724
1747
  if "dtype" not in arg_values:
1725
- raise RuntimeError("'dtype' keyword argument must be specified when calling tile_zeros() function")
1748
+ raise TypeError("tile_zeros() missing required keyword argument 'dtype'")
1726
1749
 
1727
1750
  if "storage" not in arg_values:
1728
- raise ValueError("'storage' keyword not provided for tile_zeros")
1751
+ raise TypeError("tile_zeros() missing required keyword argument 'storage'")
1729
1752
 
1730
1753
  if arg_values["storage"] not in {"shared", "register"}:
1731
- raise ValueError(
1732
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1733
- )
1754
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
1734
1755
 
1735
- m, n = arg_values["m"], arg_values["n"]
1736
1756
  dtype = arg_values["dtype"]
1737
1757
 
1738
- return TileZeros(dtype=dtype, M=m, N=n, storage=arg_values["storage"])
1758
+ return TileZeros(dtype=dtype, shape=shape, storage=arg_values["storage"])
1739
1759
 
1740
1760
 
1741
1761
  def tile_zeros_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1742
- m, n, dtype = arg_values["m"], arg_values["n"], arg_values["dtype"]
1762
+ shape = tile_unpack_shape(arg_values)
1763
+ dtype = arg_values["dtype"]
1743
1764
 
1744
1765
  template_args = []
1745
1766
  template_args.append(dtype)
1746
- template_args.append(m.constant)
1747
- template_args.append(n.constant)
1767
+ for d in shape:
1768
+ template_args.append(d.constant)
1748
1769
 
1749
1770
  return ([], template_args)
1750
1771
 
1751
1772
 
1752
1773
  add_builtin(
1753
1774
  "tile_zeros",
1754
- input_types={"m": int, "n": int, "dtype": Any, "storage": str},
1755
- defaults={"storage": "register"},
1775
+ input_types={"shape": Tuple[int, ...], "dtype": Any, "storage": str},
1776
+ defaults={"storage": "register", "dtype": float},
1756
1777
  value_func=tile_zeros_value_func,
1757
1778
  dispatch_func=tile_zeros_dispatch_func,
1758
1779
  variadic=False,
1759
1780
  missing_grad=True,
1760
- doc="""Allocates a tile of zero-initialized items.
1781
+ doc="""Allocate a tile of zero-initialized items.
1761
1782
 
1762
- :param m: Size of the first dimension of the output tile
1763
- :param n: Size of the second dimension of the output tile
1764
- :param dtype: Datatype of output tile's elements
1783
+ :param shape: Shape of the output tile
1784
+ :param dtype: Data type of output tile's elements (default float)
1765
1785
  :param storage: The storage location for the tile: ``"register"`` for registers
1766
1786
  (default) or ``"shared"`` for shared memory.
1767
- :returns: A zero-initialized tile with ``shape=(m,n)`` and the specified datatype""",
1787
+ :returns: A zero-initialized tile with shape and data type as specified""",
1788
+ group="Tile Primitives",
1789
+ export=False,
1790
+ )
1791
+
1792
+ # overload for scalar shape
1793
+ add_builtin(
1794
+ "tile_zeros",
1795
+ input_types={"shape": int, "dtype": Any, "storage": str},
1796
+ defaults={"storage": "register", "dtype": float},
1797
+ value_func=tile_zeros_value_func,
1798
+ dispatch_func=tile_zeros_dispatch_func,
1799
+ variadic=False,
1800
+ missing_grad=True,
1801
+ hidden=True,
1768
1802
  group="Tile Primitives",
1769
1803
  export=False,
1770
1804
  )
@@ -1773,54 +1807,63 @@ add_builtin(
1773
1807
  def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1774
1808
  # return generic type (for doc builds)
1775
1809
  if arg_types is None:
1776
- return Tile(dtype=Any, M=Any, N=Any)
1810
+ return Tile(dtype=Any, shape=Any)
1777
1811
 
1778
- if "m" not in arg_values:
1779
- raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
1780
-
1781
- if "n" not in arg_values:
1782
- raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
1812
+ shape = tile_unpack_shape(arg_values)
1783
1813
 
1784
1814
  if "dtype" not in arg_values:
1785
- raise RuntimeError("'dtype' keyword argument must be specified when calling tile_zeros() function")
1815
+ raise TypeError("tile_ones() missing required keyword argument 'dtype'")
1816
+
1817
+ if "storage" not in arg_values:
1818
+ raise TypeError("tile_ones() missing required keyword argument 'storage'")
1786
1819
 
1787
1820
  if arg_values["storage"] not in {"shared", "register"}:
1788
- raise ValueError(
1789
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1790
- )
1821
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
1791
1822
 
1792
- m, n = arg_values["m"], arg_values["n"]
1793
1823
  dtype = arg_values["dtype"]
1794
1824
 
1795
- return TileZeros(dtype=dtype, M=m, N=n, storage=arg_values["storage"])
1825
+ return TileOnes(dtype=dtype, shape=shape, storage=arg_values["storage"])
1796
1826
 
1797
1827
 
1798
1828
  def tile_ones_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1799
- m, n, dtype = arg_values["m"], arg_values["n"], arg_values["dtype"]
1829
+ shape = tile_unpack_shape(arg_values)
1830
+ dtype = arg_values["dtype"]
1800
1831
 
1801
1832
  template_args = []
1802
1833
  template_args.append(dtype)
1803
- template_args.append(m.constant)
1804
- template_args.append(n.constant)
1834
+ for d in shape:
1835
+ template_args.append(d.constant)
1805
1836
 
1806
1837
  return ([], template_args)
1807
1838
 
1808
1839
 
1809
1840
  add_builtin(
1810
1841
  "tile_ones",
1811
- input_types={"m": int, "n": int, "dtype": Any, "storage": str},
1842
+ input_types={"shape": Tuple[int, ...], "dtype": Any, "storage": str},
1812
1843
  defaults={"storage": "register"},
1813
1844
  value_func=tile_ones_value_func,
1814
1845
  dispatch_func=tile_ones_dispatch_func,
1815
1846
  missing_grad=True,
1816
- doc="""Allocates a tile of one-initialized items.
1847
+ doc="""Allocate a tile of one-initialized items.
1817
1848
 
1818
- :param m: Size of the first dimension of the output tile
1819
- :param n: Size of the second dimension of the output tile
1820
- :param dtype: Datatype of output tile's elements
1849
+ :param shape: Shape of the output tile
1850
+ :param dtype: Data type of output tile's elements
1821
1851
  :param storage: The storage location for the tile: ``"register"`` for registers
1822
1852
  (default) or ``"shared"`` for shared memory.
1823
- :returns: A one-initialized tile with ``shape=(m,n)`` and the specified dtype""",
1853
+ :returns: A one-initialized tile with shape and data type as specified""",
1854
+ group="Tile Primitives",
1855
+ export=False,
1856
+ )
1857
+
1858
+ # overload for scalar shape
1859
+ add_builtin(
1860
+ "tile_ones",
1861
+ input_types={"shape": int, "dtype": Any, "storage": str},
1862
+ defaults={"storage": "register"},
1863
+ value_func=tile_ones_value_func,
1864
+ dispatch_func=tile_ones_dispatch_func,
1865
+ missing_grad=True,
1866
+ hidden=True,
1824
1867
  group="Tile Primitives",
1825
1868
  export=False,
1826
1869
  )
@@ -1829,14 +1872,16 @@ add_builtin(
1829
1872
  def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1830
1873
  # return generic type (for doc builds)
1831
1874
  if arg_types is None:
1832
- return Tile(dtype=Any, M=Any, N=Any)
1875
+ return Tile(dtype=Any, shape=Any)
1876
+
1877
+ if "args" not in arg_values:
1878
+ raise TypeError("tile_arange() requires at least one positional argument specifying the range")
1879
+
1880
+ args = arg_values["args"]
1833
1881
 
1834
1882
  start = 0
1835
1883
  stop = 0
1836
1884
  step = 1
1837
- dtype = int
1838
-
1839
- args = arg_values["args"]
1840
1885
 
1841
1886
  if len(args) == 1:
1842
1887
  start = 0
@@ -1852,7 +1897,7 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
1852
1897
  step = args[2]
1853
1898
 
1854
1899
  if start is None or stop is None or step is None:
1855
- raise RuntimeError("wp.tile_arange() arguments must be compile time constants")
1900
+ raise RuntimeError("tile_arange() arguments must be compile time constants")
1856
1901
 
1857
1902
  if "dtype" in arg_values:
1858
1903
  dtype = arg_values["dtype"]
@@ -1860,26 +1905,37 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
1860
1905
  dtype = float
1861
1906
 
1862
1907
  if arg_values["storage"] not in {"shared", "register"}:
1863
- raise ValueError(
1864
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1865
- )
1908
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
1866
1909
 
1867
1910
  return TileRange(dtype=dtype, start=start, stop=stop, step=step, storage=arg_values["storage"])
1868
1911
 
1869
1912
 
1870
1913
  def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1871
- m, n, dtype = return_type.M, return_type.N, return_type.dtype
1914
+ size, dtype = return_type.size, return_type.dtype
1872
1915
 
1873
1916
  template_args = []
1874
1917
  template_args.append(dtype)
1875
- template_args.append(m)
1876
- template_args.append(n)
1918
+ template_args.append(size)
1919
+
1920
+ if "args" not in arg_values:
1921
+ raise TypeError("tile_arange() requires at least one positional argument specifying the range")
1877
1922
 
1878
- # todo: it is somewhat redundant to create new vars here since some of start,stop,step
1879
- # already exist depending on which form the function was called by the user
1880
- start = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.start)
1881
- stop = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.stop)
1882
- step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
1923
+ args = arg_values["args"]
1924
+
1925
+ if len(args) == 1:
1926
+ start = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.start)
1927
+ stop = args[0]
1928
+ step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
1929
+ elif len(args) == 2:
1930
+ start = args[0]
1931
+ stop = args[1]
1932
+ step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
1933
+ elif len(args) == 3:
1934
+ start = args[0]
1935
+ stop = args[1]
1936
+ step = args[2]
1937
+ else:
1938
+ raise TypeError(f"tile_arange() accepts at most 3 positional arguments, got {len(args)}")
1883
1939
 
1884
1940
  function_args = []
1885
1941
  function_args.append(start)
@@ -1897,7 +1953,7 @@ add_builtin(
1897
1953
  dispatch_func=tile_arange_dispatch_func,
1898
1954
  variadic=True,
1899
1955
  missing_grad=True,
1900
- doc="""Generates a tile of linearly spaced elements.
1956
+ doc="""Generate a tile of linearly spaced elements.
1901
1957
 
1902
1958
  :param args: Variable-length positional arguments, interpreted as:
1903
1959
 
@@ -1905,246 +1961,157 @@ add_builtin(
1905
1961
  - ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
1906
1962
  - ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
1907
1963
 
1908
- :param dtype: Datatype of output tile's elements (optional, default: int)
1964
+ :param dtype: Data type of output tile's elements (optional, default: ``float``)
1909
1965
  :param storage: The storage location for the tile: ``"register"`` for registers
1910
1966
  (default) or ``"shared"`` for shared memory.
1911
- :returns: A tile with ``shape=(1,n)`` with linearly spaced elements of specified dtype""",
1967
+ :returns: A tile with ``shape=(n)`` with linearly spaced elements of specified data type""",
1912
1968
  group="Tile Primitives",
1913
1969
  export=False,
1914
1970
  )
1915
1971
 
1916
1972
 
1917
- def tile_load_1d_value_func(arg_types, arg_values):
1918
- # return generic type (for doc builds)
1973
+ def tile_load_tuple_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1919
1974
  if arg_types is None:
1920
- return Tile(dtype=Any, M=Any, N=Any)
1921
-
1922
- if not is_array(arg_types["a"]):
1923
- raise RuntimeError("tile_load() argument 0 must be an array")
1975
+ return array(dtype=Scalar)
1924
1976
 
1925
- if arg_types["a"].ndim != 1:
1926
- raise RuntimeError(
1927
- "tile_load() argument 0 must be 1-dimensional if using the ``wp.tile_load(array, i, n)`` syntax."
1928
- )
1977
+ a = arg_types["a"]
1929
1978
 
1930
- if not type_is_int(arg_types["i"]):
1931
- raise RuntimeError("tile_load() argument 1 must be an integer")
1979
+ shape = tile_unpack_shape(arg_values)
1980
+ offset = tile_unpack_offset(arg_values, a.ndim)
1932
1981
 
1933
- if "n" not in arg_values:
1934
- raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
1982
+ if a.ndim != len(shape):
1983
+ raise ValueError(
1984
+ f"tile_load() array argument must have same number of dimensions as the tile shape, trying to perform an {len(shape)} dimensional load from an array with {a.ndim} dimensions."
1985
+ )
1935
1986
 
1936
- if arg_values["storage"] not in {"shared", "register"}:
1987
+ if a.ndim != len(offset):
1937
1988
  raise ValueError(
1938
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1989
+ f"tile_load() offset argument must have the same number of dimensions as the array to load from, got {len(offset)} indices for an array with {a.ndim} dimensions"
1939
1990
  )
1940
1991
 
1941
- a = arg_types["a"]
1942
- _m, n = 1, arg_values["n"]
1992
+ if arg_values["storage"] not in {"shared", "register"}:
1993
+ raise ValueError(f"Invalid value for 'storage': {arg_values['storage']!r}. Expected 'shared' or 'register'.")
1943
1994
 
1944
- return TileLoad(a, 1, n, arg_values["storage"])
1995
+ return Tile(dtype=a.dtype, shape=shape, storage=arg_values["storage"])
1945
1996
 
1946
1997
 
1947
- def tile_load_1d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1948
- array = arg_values["a"]
1949
- i = arg_values["i"]
1950
- n = arg_values["n"].constant
1951
- dtype = arg_values["a"].type.dtype
1998
+ def tile_load_tuple_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
1999
+ a = args["a"]
2000
+ shape = tile_unpack_shape(args)
2001
+ offset = tile_unpack_offset(args, a.type.ndim)
1952
2002
 
1953
- template_args = []
1954
- template_args.append(dtype)
1955
- template_args.append(n)
2003
+ func_args = (a, *offset)
2004
+ template_args = (d.constant for d in shape)
1956
2005
 
1957
- return ((array, i), template_args)
2006
+ return (func_args, template_args)
1958
2007
 
1959
2008
 
1960
2009
  add_builtin(
1961
2010
  "tile_load",
1962
- input_types={"a": array(dtype=Any), "i": int, "n": int, "storage": str},
1963
- defaults={"storage": "register"},
1964
- value_func=tile_load_1d_value_func,
1965
- dispatch_func=tile_load_1d_dispatch_func,
2011
+ input_types={"a": array(dtype=Any), "shape": Tuple[int, ...], "offset": Tuple[int, ...], "storage": str},
2012
+ value_func=tile_load_tuple_value_func,
2013
+ dispatch_func=tile_load_tuple_dispatch_func,
2014
+ defaults={"offset": None, "storage": "register"},
1966
2015
  variadic=False,
1967
- doc="""Loads a 1D tile from a global memory array.
2016
+ doc="""Loads a tile from a global memory array.
1968
2017
 
1969
2018
  This method will cooperatively load a tile from global memory using all threads in the block.
1970
2019
 
1971
2020
  :param a: The source array in global memory
1972
- :param i: Offset in the source array measured in multiples of ``n``, i.e.: ``offset=i*n``
1973
- :param n: The number of elements in the tile
2021
+ :param shape: Shape of the tile to load, must have the same number of dimensions as ``a``
2022
+ :param offset: Offset in the source array to begin reading from (optional)
1974
2023
  :param storage: The storage location for the tile: ``"register"`` for registers
1975
2024
  (default) or ``"shared"`` for shared memory.
1976
- :returns: A tile with ``shape=(1,n)`` and dtype the same as the source array""",
2025
+ :returns: A tile with shape as specified and data type the same as the source array""",
1977
2026
  group="Tile Primitives",
1978
2027
  export=False,
1979
2028
  )
1980
2029
 
1981
-
1982
- def tile_load_2d_value_func(arg_types, arg_values):
1983
- # return generic type (for doc builds)
1984
- if arg_types is None:
1985
- return Tile(dtype=Any, M=Any, N=Any)
1986
-
1987
- if not is_array(arg_types["a"]):
1988
- raise RuntimeError("tile_load() argument 0 must be an array")
1989
-
1990
- if arg_types["a"].ndim != 2:
1991
- raise RuntimeError(
1992
- "tile_load() argument 0 must be 2-dimensional if using the ``wp.tile_load(array, i, j, m, n)`` syntax."
1993
- )
1994
-
1995
- if not type_is_int(arg_types["i"]):
1996
- raise RuntimeError("tile_load() argument 1 must be an integer")
1997
-
1998
- if not type_is_int(arg_types["j"]):
1999
- raise RuntimeError("tile_load() argument 1 must be an integer")
2000
-
2001
- if "m" not in arg_values:
2002
- raise RuntimeError("'m' keyword argument must be specified when calling tile_load() function")
2003
-
2004
- if "n" not in arg_values:
2005
- raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
2006
-
2007
- if arg_values["storage"] not in {"shared", "register"}:
2008
- raise ValueError(
2009
- f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
2010
- )
2011
-
2012
- a = arg_types["a"]
2013
- m, n = arg_values["m"], arg_values["n"]
2014
-
2015
- return TileLoad(a, m, n, arg_values["storage"])
2016
-
2017
-
2018
- def tile_load_2d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2019
- array = arg_values["a"]
2020
- i, j = arg_values["i"], arg_values["j"]
2021
- m, n = arg_values["m"].constant, arg_values["n"].constant
2022
- dtype = arg_values["a"].type.dtype
2023
-
2024
- template_args = []
2025
- template_args.append(dtype)
2026
- template_args.append(m)
2027
- template_args.append(n)
2028
-
2029
- return ((array, i, j), template_args)
2030
-
2031
-
2030
+ # overload for scalar shape
2032
2031
  add_builtin(
2033
2032
  "tile_load",
2034
- input_types={"a": array(dtype=Any), "i": int, "j": int, "m": int, "n": int, "storage": str},
2035
- defaults={"storage": "register"},
2036
- value_func=tile_load_2d_value_func,
2037
- dispatch_func=tile_load_2d_dispatch_func,
2038
- variadic=False,
2039
- doc="""Loads a 2D tile from a global memory array.
2040
-
2041
- This method will cooperatively load a tile from global memory using all threads in the block.
2042
-
2043
- :param a: The source array in global memory
2044
- :param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
2045
- :param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
2046
- :param m: The size of the tile's first dimension
2047
- :param n: The size of the tile's second dimension
2048
- :param storage: The storage location for the tile: ``"register"`` for registers
2049
- (default) or ``"shared"`` for shared memory.
2050
- :returns: A tile with ``shape=(m,n)`` and dtype the same as the source array""",
2033
+ input_types={"a": array(dtype=Any), "shape": int, "offset": int, "storage": str},
2034
+ value_func=tile_load_tuple_value_func,
2035
+ dispatch_func=tile_load_tuple_dispatch_func,
2036
+ defaults={"offset": None, "storage": "register"},
2051
2037
  group="Tile Primitives",
2038
+ hidden=True,
2052
2039
  export=False,
2053
2040
  )
2054
2041
 
2055
2042
 
2056
- def tile_store_1d_value_func(arg_types, arg_values):
2043
+ def tile_store_value_func(arg_types, arg_values):
2057
2044
  # return generic type (for doc builds)
2058
2045
  if arg_types is None:
2059
2046
  return None
2060
2047
 
2061
- if len(arg_types) != 3:
2062
- raise RuntimeError("tile_store() requires 3 positional args")
2048
+ a = arg_types["a"]
2049
+ t = arg_types["t"]
2063
2050
 
2064
- if not is_array(arg_types["a"]):
2065
- raise RuntimeError("tile_store() argument 0 must be an array")
2051
+ c = tile_unpack_offset(arg_types, a.ndim)
2066
2052
 
2067
- if arg_types["a"].ndim != 1:
2068
- raise RuntimeError(
2069
- "tile_load() argument 0 must be a 1-dimensional array if using the ``wp.tile_store(array, i, t)`` syntax."
2053
+ if len(c) != a.ndim:
2054
+ raise ValueError(
2055
+ f"tile_store() 'a' argument must have {len(c)} dimensions, "
2056
+ f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
2070
2057
  )
2071
2058
 
2072
- if not type_is_int(arg_types["i"]):
2073
- raise RuntimeError("tile_store() argument 1 must be an integer")
2074
-
2075
- if not is_tile(arg_types["t"]):
2076
- raise RuntimeError("tile_store() argument 2 must be a tile")
2059
+ if len(t.shape) != a.ndim:
2060
+ raise ValueError(
2061
+ f"tile_store() 'a' argument must have the same number of dimensions as the 't' argument, "
2062
+ f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
2063
+ )
2077
2064
 
2078
2065
  if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2079
- raise RuntimeError("tile_store() destination array must have same type as source tile")
2066
+ raise TypeError(
2067
+ f"tile_store() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2068
+ )
2080
2069
 
2081
2070
  return None
2082
2071
 
2083
2072
 
2073
+ def tile_store_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2074
+ a = args["a"]
2075
+ t = args["t"]
2076
+
2077
+ offset = tile_unpack_offset(args, a.type.ndim)
2078
+
2079
+ func_args = (a, *offset, t)
2080
+ template_args = []
2081
+
2082
+ return (func_args, template_args)
2083
+
2084
+
2084
2085
  add_builtin(
2085
2086
  "tile_store",
2086
- input_types={"a": array(dtype=Any), "i": int, "t": Any},
2087
- value_func=tile_store_1d_value_func,
2087
+ input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
2088
+ value_func=tile_store_value_func,
2089
+ dispatch_func=tile_store_dispatch_func,
2090
+ defaults={"offset": None},
2088
2091
  variadic=False,
2089
2092
  skip_replay=True,
2090
- doc="""Stores a 1D tile to a global memory array.
2093
+ doc="""Store a tile to a global memory array.
2091
2094
 
2092
2095
  This method will cooperatively store a tile to global memory using all threads in the block.
2093
2096
 
2094
2097
  :param a: The destination array in global memory
2095
- :param i: Offset in the destination array measured in multiples of ``n``, i.e.: ``offset=i*n``
2096
- :param t: The source tile to store data from, must have the same dtype as the destination array""",
2098
+ :param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
2099
+ :param offset: Offset in the destination array (optional)""",
2097
2100
  group="Tile Primitives",
2098
2101
  export=False,
2099
2102
  )
2100
2103
 
2101
-
2102
- def tile_store_2d_value_func(arg_types, arg_values):
2103
- # return generic type (for doc builds)
2104
- if arg_types is None:
2105
- return None
2106
-
2107
- if len(arg_types) != 4:
2108
- raise RuntimeError("tile_store() requires 4 positional args")
2109
-
2110
- if not is_array(arg_types["a"]):
2111
- raise RuntimeError("tile_store() argument 0 must be an array")
2112
-
2113
- if arg_types["a"].ndim != 2:
2114
- raise RuntimeError(
2115
- "tile_load() argument 0 must be a 2-dimensional array if using the ``wp.tile_store(array, i, j, t)`` syntax."
2116
- )
2117
-
2118
- if not type_is_int(arg_types["i"]):
2119
- raise RuntimeError("tile_store() argument 1 must be an integer")
2120
-
2121
- if not type_is_int(arg_types["j"]):
2122
- raise RuntimeError("tile_store() argument 2 must be an integer")
2123
-
2124
- if not is_tile(arg_types["t"]):
2125
- raise RuntimeError("tile_store() argument 3 must be a tile")
2126
-
2127
- if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2128
- raise RuntimeError("tile_store() destination array must have same type as source tile")
2129
-
2130
- return None
2131
-
2132
-
2104
+ # overload for scalar offset
2133
2105
  add_builtin(
2134
2106
  "tile_store",
2135
- input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Any},
2136
- value_func=tile_store_2d_value_func,
2107
+ input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
2108
+ value_func=tile_store_value_func,
2109
+ dispatch_func=tile_store_dispatch_func,
2110
+ defaults={"offset": None},
2137
2111
  variadic=False,
2138
2112
  skip_replay=True,
2139
- doc="""Stores a tile to a global memory array.
2140
-
2141
- This method will cooperatively store a tile to global memory using all threads in the block.
2142
-
2143
- :param a: The destination array in global memory
2144
- :param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
2145
- :param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
2146
- :param t: The source tile to store data from, must have the same dtype as the destination array""",
2147
2113
  group="Tile Primitives",
2114
+ hidden=True,
2148
2115
  export=False,
2149
2116
  )
2150
2117
 
@@ -2152,132 +2119,221 @@ add_builtin(
2152
2119
  def tile_atomic_add_value_func(arg_types, arg_values):
2153
2120
  # return generic type (for doc builds)
2154
2121
  if arg_types is None:
2155
- return Tile(dtype=Any, M=Any, N=Any)
2122
+ return Tile(dtype=Any, shape=Any)
2156
2123
 
2157
- if len(arg_types) != 4:
2158
- raise RuntimeError("tile_atomic_add() requires 4 positional args")
2124
+ a = arg_types["a"]
2125
+ t = arg_types["t"]
2159
2126
 
2160
- if not is_array(arg_types["a"]):
2161
- raise RuntimeError("tile_atomic_add() argument 0 must be an array")
2127
+ c = tile_unpack_offset(arg_types, a.ndim)
2128
+ if len(c) != a.ndim:
2129
+ raise ValueError(
2130
+ f"tile_atomic_add() 'a' argument must have {len(c)} dimensions, "
2131
+ f"calculated based on the provided offset arguments, but got {a.ndim} dimensions."
2132
+ )
2162
2133
 
2163
- if not type_is_int(arg_types["x"]):
2164
- raise RuntimeError("tile_atomic_add() argument 1 must be an integer")
2134
+ if a.ndim != len(t.shape):
2135
+ raise ValueError(
2136
+ f"tile_atomic_add() 'a' argument must have the same number of dimensions as the 't' argument, "
2137
+ f"but got {a.ndim} dimensions for 'a' and {len(t.shape)} dimensions for 't'"
2138
+ )
2139
+
2140
+ if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2141
+ raise TypeError(
2142
+ f"tile_atomic_add() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2143
+ )
2165
2144
 
2166
- if not type_is_int(arg_types["y"]):
2167
- raise RuntimeError("tile_atomic_add() argument 2 must be an integer")
2145
+ return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape)
2168
2146
 
2169
- if not is_tile(arg_types["t"]):
2170
- raise RuntimeError("tile_atomic_add() argument 3 must be a tile")
2171
2147
 
2172
- if arg_types["a"].dtype != arg_types["t"].dtype:
2173
- raise RuntimeError("tile_atomic_add() tile dtype and array dtype must match")
2148
+ def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2149
+ a = args["a"]
2150
+ t = args["t"]
2174
2151
 
2175
- return Tile(dtype=arg_types["t"].dtype, M=arg_types["t"].M, N=arg_types["t"].N)
2152
+ offset = tile_unpack_offset(args, a.type.ndim)
2153
+
2154
+ func_args = (a, *offset, t)
2155
+ template_args = []
2156
+
2157
+ return (func_args, template_args)
2176
2158
 
2177
2159
 
2178
2160
  add_builtin(
2179
2161
  "tile_atomic_add",
2180
- input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Any},
2162
+ input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
2181
2163
  value_func=tile_atomic_add_value_func,
2182
- variadic=True,
2164
+ dispatch_func=tile_atomic_add_dispatch_func,
2165
+ defaults={"offset": None},
2166
+ variadic=False,
2183
2167
  skip_replay=True,
2184
- doc="""Atomically add a tile to the array `a`, each element will be updated atomically.
2168
+ doc="""Atomically add a 1D tile to the array `a`, each element will be updated atomically.
2185
2169
 
2186
2170
  :param a: Array in global memory, should have the same ``dtype`` as the input tile
2187
- :param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
2188
- :param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
2189
2171
  :param t: Source tile to add to the destination array
2190
- :returns: A tile with the same dimensions and type as the source tile, holding the original value of the destination elements""",
2172
+ :param offset: Offset in the destination array (optional)
2173
+ :returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements""",
2191
2174
  group="Tile Primitives",
2192
2175
  export=False,
2193
2176
  )
2194
2177
 
2178
+ # overload for scalar offset
2179
+ add_builtin(
2180
+ "tile_atomic_add",
2181
+ input_types={"a": array(dtype=Any), "t": Tile(dtype=Any, shape=Any), "offset": int},
2182
+ value_func=tile_atomic_add_value_func,
2183
+ dispatch_func=tile_atomic_add_dispatch_func,
2184
+ defaults={"offset": None},
2185
+ variadic=False,
2186
+ skip_replay=True,
2187
+ group="Tile Primitives",
2188
+ hidden=True,
2189
+ export=False,
2190
+ )
2191
+
2195
2192
 
2196
2193
  def tile_view_value_func(arg_types, arg_values):
2197
2194
  # return generic type (for doc builds)
2198
2195
  if arg_types is None:
2199
- return Tile(dtype=Any, M=Any, N=Any)
2196
+ return Tile(dtype=Any, shape=Any)
2200
2197
 
2201
2198
  tile = arg_types["t"]
2199
+ offset = arg_types["offset"]
2202
2200
 
2203
- if "m" not in arg_values:
2204
- m = 1
2205
- else:
2206
- m = arg_values["m"]
2201
+ if len(offset) > len(tile.shape):
2202
+ raise ValueError(f"tile_view() specified too many offset coordinates {len(offset)} > {len(tile.shape)}")
2207
2203
 
2208
- if "n" not in arg_values:
2209
- n = tile.N
2204
+ if "shape" in arg_values:
2205
+ # if shape is specified take it directly, e.g.:
2206
+ # tile_view(t, offset=(i,j), shape=(m,n))
2207
+ shape = arg_values["shape"]
2208
+ strides = tile.strides
2209
+
2210
+ if len(shape) != len(tile.shape):
2211
+ raise ValueError(
2212
+ f"tile_view() if shape is specified it must have same number of dimensions as source tile, expected {len(tile.shape)}, got {len(shape)}"
2213
+ )
2210
2214
  else:
2211
- n = arg_values["n"]
2215
+ # if not specified, then take output shape from unspecified src dimensions
2216
+ # e.g.: tile[i] will return a whole row of a 2D tile
2217
+ shape = tile.shape[len(offset) :]
2218
+ strides = tile.strides[len(offset) :]
2212
2219
 
2213
- if m > tile.M or n > tile.N:
2214
- raise RuntimeError(
2215
- f"Trying to view a tile subrange with dimensions ({m}, {n}) which is larger than source tile with dimensions ({tile.M}, {tile.N})"
2216
- )
2220
+ assert len(shape) == len(strides)
2217
2221
 
2218
2222
  # force source tile to shared memory
2219
2223
  tile.storage = "shared"
2220
2224
 
2221
- output = Tile(dtype=tile.dtype, M=m, N=n, strides=tile.strides, layout=tile.layout, storage="shared", owner=False)
2225
+ output = Tile(dtype=tile.dtype, shape=shape, strides=strides, layout=tile.layout, storage="shared", owner=False)
2222
2226
  return output
2223
2227
 
2224
2228
 
2225
2229
  def tile_view_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2226
2230
  tile = arg_values["t"]
2227
- i = arg_values["i"]
2228
-
2229
- if "j" not in arg_values:
2230
- j = warp.codegen.Var(label=None, type=int, constant=0)
2231
- else:
2232
- j = arg_values["j"]
2231
+ coord = arg_values["offset"]
2233
2232
 
2234
- template_args = []
2235
- template_args.append(return_type.M)
2236
- template_args.append(return_type.N)
2233
+ # zero-pad coord to match source array
2234
+ view_coord = [0] * len(tile.type.shape)
2235
+ for i in range(len(coord)):
2236
+ view_coord[i] = coord[i]
2237
2237
 
2238
- return ((tile, i, j), template_args)
2238
+ return ((tile, *view_coord), (return_type,))
2239
2239
 
2240
2240
 
2241
2241
  add_builtin(
2242
2242
  "tile_view",
2243
- input_types={"t": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int, "m": int, "n": int},
2243
+ input_types={"t": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...], "shape": Tuple[int, ...]},
2244
2244
  value_func=tile_view_value_func,
2245
2245
  dispatch_func=tile_view_dispatch_func,
2246
- defaults={"j": None, "m": None, "n": None},
2247
- variadic=True,
2248
- doc="""Return a subrange of a given tile from coordinates (i,j) to (i+m, j+n).
2246
+ defaults={"shape": None},
2247
+ variadic=False,
2248
+ doc="""Return a slice of a given tile [offset, offset+shape], if shape is not specified it will be inferred from the unspecified offset dimensions.
2249
2249
 
2250
2250
  :param t: Input tile to extract a subrange from
2251
- :param i: Offset in the source tile along the first dimension
2252
- :param j: Offset in the source tile along the second dimensions
2253
- :param m: Size of the subrange to return along the first dimension
2254
- :param n: Size of the subrange to return along the second dimension
2255
- :returns: A tile with dimensions (m,n) and the same datatype as the input tile""",
2251
+ :param offset: Offset in the source tile
2252
+ :param shape: Shape of the returned slice
2253
+ :returns: A tile with dimensions given by the specified shape or the remaining source tile dimensions""",
2256
2254
  group="Tile Primitives",
2255
+ missing_grad=True,
2257
2256
  export=False,
2258
2257
  )
2259
2258
 
2260
2259
 
2261
2260
  def tile_assign_value_func(arg_types, arg_values):
2262
- # return generic type (for doc builds)
2261
+ if arg_types is None:
2262
+ return None
2263
+
2264
+ # force the destination tile to shared memory
2265
+ arg_types["dst"].storage = "shared"
2263
2266
  return None
2264
2267
 
2265
2268
 
2269
+ def tile_assign_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2270
+ dst = args["dst"]
2271
+ src = args["src"]
2272
+
2273
+ offset = tile_unpack_offset(args, len(dst.type.shape))
2274
+
2275
+ func_args = (dst, src, *offset)
2276
+ template_args = []
2277
+
2278
+ return (func_args, template_args)
2279
+
2280
+
2266
2281
  add_builtin(
2267
2282
  "tile_assign",
2268
- input_types={"dst": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int, "src": Tile(dtype=Any, M=Any, N=Any)},
2283
+ input_types={"dst": Tile(dtype=Any, shape=Any), "src": Tile(dtype=Any, shape=Any), "offset": Tuple[int, ...]},
2269
2284
  value_func=tile_assign_value_func,
2270
- # dispatch_func=tile_assign_dispatch_func,
2271
- doc="""Assign a tile to a subrange of a destination tile at coordinates (i,j).
2285
+ dispatch_func=tile_assign_dispatch_func,
2286
+ defaults={"offset": None},
2287
+ doc="""Assign a tile to a subrange of a destination tile.
2272
2288
 
2273
- :param t: The destination tile to assign to
2274
- :param i: Offset in the source tile along the first dimension
2275
- :param j: Offset in the source tile along the second dimensions
2276
- :param src: The source tile to read values from""",
2289
+ :param dst: The destination tile to assign to
2290
+ :param src: The source tile to read values from
2291
+ :param offset: Offset in the destination tile to write to""",
2277
2292
  group="Tile Primitives",
2278
2293
  export=False,
2279
2294
  )
2280
2295
 
2296
+ # handles expressions like tile[i,j] = 1.0
2297
+ add_builtin(
2298
+ "assign",
2299
+ input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "src": Scalar},
2300
+ value_func=tile_assign_value_func,
2301
+ group="Tile Primitives",
2302
+ export=False,
2303
+ hidden=True,
2304
+ missing_grad=True,
2305
+ )
2306
+
2307
+ add_builtin(
2308
+ "assign",
2309
+ input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "src": Scalar},
2310
+ value_func=tile_assign_value_func,
2311
+ group="Tile Primitives",
2312
+ export=False,
2313
+ hidden=True,
2314
+ missing_grad=True,
2315
+ )
2316
+
2317
+ add_builtin(
2318
+ "assign",
2319
+ input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "src": Scalar},
2320
+ value_func=tile_assign_value_func,
2321
+ group="Tile Primitives",
2322
+ export=False,
2323
+ hidden=True,
2324
+ missing_grad=True,
2325
+ )
2326
+
2327
+ add_builtin(
2328
+ "assign",
2329
+ input_types={"dst": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int, "src": Scalar},
2330
+ value_func=tile_assign_value_func,
2331
+ group="Tile Primitives",
2332
+ export=False,
2333
+ hidden=True,
2334
+ missing_grad=True,
2335
+ )
2336
+
2281
2337
 
2282
2338
  def tile_value_func(arg_types, arg_values):
2283
2339
  # return generic type (for doc builds)
@@ -2285,7 +2341,7 @@ def tile_value_func(arg_types, arg_values):
2285
2341
  return Tile
2286
2342
 
2287
2343
  if len(arg_types) != 1:
2288
- raise RuntimeError("tile() requires 1 positional arg")
2344
+ raise TypeError(f"tile() takes exactly 1 positional argument but {len(arg_types)} were given")
2289
2345
 
2290
2346
  dtype = None
2291
2347
  length = None
@@ -2293,11 +2349,12 @@ def tile_value_func(arg_types, arg_values):
2293
2349
  if type_is_vector(arg_types["x"]):
2294
2350
  dtype = arg_types["x"]._wp_scalar_type_
2295
2351
  length = arg_types["x"]._shape_[0]
2352
+ shape = (length, warp.codegen.options["block_dim"])
2296
2353
  else:
2297
2354
  dtype = arg_types["x"]
2298
- length = 1
2355
+ shape = (warp.codegen.options["block_dim"],)
2299
2356
 
2300
- return Tile(dtype=dtype, M=length, N=warp.codegen.options["block_dim"], op="tile")
2357
+ return Tile(dtype=dtype, shape=shape, op="tile")
2301
2358
 
2302
2359
 
2303
2360
  add_builtin(
@@ -2305,14 +2362,14 @@ add_builtin(
2305
2362
  input_types={"x": Any},
2306
2363
  value_func=tile_value_func,
2307
2364
  variadic=True,
2308
- doc="""Constructs a new Tile from per-thread kernel values.
2365
+ doc="""Construct a new tile from per-thread kernel values.
2309
2366
 
2310
2367
  This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
2311
2368
 
2312
2369
  * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
2313
2370
  * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
2314
2371
 
2315
- :param x: A per-thread local value, e.g.: scalar, vector, or matrix.
2372
+ :param x: A per-thread local value, e.g. scalar, vector, or matrix.
2316
2373
  :returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
2317
2374
 
2318
2375
  This example shows how to create a linear sequence from thread variables:
@@ -2331,7 +2388,7 @@ add_builtin(
2331
2388
 
2332
2389
  .. code-block:: text
2333
2390
 
2334
- tile(m=1, n=16, storage=register) = [[0 2 4 6 8 ...]]
2391
+ [0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30] = tile(shape=(16), storage=register)
2335
2392
 
2336
2393
  """,
2337
2394
  group="Tile Primitives",
@@ -2345,38 +2402,40 @@ def untile_value_func(arg_types, arg_values):
2345
2402
  return Scalar
2346
2403
 
2347
2404
  if len(arg_types) != 1:
2348
- raise RuntimeError("untile() requires 1 positional arg")
2405
+ raise TypeError(f"untile() takes exactly 1 positional argument but {len(arg_types)} were given")
2349
2406
 
2350
2407
  t = arg_types["a"]
2351
2408
 
2352
2409
  if not is_tile(t):
2353
- raise RuntimeError(f"untile() accepts arguments of type tile only, got {arg_types[0]}")
2410
+ raise TypeError(f"untile() argument must be a tile, got {t!r}")
2354
2411
 
2355
- if t.N != warp.codegen.options["block_dim"]:
2356
- raise RuntimeError(
2357
- f"untile() argument must have the same length as the block width, got {t.N}, expected {warp.codegen.options['block_dim']}"
2412
+ if t.shape[-1] != warp.codegen.options["block_dim"]:
2413
+ raise ValueError(
2414
+ f"untile() argument last dimension {t.shape[-1]} does not match the expected block width {warp.codegen.options['block_dim']}"
2358
2415
  )
2359
2416
 
2360
- if t.M == 1:
2417
+ if len(t.shape) == 1:
2361
2418
  return t.dtype
2362
- elif t.M > 1:
2363
- return warp.types.vector(t.M, t.dtype)
2419
+ elif len(t.shape) == 2:
2420
+ return warp.types.vector(t.shape[0], t.dtype)
2421
+ else:
2422
+ raise ValueError(f"untile() argument must have a positive size in dimension 0, but got {t.shape[0]}")
2364
2423
 
2365
2424
 
2366
2425
  add_builtin(
2367
2426
  "untile",
2368
- input_types={"a": Any},
2427
+ input_types={"a": Tile(dtype=Any, shape=Any)},
2369
2428
  value_func=untile_value_func,
2370
2429
  variadic=True,
2371
- doc="""Convert a Tile back to per-thread values.
2430
+ doc="""Convert a tile back to per-thread values.
2372
2431
 
2373
2432
  This function converts a block-wide tile back to per-thread values.
2374
2433
 
2375
- * If the input tile is 1-dimensional then the resulting value will be a per-thread scalar
2376
- * If the input tile is 2-dimensional then the resulting value will be a per-thread vector of length M
2434
+ * If the input tile is 1D, then the resulting value will be a per-thread scalar
2435
+ * If the input tile is 2D, then the resulting value will be a per-thread vector of length M
2377
2436
 
2378
2437
  :param a: A tile with dimensions ``shape=(M, block_dim)``
2379
- :returns: A single value per-thread with the same dtype as the tile
2438
+ :returns: A single value per-thread with the same data type as the tile
2380
2439
 
2381
2440
  This example shows how to create a linear sequence from thread variables:
2382
2441
 
@@ -2390,7 +2449,7 @@ add_builtin(
2390
2449
  t = wp.tile(i)*2
2391
2450
 
2392
2451
  # convert back to per-thread values
2393
- s = wp.untile()
2452
+ s = wp.untile(t)
2394
2453
 
2395
2454
  print(s)
2396
2455
 
@@ -2417,21 +2476,38 @@ def tile_extract_value_func(arg_types, arg_values):
2417
2476
  if arg_types is None:
2418
2477
  return Scalar
2419
2478
 
2420
- if len(arg_types) != 3:
2421
- raise RuntimeError("tile_extract() requires 3 positional args")
2422
-
2423
- if not is_tile(arg_types["a"]):
2424
- raise RuntimeError("tile_extract() argument 0 must be a tile")
2479
+ # force the input tile to shared memory
2480
+ arg_types["a"].storage = "shared"
2425
2481
 
2426
2482
  return arg_types["a"].dtype
2427
2483
 
2428
2484
 
2429
2485
  add_builtin(
2430
2486
  "tile_extract",
2431
- input_types={"a": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int},
2487
+ input_types={"a": Tile(dtype=Any, shape=Any), "i": int},
2432
2488
  value_func=tile_extract_value_func,
2433
- variadic=True,
2434
- doc="""Extracts a single element from the tile and returns it as a scalar type.
2489
+ variadic=False,
2490
+ doc="""Extract a single element from the tile and return it as a scalar type.
2491
+
2492
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
2493
+
2494
+ Note that this may incur additional synchronization if the source tile is a register tile.
2495
+
2496
+ :param a: Tile to extract the element from
2497
+ :param i: Coordinate of element on first dimension
2498
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
2499
+ group="Tile Primitives",
2500
+ hidden=True,
2501
+ export=False,
2502
+ )
2503
+
2504
+
2505
+ add_builtin(
2506
+ "tile_extract",
2507
+ input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int},
2508
+ value_func=tile_extract_value_func,
2509
+ variadic=False,
2510
+ doc="""Extract a single element from the tile and return it as a scalar type.
2435
2511
 
2436
2512
  This function will extract an element from the tile and broadcast its value to all threads in the block.
2437
2513
 
@@ -2440,8 +2516,52 @@ add_builtin(
2440
2516
  :param a: Tile to extract the element from
2441
2517
  :param i: Coordinate of element on first dimension
2442
2518
  :param j: Coordinate of element on the second dimension
2443
- :returns: The value of the element at the specified tile location, with the same type as the input tile's per-element dtype""",
2519
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
2444
2520
  group="Tile Primitives",
2521
+ hidden=True,
2522
+ export=False,
2523
+ )
2524
+
2525
+ add_builtin(
2526
+ "tile_extract",
2527
+ input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int},
2528
+ value_func=tile_extract_value_func,
2529
+ variadic=False,
2530
+ doc="""Extract a single element from the tile and return it as a scalar type.
2531
+
2532
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
2533
+
2534
+ Note that this may incur additional synchronization if the source tile is a register tile.
2535
+
2536
+ :param a: Tile to extract the element from
2537
+ :param i: Coordinate of element on first dimension
2538
+ :param j: Coordinate of element on the second dimension
2539
+ :param k: Coordinate of element on the third dimension
2540
+ :returns: The value of the element at the specified tile location with the same data type as the input tile""",
2541
+ group="Tile Primitives",
2542
+ hidden=True,
2543
+ export=False,
2544
+ )
2545
+
2546
+ add_builtin(
2547
+ "tile_extract",
2548
+ input_types={"a": Tile(dtype=Any, shape=Any), "i": int, "j": int, "k": int, "l": int},
2549
+ value_func=tile_extract_value_func,
2550
+ variadic=False,
2551
+ doc="""Extract a single element from the tile and return it as a scalar type.
2552
+
2553
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
2554
+
2555
+ Note that this may incur additional synchronization if the source tile is a register tile.
2556
+
2557
+ :param a: Tile to extract the element from
2558
+ :param i: Coordinate of element on first dimension
2559
+ :param j: Coordinate of element on the second dimension
2560
+ :param k: Coordinate of element on the third dimension
2561
+ :param l: Coordinate of element on the fourth dimension
2562
+ :returns: The value of the element at the specified tile location, with the same data type as the input tile""",
2563
+ group="Tile Primitives",
2564
+ hidden=True,
2445
2565
  export=False,
2446
2566
  )
2447
2567
 
@@ -2452,12 +2572,12 @@ def tile_transpose_value_func(arg_types, arg_values):
2452
2572
  return Tile
2453
2573
 
2454
2574
  if len(arg_types) != 1:
2455
- raise RuntimeError("tile_transpose() requires 1 positional args")
2575
+ raise TypeError(f"tile_transpose() takes exactly 1 positional argument but {len(arg_types)} were given")
2456
2576
 
2457
2577
  t = arg_types["a"]
2458
2578
 
2459
2579
  if not is_tile(t):
2460
- raise RuntimeError("tile_transpose() argument 0 must be a tile")
2580
+ raise TypeError(f"tile_transpose() argument must be a tile, got {t!r}")
2461
2581
 
2462
2582
  layout = None
2463
2583
 
@@ -2472,8 +2592,7 @@ def tile_transpose_value_func(arg_types, arg_values):
2472
2592
 
2473
2593
  return Tile(
2474
2594
  dtype=t.dtype,
2475
- M=t.N,
2476
- N=t.M,
2595
+ shape=t.shape[::-1],
2477
2596
  op="transpose",
2478
2597
  storage=t.storage,
2479
2598
  strides=t.strides[::-1],
@@ -2484,12 +2603,13 @@ def tile_transpose_value_func(arg_types, arg_values):
2484
2603
 
2485
2604
  add_builtin(
2486
2605
  "tile_transpose",
2487
- input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
2606
+ input_types={"a": Tile(dtype=Any, shape=Any)},
2488
2607
  value_func=tile_transpose_value_func,
2489
2608
  variadic=True,
2490
2609
  doc="""Transpose a tile.
2491
2610
 
2492
- For shared memory tiles this operation will alias the input tile, register tiles will first be transferred to shared memory before transposition.
2611
+ For shared memory tiles, this operation will alias the input tile.
2612
+ Register tiles will first be transferred to shared memory before transposition.
2493
2613
 
2494
2614
  :param a: Tile to transpose with ``shape=(M,N)``
2495
2615
  :returns: Tile with ``shape=(N,M)``""",
@@ -2503,41 +2623,36 @@ def tile_broadcast_value_func(arg_types, arg_values):
2503
2623
  if arg_types is None:
2504
2624
  return Tile
2505
2625
 
2506
- if len(arg_types) != 3:
2507
- raise RuntimeError("tile_broadcast() requires 1 positional args")
2508
-
2509
2626
  t = arg_types["a"]
2510
- m = arg_values["m"]
2511
- n = arg_values["n"]
2512
2627
 
2513
- if not is_tile(t):
2514
- raise RuntimeError("tile_broadcast() argument 0 must be a tile")
2628
+ # target shape and strides
2629
+ target_shape = tile_unpack_shape(arg_values)
2630
+ target_strides = [0] * len(target_shape)
2515
2631
 
2516
- # try to broadcast last dimension
2517
- if t.N == 1:
2518
- stride_n = 0
2519
- elif t.N == n:
2520
- stride_n = t.strides[1]
2521
- else:
2522
- raise RuntimeError(
2523
- f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
2524
- )
2632
+ offset = len(target_shape) - len(t.shape)
2525
2633
 
2526
- # try to broadcast first dimension
2527
- if t.M == 1:
2528
- stride_m = 0
2529
- elif t.M == m:
2530
- stride_m = t.strides[0]
2531
- else:
2532
- raise RuntimeError(
2533
- f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
2534
- )
2634
+ # compute target strides
2635
+ for i in reversed(range(len(target_shape))):
2636
+ j = i - offset
2637
+
2638
+ if j < 0:
2639
+ target_strides[i] = 0
2640
+ else:
2641
+ # try to broadcast each dimension
2642
+ if t.shape[j] == 1:
2643
+ target_strides[i] = 0
2644
+ elif t.shape[j] == target_shape[i]:
2645
+ target_strides[i] = t.strides[j]
2646
+ else:
2647
+ raise ValueError(
2648
+ f"tile_broadcast() cannot broadcast dimension {t.shape[j]} into {target_shape[i]} at index {i}"
2649
+ )
2535
2650
 
2536
2651
  # force the input tile to shared memory
2537
2652
  t.storage = "shared"
2538
2653
 
2539
2654
  tile_type = Tile(
2540
- dtype=t.dtype, M=m, N=n, op="broadcast", storage=t.storage, strides=(stride_m, stride_n), owner=False
2655
+ dtype=t.dtype, shape=target_shape, op="broadcast", storage=t.storage, strides=target_strides, owner=False
2541
2656
  )
2542
2657
  return tile_type
2543
2658
 
@@ -2546,8 +2661,8 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
2546
2661
  tile = arg_values["a"]
2547
2662
 
2548
2663
  template_args = []
2549
- template_args.append(return_type.M)
2550
- template_args.append(return_type.N)
2664
+ template_args.append(return_type.shape[0])
2665
+ template_args.append(return_type.shape[1])
2551
2666
  template_args.append(return_type.strides[0])
2552
2667
  template_args.append(return_type.strides[1])
2553
2668
 
@@ -2556,15 +2671,18 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any
2556
2671
 
2557
2672
  add_builtin(
2558
2673
  "tile_broadcast",
2559
- input_types={"a": Tile(dtype=Any, M=Any, N=Any), "m": int, "n": int},
2674
+ input_types={"a": Tile(dtype=Any, shape=Any), "shape": Tuple[int, ...]},
2560
2675
  value_func=tile_broadcast_value_func,
2561
2676
  dispatch_func=tile_broadcast_dispatch_func,
2562
- variadic=True,
2677
+ variadic=False,
2563
2678
  doc="""Broadcast a tile.
2564
2679
 
2565
- This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
2680
+ This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n).
2681
+
2682
+ Broadcasting follows NumPy broadcast rules.
2566
2683
 
2567
2684
  :param a: Tile to broadcast
2685
+ :param shape: The shape to broadcast to
2568
2686
  :returns: Tile with broadcast ``shape=(m, n)``""",
2569
2687
  group="Tile Primitives",
2570
2688
  export=False,
@@ -2574,19 +2692,10 @@ add_builtin(
2574
2692
  def tile_matmul_value_func(arg_types, arg_values):
2575
2693
  # return generic type (for doc builds)
2576
2694
  if arg_types is None:
2577
- return Tile(dtype=Any, M=Any, N=Any)
2695
+ return Tile(dtype=Any, shape=Any)
2578
2696
 
2579
2697
  if len(arg_types) != 3:
2580
- raise RuntimeError("tile_matmul() requires 4 positional args")
2581
-
2582
- if not is_tile(arg_types["a"]):
2583
- raise RuntimeError("tile_matmul() argument 0 must be a tile")
2584
-
2585
- if not is_tile(arg_types["b"]):
2586
- raise RuntimeError("tile_matmul() argument 1 must be an tile")
2587
-
2588
- if not isinstance(arg_types["out"], Tile):
2589
- raise RuntimeError("tile_matmul() output argument must be a tile")
2698
+ raise TypeError(f"tile_matmul() takes exactly 3 positional arguments but {len(arg_types)} were given")
2590
2699
 
2591
2700
  return None
2592
2701
 
@@ -2621,17 +2730,17 @@ add_builtin(
2621
2730
  def tile_sum_value_func(arg_types, arg_values):
2622
2731
  # return generic type (for doc builds)
2623
2732
  if arg_types is None:
2624
- return Tile(dtype=Any, M=1, N=1)
2733
+ return Tile(dtype=Any, shape=(1,))
2625
2734
 
2626
2735
  if len(arg_types) != 1:
2627
- raise RuntimeError("tile_sum() requires 1 positional args")
2736
+ raise TypeError(f"tile_sum() takes exactly 1 positional argument but {len(arg_types)} were given")
2628
2737
 
2629
2738
  a = arg_types["a"]
2630
2739
 
2631
2740
  if not is_tile(a):
2632
- raise RuntimeError("tile_sum() argument 0 must be a tile")
2741
+ raise TypeError(f"tile_sum() argument must be a tile, got {a!r}")
2633
2742
 
2634
- return Tile(dtype=a.dtype, M=1, N=1, op="sum")
2743
+ return Tile(dtype=a.dtype, shape=(1,), op="sum")
2635
2744
 
2636
2745
 
2637
2746
  add_builtin(
@@ -2642,7 +2751,7 @@ add_builtin(
2642
2751
  doc="""Cooperatively compute the sum of the tile elements using all threads in the block.
2643
2752
 
2644
2753
  :param a: The tile to compute the sum of
2645
- :returns: A single-element tile with dimensions of (1,1) holding the sum
2754
+ :returns: A single-element tile holding the sum
2646
2755
 
2647
2756
  Example:
2648
2757
 
@@ -2651,18 +2760,18 @@ add_builtin(
2651
2760
  @wp.kernel
2652
2761
  def compute():
2653
2762
 
2654
- t = wp.tile_ones(dtype=float, m=16, n=16)
2763
+ t = wp.tile_ones(dtype=float, shape=(16, 16))
2655
2764
  s = wp.tile_sum(t)
2656
2765
 
2657
- print(t)
2766
+ print(s)
2658
2767
 
2659
- wp.launch(compute, dim=[64], inputs=[])
2768
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
2660
2769
 
2661
2770
  Prints:
2662
2771
 
2663
2772
  .. code-block:: text
2664
2773
 
2665
- tile(m=1, n=1, storage=register) = [[256]]
2774
+ [256] = tile(shape=(1), storage=register)
2666
2775
 
2667
2776
  """,
2668
2777
  group="Tile Primitives",
@@ -2673,17 +2782,17 @@ add_builtin(
2673
2782
  def tile_min_value_func(arg_types, arg_values):
2674
2783
  # return generic type (for doc builds)
2675
2784
  if arg_types is None:
2676
- return Tile(dtype=Any, M=1, N=1)
2785
+ return Tile(dtype=Any, shape=(1,))
2677
2786
 
2678
2787
  if len(arg_types) != 1:
2679
- raise RuntimeError("tile_min() requires 1 positional args")
2788
+ raise TypeError(f"tile_min() takes exactly 1 positional argument but {len(arg_types)} were given")
2680
2789
 
2681
2790
  a = arg_types["a"]
2682
2791
 
2683
2792
  if not is_tile(a):
2684
- raise RuntimeError("tile_min() argument 0 must be a tile")
2793
+ raise TypeError(f"tile_min() argument must be a tile, got {a!r}")
2685
2794
 
2686
- return Tile(dtype=a.dtype, M=1, N=1, op="min")
2795
+ return Tile(dtype=a.dtype, shape=(1,), op="min")
2687
2796
 
2688
2797
 
2689
2798
  add_builtin(
@@ -2694,7 +2803,7 @@ add_builtin(
2694
2803
  doc="""Cooperatively compute the minimum of the tile elements using all threads in the block.
2695
2804
 
2696
2805
  :param a: The tile to compute the minimum of
2697
- :returns: A single-element tile with dimensions of (1,1) holding the minimum value
2806
+ :returns: A single-element tile holding the minimum value
2698
2807
 
2699
2808
  Example:
2700
2809
 
@@ -2703,18 +2812,19 @@ add_builtin(
2703
2812
  @wp.kernel
2704
2813
  def compute():
2705
2814
 
2706
- t = wp.tile_arange(start=--10, stop=10, dtype=float)
2815
+ t = wp.tile_arange(64, 128)
2707
2816
  s = wp.tile_min(t)
2708
2817
 
2709
- print(t)
2818
+ print(s)
2819
+
2710
2820
 
2711
- wp.launch(compute, dim=[64], inputs=[])
2821
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
2712
2822
 
2713
2823
  Prints:
2714
2824
 
2715
2825
  .. code-block:: text
2716
2826
 
2717
- tile(m=1, n=1, storage=register) = [[-10]]
2827
+ [64] = tile(shape=(1), storage=register)
2718
2828
 
2719
2829
  """,
2720
2830
  group="Tile Primitives",
@@ -2725,28 +2835,28 @@ add_builtin(
2725
2835
  def tile_max_value_func(arg_types, arg_values):
2726
2836
  # return generic type (for doc builds)
2727
2837
  if arg_types is None:
2728
- return Tile(dtype=Any, M=1, N=1)
2838
+ return Tile(dtype=Any, shape=(1,))
2729
2839
 
2730
2840
  if len(arg_types) != 1:
2731
- raise RuntimeError("tile_max() requires 1 positional args")
2841
+ raise TypeError(f"tile_max() takes exactly 1 positional argument but {len(arg_types)} were given")
2732
2842
 
2733
2843
  a = arg_types["a"]
2734
2844
 
2735
2845
  if not is_tile(a):
2736
- raise RuntimeError("tile_max() argument 0 must be a tile")
2846
+ raise TypeError(f"tile_max() argument must be a tile, got {a!r}")
2737
2847
 
2738
- return Tile(dtype=a.dtype, M=1, N=1, op="min")
2848
+ return Tile(dtype=a.dtype, shape=(1,), op="min")
2739
2849
 
2740
2850
 
2741
2851
  add_builtin(
2742
2852
  "tile_max",
2743
- input_types={"a": Tile},
2853
+ input_types={"a": Tile(dtype=Any, shape=Any)},
2744
2854
  value_func=tile_max_value_func,
2745
- variadic=True,
2855
+ variadic=False,
2746
2856
  doc="""Cooperatively compute the maximum of the tile elements using all threads in the block.
2747
2857
 
2748
2858
  :param a: The tile to compute the maximum from
2749
- :returns: A single-element tile with dimensions of (1,1) holding the maximum value
2859
+ :returns: A single-element tile holding the maximum value
2750
2860
 
2751
2861
  Example:
2752
2862
 
@@ -2755,18 +2865,18 @@ add_builtin(
2755
2865
  @wp.kernel
2756
2866
  def compute():
2757
2867
 
2758
- t = wp.tile_arange(start=--10, stop=10, dtype=float)
2759
- s = wp.tile_min(t)
2868
+ t = wp.tile_arange(64, 128)
2869
+ s = wp.tile_max(t)
2760
2870
 
2761
- print(t)
2871
+ print(s)
2762
2872
 
2763
- wp.launch(compute, dim=[64], inputs=[])
2873
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
2764
2874
 
2765
2875
  Prints:
2766
2876
 
2767
2877
  .. code-block:: text
2768
2878
 
2769
- tile(m=1, n=1, storage=register) = [[10]]
2879
+ [127] = tile(shape=(1), storage=register)
2770
2880
 
2771
2881
  """,
2772
2882
  group="Tile Primitives",
@@ -2777,15 +2887,14 @@ add_builtin(
2777
2887
  # does type propagation for load()
2778
2888
  def tile_reduce_value_func(arg_types, arg_values):
2779
2889
  if arg_types is None:
2780
- return Tile(dtype=Any, M=Any, N=Any)
2890
+ return Tile(dtype=Any, shape=(1,))
2781
2891
 
2782
2892
  a = arg_types["a"]
2783
2893
 
2784
- # check all args are tiles
2785
2894
  if not is_tile(a):
2786
- raise RuntimeError(f"tile_reduce() arguments must be tiles, got type {a}")
2895
+ raise TypeError(f"tile_reduce() 'a' argument must be a tile, got {a!r}")
2787
2896
 
2788
- return Tile(dtype=a.dtype, M=1, N=1, op="reduce")
2897
+ return Tile(dtype=a.dtype, shape=(1,), op="reduce")
2789
2898
 
2790
2899
 
2791
2900
  def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
@@ -2796,7 +2905,7 @@ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any,
2796
2905
 
2797
2906
  add_builtin(
2798
2907
  "tile_reduce",
2799
- input_types={"op": Callable, "a": Any},
2908
+ input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
2800
2909
  value_func=tile_reduce_value_func,
2801
2910
  native_func="tile_reduce",
2802
2911
  doc="""Apply a custom reduction operator across the tile.
@@ -2804,8 +2913,8 @@ add_builtin(
2804
2913
  This function cooperatively performs a reduction using the provided operator across the tile.
2805
2914
 
2806
2915
  :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
2807
- :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
2808
- :returns: A single-element tile with ``shape=(1,1)`` with the same datatype as the input tile.
2916
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
2917
+ :returns: A single-element tile with the same data type as the input tile.
2809
2918
 
2810
2919
  Example:
2811
2920
 
@@ -2819,13 +2928,13 @@ add_builtin(
2819
2928
 
2820
2929
  print(s)
2821
2930
 
2822
- wp.launch(factorial, dim=[16], inputs=[], block_dim=16)
2931
+ wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
2823
2932
 
2824
2933
  Prints:
2825
2934
 
2826
2935
  .. code-block:: text
2827
2936
 
2828
- tile(m=1, n=1, storage=register) = [[362880]]
2937
+ [362880] = tile(shape=(1), storage=register)
2829
2938
  """,
2830
2939
  group="Tile Primitives",
2831
2940
  export=False,
@@ -2837,26 +2946,19 @@ add_builtin(
2837
2946
  # does type propagation for load()
2838
2947
  def tile_unary_map_value_func(arg_types, arg_values):
2839
2948
  if arg_types is None:
2840
- return Tile(dtype=Any, M=Any, N=Any)
2949
+ return Tile(dtype=Any, shape=Any)
2841
2950
 
2842
2951
  a = arg_types["a"]
2843
2952
 
2844
- # check all args are tiles
2845
2953
  if not is_tile(a):
2846
- raise RuntimeError(f"tile_map() arguments must be tiles, got type {a}")
2954
+ raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
2847
2955
 
2848
2956
  return TileUnaryMap(a)
2849
2957
 
2850
2958
 
2851
- # def tile_map_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2852
- # func_args = (args["op"], *args["args"])
2853
- # template_args = ()
2854
- # return (func_args, template_args)
2855
-
2856
-
2857
2959
  add_builtin(
2858
2960
  "tile_map",
2859
- input_types={"op": Callable, "a": Any},
2961
+ input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any)},
2860
2962
  value_func=tile_unary_map_value_func,
2861
2963
  # dispatch_func=tile_map_dispatch_func,
2862
2964
  # variadic=True,
@@ -2866,8 +2968,8 @@ add_builtin(
2866
2968
  This function cooperatively applies a unary function to each element of the tile using all threads in the block.
2867
2969
 
2868
2970
  :param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
2869
- :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
2870
- :returns: A tile with the same dimensions and datatype as the input tile.
2971
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
2972
+ :returns: A tile with the same dimensions and data type as the input tile.
2871
2973
 
2872
2974
  Example:
2873
2975
 
@@ -2881,13 +2983,13 @@ add_builtin(
2881
2983
 
2882
2984
  print(s)
2883
2985
 
2884
- wp.launch(compute, dim=[16], inputs=[])
2986
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
2885
2987
 
2886
2988
  Prints:
2887
2989
 
2888
2990
  .. code-block:: text
2889
2991
 
2890
- tile(m=1, n=10, storage=register) = [[0 0.0998334 0.198669 0.29552 ...]]
2992
+ [0 0.0998334 0.198669 0.29552 0.389418 0.479426 0.564642 0.644218 0.717356 0.783327] = tile(shape=(10), storage=register)
2891
2993
  """,
2892
2994
  group="Tile Primitives",
2893
2995
  export=False,
@@ -2896,34 +2998,37 @@ add_builtin(
2896
2998
 
2897
2999
  def tile_binary_map_value_func(arg_types, arg_values):
2898
3000
  if arg_types is None:
2899
- return Tile(dtype=Any, M=Any, N=Any)
3001
+ return Tile(dtype=Any, shape=Any)
2900
3002
 
2901
3003
  a = arg_types["a"]
2902
3004
  b = arg_types["b"]
2903
3005
 
2904
3006
  # check all args are tiles
2905
3007
  if not is_tile(a):
2906
- raise RuntimeError(f"tile_map() arguments must be tiles, got type {a}")
3008
+ raise TypeError(f"tile_map() 'a' argument must be a tile, got {a!r}")
2907
3009
 
2908
3010
  if not is_tile(b):
2909
- raise RuntimeError(f"tile_map() arguments must be tiles, got type {b}")
3011
+ raise TypeError(f"tile_map() 'b' argument must be a tile, got {b!r}")
2910
3012
 
2911
- # use first argument to define output type
3013
+ # ensure types equal
2912
3014
  if not types_equal(a.dtype, b.dtype):
2913
- raise RuntimeError(f"tile_map() arguments must all have the same type {a.dtype} != {b.dtype}")
3015
+ raise TypeError(f"tile_map() arguments must have the same dtype, got {a.dtype} and {b.dtype}")
2914
3016
 
2915
- if a.M != b.M:
2916
- raise RuntimeError(f"tile_map() arguments must all have the same m dimension {a.M} != {b.M}")
3017
+ if len(a.shape) != len(b.shape):
3018
+ raise ValueError(
3019
+ f"tile_map() shapes must have the same number of dimensions, got {len(a.shape)} and {len(b.shape)}"
3020
+ )
2917
3021
 
2918
- if a.N != b.N:
2919
- raise RuntimeError(f"tile_map() arguments must all have the same n dimension {a.N} != {b.N}")
3022
+ for i in range(len(a.shape)):
3023
+ if a.shape[i] != b.shape[i]:
3024
+ raise ValueError(f"tile_map() shapes do not match on dimension {i}, got {a.shape[i]} and {b.shape[i]}")
2920
3025
 
2921
3026
  return TileBinaryMap(a, b)
2922
3027
 
2923
3028
 
2924
3029
  add_builtin(
2925
3030
  "tile_map",
2926
- input_types={"op": Callable, "a": Any, "b": Any},
3031
+ input_types={"op": Callable, "a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
2927
3032
  value_func=tile_binary_map_value_func,
2928
3033
  # dispatch_func=tile_map_dispatch_func,
2929
3034
  # variadic=True,
@@ -2946,19 +3051,19 @@ add_builtin(
2946
3051
  def compute():
2947
3052
 
2948
3053
  a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
2949
- b = wp.tile_ones(m=1, n=10, dtype=float)
3054
+ b = wp.tile_ones(shape=10, dtype=float)
2950
3055
 
2951
3056
  s = wp.tile_map(wp.add, a, b)
2952
3057
 
2953
3058
  print(s)
2954
3059
 
2955
- wp.launch(compute, dim=[16], inputs=[])
3060
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
2956
3061
 
2957
3062
  Prints:
2958
3063
 
2959
3064
  .. code-block:: text
2960
3065
 
2961
- tile(m=1, n=10, storage=register) = [[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]]""",
3066
+ [1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9] = tile(shape=(10), storage=register)""",
2962
3067
  group="Tile Primitives",
2963
3068
  export=False,
2964
3069
  )
@@ -3073,6 +3178,18 @@ add_builtin(
3073
3178
  )
3074
3179
 
3075
3180
 
3181
+ def mlp_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
3182
+ warp.utils.warn(
3183
+ "wp.mlp() is deprecated and will be removed in a future\nversion. Use tile primitives instead.",
3184
+ category=DeprecationWarning,
3185
+ )
3186
+
3187
+ func_args = tuple(args.values())
3188
+ template_args = ()
3189
+
3190
+ return (func_args, template_args)
3191
+
3192
+
3076
3193
  add_builtin(
3077
3194
  "mlp",
3078
3195
  input_types={
@@ -3084,9 +3201,13 @@ add_builtin(
3084
3201
  "out": array(dtype=float, ndim=2),
3085
3202
  },
3086
3203
  value_type=None,
3204
+ dispatch_func=mlp_dispatch_func,
3087
3205
  skip_replay=True,
3088
3206
  doc="""Evaluate a multi-layer perceptron (MLP) layer in the form: ``out = act(weights*x + bias)``.
3089
3207
 
3208
+ .. deprecated:: 1.6
3209
+ Use :doc:`tile primitives </modules/tiles>` instead.
3210
+
3090
3211
  :param weights: A layer's network weights with dimensions ``(m, n)``.
3091
3212
  :param bias: An array with dimensions ``(n)``.
3092
3213
  :param activation: A ``wp.func`` function that takes a single scalar float as input and returns a scalar float as output
@@ -4665,6 +4786,19 @@ def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str,
4665
4786
  return arr_type.dtype
4666
4787
 
4667
4788
 
4789
+ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
4790
+ # as this is a codegen callback, we can mark the fact that this func writes to an array here
4791
+ if warp.config.verify_autograd_array_access:
4792
+ arr = args["arr"]
4793
+ arr.mark_write()
4794
+
4795
+ func_args = tuple(args.values())
4796
+ # we don't need to specify template arguments for atomic ops
4797
+ template_args = ()
4798
+
4799
+ return (func_args, template_args)
4800
+
4801
+
4668
4802
  for array_type in array_types:
4669
4803
  # don't list indexed array operations explicitly in docs
4670
4804
  hidden = array_type == indexedarray
@@ -4675,6 +4809,7 @@ for array_type in array_types:
4675
4809
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4676
4810
  constraint=atomic_op_constraint,
4677
4811
  value_func=atomic_op_value_func,
4812
+ dispatch_func=atomic_op_dispatch_func,
4678
4813
  doc="Atomically add ``value`` onto ``arr[i]`` and return the old value.",
4679
4814
  group="Utility",
4680
4815
  skip_replay=True,
@@ -4685,6 +4820,7 @@ for array_type in array_types:
4685
4820
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4686
4821
  constraint=atomic_op_constraint,
4687
4822
  value_func=atomic_op_value_func,
4823
+ dispatch_func=atomic_op_dispatch_func,
4688
4824
  doc="Atomically add ``value`` onto ``arr[i,j]`` and return the old value.",
4689
4825
  group="Utility",
4690
4826
  skip_replay=True,
@@ -4695,6 +4831,7 @@ for array_type in array_types:
4695
4831
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4696
4832
  constraint=atomic_op_constraint,
4697
4833
  value_func=atomic_op_value_func,
4834
+ dispatch_func=atomic_op_dispatch_func,
4698
4835
  doc="Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value.",
4699
4836
  group="Utility",
4700
4837
  skip_replay=True,
@@ -4705,6 +4842,7 @@ for array_type in array_types:
4705
4842
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4706
4843
  constraint=atomic_op_constraint,
4707
4844
  value_func=atomic_op_value_func,
4845
+ dispatch_func=atomic_op_dispatch_func,
4708
4846
  doc="Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
4709
4847
  group="Utility",
4710
4848
  skip_replay=True,
@@ -4716,6 +4854,7 @@ for array_type in array_types:
4716
4854
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4717
4855
  constraint=atomic_op_constraint,
4718
4856
  value_func=atomic_op_value_func,
4857
+ dispatch_func=atomic_op_dispatch_func,
4719
4858
  doc="Atomically subtract ``value`` onto ``arr[i]`` and return the old value.",
4720
4859
  group="Utility",
4721
4860
  skip_replay=True,
@@ -4726,6 +4865,7 @@ for array_type in array_types:
4726
4865
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4727
4866
  constraint=atomic_op_constraint,
4728
4867
  value_func=atomic_op_value_func,
4868
+ dispatch_func=atomic_op_dispatch_func,
4729
4869
  doc="Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value.",
4730
4870
  group="Utility",
4731
4871
  skip_replay=True,
@@ -4736,6 +4876,7 @@ for array_type in array_types:
4736
4876
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4737
4877
  constraint=atomic_op_constraint,
4738
4878
  value_func=atomic_op_value_func,
4879
+ dispatch_func=atomic_op_dispatch_func,
4739
4880
  doc="Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value.",
4740
4881
  group="Utility",
4741
4882
  skip_replay=True,
@@ -4746,6 +4887,7 @@ for array_type in array_types:
4746
4887
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4747
4888
  constraint=atomic_op_constraint,
4748
4889
  value_func=atomic_op_value_func,
4890
+ dispatch_func=atomic_op_dispatch_func,
4749
4891
  doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
4750
4892
  group="Utility",
4751
4893
  skip_replay=True,
@@ -4757,6 +4899,7 @@ for array_type in array_types:
4757
4899
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4758
4900
  constraint=atomic_op_constraint,
4759
4901
  value_func=atomic_op_value_func,
4902
+ dispatch_func=atomic_op_dispatch_func,
4760
4903
  doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
4761
4904
 
4762
4905
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4769,6 +4912,7 @@ for array_type in array_types:
4769
4912
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4770
4913
  constraint=atomic_op_constraint,
4771
4914
  value_func=atomic_op_value_func,
4915
+ dispatch_func=atomic_op_dispatch_func,
4772
4916
  doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
4773
4917
 
4774
4918
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4781,6 +4925,7 @@ for array_type in array_types:
4781
4925
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4782
4926
  constraint=atomic_op_constraint,
4783
4927
  value_func=atomic_op_value_func,
4928
+ dispatch_func=atomic_op_dispatch_func,
4784
4929
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
4785
4930
 
4786
4931
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4793,6 +4938,7 @@ for array_type in array_types:
4793
4938
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4794
4939
  constraint=atomic_op_constraint,
4795
4940
  value_func=atomic_op_value_func,
4941
+ dispatch_func=atomic_op_dispatch_func,
4796
4942
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
4797
4943
 
4798
4944
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4806,6 +4952,7 @@ for array_type in array_types:
4806
4952
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4807
4953
  constraint=atomic_op_constraint,
4808
4954
  value_func=atomic_op_value_func,
4955
+ dispatch_func=atomic_op_dispatch_func,
4809
4956
  doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
4810
4957
 
4811
4958
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4818,6 +4965,7 @@ for array_type in array_types:
4818
4965
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4819
4966
  constraint=atomic_op_constraint,
4820
4967
  value_func=atomic_op_value_func,
4968
+ dispatch_func=atomic_op_dispatch_func,
4821
4969
  doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
4822
4970
 
4823
4971
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4830,6 +4978,7 @@ for array_type in array_types:
4830
4978
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4831
4979
  constraint=atomic_op_constraint,
4832
4980
  value_func=atomic_op_value_func,
4981
+ dispatch_func=atomic_op_dispatch_func,
4833
4982
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
4834
4983
 
4835
4984
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4842,6 +4991,7 @@ for array_type in array_types:
4842
4991
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4843
4992
  constraint=atomic_op_constraint,
4844
4993
  value_func=atomic_op_value_func,
4994
+ dispatch_func=atomic_op_dispatch_func,
4845
4995
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
4846
4996
 
4847
4997
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4977,6 +5127,43 @@ add_builtin(
4977
5127
  )
4978
5128
 
4979
5129
 
5130
+ # implements vector[idx] += scalar
5131
+ add_builtin(
5132
+ "augassign_add",
5133
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5134
+ value_type=None,
5135
+ hidden=True,
5136
+ group="Utility",
5137
+ )
5138
+
5139
+ # implements quaternion[idx] += scalar
5140
+ add_builtin(
5141
+ "augassign_add",
5142
+ input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5143
+ value_type=None,
5144
+ hidden=True,
5145
+ group="Utility",
5146
+ )
5147
+
5148
+ # implements vector[idx] -= scalar
5149
+ add_builtin(
5150
+ "augassign_sub",
5151
+ input_types={"a": vector(length=Any, dtype=Scalar), "i": int, "value": Scalar},
5152
+ value_type=None,
5153
+ hidden=True,
5154
+ group="Utility",
5155
+ )
5156
+
5157
+ # implements quaternion[idx] -= scalar
5158
+ add_builtin(
5159
+ "augassign_sub",
5160
+ input_types={"a": quaternion(dtype=Scalar), "i": int, "value": Scalar},
5161
+ value_type=None,
5162
+ hidden=True,
5163
+ group="Utility",
5164
+ )
5165
+
5166
+
4980
5167
  def matrix_index_row_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
4981
5168
  mat_type = arg_types["a"]
4982
5169
  row_type = mat_type._wp_row_type_
@@ -5026,22 +5213,42 @@ def matrix_vector_sametype(arg_types: Mapping[str, Any]):
5026
5213
  return mat_size == vec_size and mat_type == vec_type
5027
5214
 
5028
5215
 
5029
- # implements matrix[i,j] = scalar
5216
+ # implements matrix[i,j] = scalar
5217
+ add_builtin(
5218
+ "assign",
5219
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5220
+ value_func=matrix_assign_value_func,
5221
+ hidden=True,
5222
+ group="Utility",
5223
+ )
5224
+
5225
+
5226
+ # implements matrix[i] = vector
5227
+ add_builtin(
5228
+ "assign",
5229
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5230
+ constraint=matrix_vector_sametype,
5231
+ value_func=matrix_assign_value_func,
5232
+ hidden=True,
5233
+ group="Utility",
5234
+ )
5235
+
5236
+
5237
+ # implements matrix[i,j] += scalar
5030
5238
  add_builtin(
5031
- "assign",
5239
+ "augassign_add",
5032
5240
  input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5033
- value_func=matrix_assign_value_func,
5241
+ value_type=None,
5034
5242
  hidden=True,
5035
5243
  group="Utility",
5036
5244
  )
5037
5245
 
5038
5246
 
5039
- # implements matrix[i] = vector
5247
+ # implements matrix[i,j] -= scalar
5040
5248
  add_builtin(
5041
- "assign",
5042
- input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "value": vector(length=Any, dtype=Scalar)},
5043
- constraint=matrix_vector_sametype,
5044
- value_func=matrix_assign_value_func,
5249
+ "augassign_sub",
5250
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar), "i": int, "j": int, "value": Scalar},
5251
+ value_type=None,
5045
5252
  hidden=True,
5046
5253
  group="Utility",
5047
5254
  )
@@ -5613,19 +5820,19 @@ add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.boo
5613
5820
  # Tile operators
5614
5821
  def tile_unary_value_func(arg_types, arg_values):
5615
5822
  if arg_types is None:
5616
- return Tile(dtype=Any, M=Any, N=Any)
5823
+ return Tile(dtype=Any, shape=Any)
5617
5824
 
5618
5825
  t = arg_types["x"]
5619
5826
 
5620
5827
  if not is_tile(t):
5621
- raise RuntimeError("Expected tile for unary expression")
5828
+ raise TypeError(f"Expected tile for unary expression, got {t}")
5622
5829
 
5623
5830
  return TileUnaryMap(t)
5624
5831
 
5625
5832
 
5626
5833
  def tile_scalar_mul_value_func(arg_types, arg_values):
5627
5834
  if arg_types is None:
5628
- return Tile(dtype=Any, M=Any, N=Any)
5835
+ return Tile(dtype=Any, shape=Any)
5629
5836
 
5630
5837
  x = arg_types["x"]
5631
5838
  y = arg_types["y"]
@@ -5633,25 +5840,21 @@ def tile_scalar_mul_value_func(arg_types, arg_values):
5633
5840
  # tile*scalar
5634
5841
  if is_tile(x):
5635
5842
  if x.dtype != y:
5636
- raise RuntimeError(
5637
- "Scalar factor should have the same type as tile for tile*scalar, tile type: {x} scalar type: {y}"
5638
- )
5843
+ raise TypeError(f"Scalar factor type {y} does not match tile type {x.dtype} for tile*scalar")
5639
5844
 
5640
- return TileBinaryMap(x, TileConstant(y, x.M, x.N))
5845
+ return TileBinaryMap(x, TileConstant(y, x.shape))
5641
5846
 
5642
5847
  # scalar*tile
5643
5848
  if is_tile(y):
5644
5849
  if y.dtype != x:
5645
- raise RuntimeError(
5646
- "Scalar factor should have the same type as tile for scalar*tile, tile type: {x} scalar type: {y}"
5647
- )
5850
+ raise TypeError(f"Scalar factor type {x} does not match tile type {y.dtype} for scalar*tile")
5648
5851
 
5649
- return TileBinaryMap(TileConstant(x, y.M, y.N), y)
5852
+ return TileBinaryMap(TileConstant(x, y.shape), y)
5650
5853
 
5651
5854
 
5652
5855
  add_builtin(
5653
5856
  "neg",
5654
- input_types={"x": Tile(dtype=Any, M=Any, N=Any)},
5857
+ input_types={"x": Tile(dtype=Any, shape=Any)},
5655
5858
  value_func=tile_unary_value_func,
5656
5859
  doc="Negate each element of a tile",
5657
5860
  export=False,
@@ -5661,7 +5864,7 @@ add_builtin(
5661
5864
 
5662
5865
  add_builtin(
5663
5866
  "add",
5664
- input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
5867
+ input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
5665
5868
  value_func=tile_binary_map_value_func,
5666
5869
  # dispatch_func=tile_map_dispatch_func,
5667
5870
  # variadic=True,
@@ -5671,9 +5874,22 @@ add_builtin(
5671
5874
  export=False,
5672
5875
  )
5673
5876
 
5877
+ add_builtin(
5878
+ "sub",
5879
+ input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
5880
+ value_func=tile_binary_map_value_func,
5881
+ # dispatch_func=tile_map_dispatch_func,
5882
+ # variadic=True,
5883
+ native_func="tile_sub",
5884
+ doc="Subtract each element b from a",
5885
+ group="Tile Primitives",
5886
+ export=False,
5887
+ )
5888
+
5889
+
5674
5890
  add_builtin(
5675
5891
  "mul",
5676
- input_types={"x": Tile(dtype=Any, M=Any, N=Any), "y": Scalar},
5892
+ input_types={"x": Tile(dtype=Any, shape=Any), "y": Scalar},
5677
5893
  value_func=tile_scalar_mul_value_func,
5678
5894
  doc="Multiply each element of a tile by a scalar",
5679
5895
  export=False,
@@ -5683,7 +5899,7 @@ add_builtin(
5683
5899
 
5684
5900
  add_builtin(
5685
5901
  "mul",
5686
- input_types={"x": Scalar, "y": Tile(dtype=Any, M=Any, N=Any)},
5902
+ input_types={"x": Scalar, "y": Tile(dtype=Any, shape=Any)},
5687
5903
  value_func=tile_scalar_mul_value_func,
5688
5904
  doc="Multiply each element of a tile by a scalar",
5689
5905
  export=False,
@@ -5692,6 +5908,70 @@ add_builtin(
5692
5908
  )
5693
5909
 
5694
5910
 
5911
+ def tile_diag_add_value_func(arg_types, arg_values):
5912
+ if arg_types is None:
5913
+ return Tile(dtype=Any, shape=Any)
5914
+
5915
+ a = arg_types["a"]
5916
+ d = arg_types["d"]
5917
+
5918
+ if not is_tile(a):
5919
+ raise TypeError(f"tile_diag_add() 'a' argument must be a tile, got {a!r}")
5920
+
5921
+ if not is_tile(d):
5922
+ raise TypeError(f"tile_diag_add() 'd' argument must be a tile, got {d!r}")
5923
+
5924
+ if not types_equal(a.dtype, d.dtype):
5925
+ raise TypeError(f"tile_diag_add() arguments must have the same dtype, got {a.dtype} and {d.dtype}")
5926
+
5927
+ if len(a.shape) != 2:
5928
+ raise TypeError("tile_diag_add() argument 'a' must be a 2D tile")
5929
+
5930
+ if len(d.shape) != 1:
5931
+ raise TypeError("tile_diag_add() argument 'd' must be a 1D tile")
5932
+
5933
+ if a.shape[0] != a.shape[1]:
5934
+ raise ValueError("tile_diag_add() 'a' argument must be square")
5935
+
5936
+ if a.shape[0] != d.shape[0]:
5937
+ raise ValueError(
5938
+ f"tile_diag_add() 'd' argument must have the same number of elements as the number of rows in 'a', "
5939
+ f"got {d.shape[0]} elements in 'd' and {a.shape[0]} rows in 'a'"
5940
+ )
5941
+
5942
+ # use first argument to define output type
5943
+ return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
5944
+
5945
+
5946
+ def tile_diag_add_lto_dispatch_func(
5947
+ arg_types: Mapping[str, type],
5948
+ return_type: Any,
5949
+ return_values: List[Var],
5950
+ arg_values: Mapping[str, Var],
5951
+ options: Mapping[str, Any],
5952
+ builder: warp.context.ModuleBuilder,
5953
+ ):
5954
+ a = arg_values["a"]
5955
+ d = arg_values["d"]
5956
+ # force the storage type of the input variables to shared memory
5957
+ a.type.storage = "shared"
5958
+ d.type.storage = "shared"
5959
+ out = return_values[0]
5960
+ return ((a, d, out), [], [], 0)
5961
+
5962
+
5963
+ add_builtin(
5964
+ "tile_diag_add",
5965
+ input_types={"a": Tile(dtype=Any, shape=Any), "d": Tile(dtype=Any, shape=Any)},
5966
+ value_func=tile_diag_add_value_func,
5967
+ lto_dispatch_func=tile_diag_add_lto_dispatch_func,
5968
+ native_func="tile_diag_add",
5969
+ doc="Add a square matrix and a diagonal matrix 'd' represented as a 1D tile",
5970
+ group="Tile Primitives",
5971
+ export=False,
5972
+ )
5973
+
5974
+
5695
5975
  ##
5696
5976
  ## MathDx, LTOIR-based, Tile functions
5697
5977
  ##
@@ -5703,24 +5983,25 @@ add_builtin(
5703
5983
  def tile_matmul_generic_value_func(arg_types, arg_values):
5704
5984
  # return generic type (for doc builds)
5705
5985
  if arg_types is None:
5706
- return Tile(dtype=Any, M=Any, N=Any)
5986
+ return Tile(dtype=Any, shape=Any)
5707
5987
 
5708
5988
  a = arg_types["a"]
5709
5989
  b = arg_types["b"]
5710
5990
 
5711
5991
  if not is_tile(a):
5712
- raise RuntimeError("tile_matmul() argument 0 must be a tile")
5992
+ raise TypeError(f"tile_matmul() 'a' argument must be a tile, got {a!r}")
5993
+
5713
5994
  if not is_tile(b):
5714
- raise RuntimeError("tile_matmul() argument 1 must be an tile")
5995
+ raise TypeError(f"tile_matmul() 'b' argument must be a tile, got {b!r}")
5715
5996
 
5716
5997
  # out = wp.tile_matmul(a, b)
5717
5998
  if len(arg_types) == 2:
5718
- return Tile(dtype=a.dtype, M=a.M, N=b.N, storage="shared")
5999
+ return Tile(dtype=a.dtype, shape=(a.shape[0], b.shape[1]), storage="shared")
5719
6000
 
5720
6001
  # wp.tile_matmul(a, b, out)
5721
6002
  elif len(arg_types) == 3:
5722
6003
  if not is_tile(arg_types["out"]):
5723
- raise RuntimeError("tile_matmul() output argument must be a tile")
6004
+ raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {arg_types['out']!r}")
5724
6005
 
5725
6006
  return None
5726
6007
 
@@ -5743,16 +6024,20 @@ def tile_matmul_generic_lto_dispatch_func(
5743
6024
  accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value
5744
6025
  out = arg_values["out"]
5745
6026
 
5746
- if any(not is_tile(arg.type) for arg in [a, b, out]):
5747
- raise RuntimeError("tile_matmul() requires three Tile arguments")
6027
+ if not is_tile(out.type):
6028
+ raise TypeError(f"tile_matmul() 'out' argument must be a tile, got {out!r}")
5748
6029
 
5749
6030
  if any(arg.type.dtype not in [float16, float32, float64, vec2h, vec2f, vec2d] for arg in [a, b, out]):
5750
- raise RuntimeError(
6031
+ raise TypeError(
5751
6032
  "tile_matmul() arguments must be tiles of float16, float32 or float64, vec2h, vec2f, vec2d entries"
5752
6033
  )
5753
6034
 
5754
- if (a.type.N != b.type.M) or (a.type.M != out.type.M) or (b.type.N != out.type.N):
5755
- raise RuntimeError("tile_matmul(A, B, C) requires sizes of A, B and C to be consistent for a matmul")
6035
+ if (
6036
+ (a.type.shape[1] != b.type.shape[0])
6037
+ or (a.type.shape[0] != out.type.shape[0])
6038
+ or (b.type.shape[1] != out.type.shape[1])
6039
+ ):
6040
+ raise ValueError("tile_matmul(A, B, C) requires sizes of A, B and C to be consistent for a matmul")
5756
6041
 
5757
6042
  # set the storage type to the inputs to shared
5758
6043
  a.type.storage = "shared"
@@ -5774,18 +6059,18 @@ def tile_matmul_generic_lto_dispatch_func(
5774
6059
  return ("wp::vec2f", 5, 1)
5775
6060
  if dtype == vec2d:
5776
6061
  return ("wp::vec2d", 6, 1)
5777
- raise RuntimeError("Unsupported input type in tile_matmul")
6062
+ raise TypeError("Unsupported input type in tile_matmul")
5778
6063
 
5779
6064
  def cublasdx_arrangement_map(layout):
5780
6065
  if layout == "colmajor":
5781
6066
  return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
5782
6067
  if layout == "rowmajor":
5783
6068
  return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
5784
- raise RuntimeError("Unsupported layout in tile_matmul")
6069
+ raise ValueError("Unsupported layout in tile_matmul")
5785
6070
 
5786
6071
  # generate the LTO
5787
- M, K = a.type.M, a.type.N
5788
- _, N = b.type.M, b.type.N
6072
+ M, K = a.type.shape[0], a.type.shape[1]
6073
+ _, N = b.type.shape[0], b.type.shape[1]
5789
6074
  num_threads = options["block_dim"]
5790
6075
  arch = options["output_arch"]
5791
6076
 
@@ -5798,7 +6083,8 @@ def tile_matmul_generic_lto_dispatch_func(
5798
6083
  c_arrangement = cublasdx_arrangement_map(clayout)
5799
6084
 
5800
6085
  if a_type != b_type or a_type != c_type:
5801
- raise RuntimeError("time_matmul(A, B, C) requires all inputs to be real or complex")
6086
+ raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
6087
+
5802
6088
  element_type = a_type
5803
6089
 
5804
6090
  lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
@@ -5893,15 +6179,16 @@ def tile_matmul_generic_lto_dispatch_func(
5893
6179
  ),
5894
6180
  template_args,
5895
6181
  [lto_forward, lto_backward_A, lto_backward_B],
6182
+ 0,
5896
6183
  )
5897
6184
 
5898
6185
 
5899
6186
  add_builtin(
5900
6187
  "tile_matmul",
5901
6188
  input_types={
5902
- "a": Tile(dtype=Any, M=Any, N=Any),
5903
- "b": Tile(dtype=Any, M=Any, N=Any),
5904
- "out": Tile(dtype=Any, M=Any, N=Any),
6189
+ "a": Tile(dtype=Any, shape=Any),
6190
+ "b": Tile(dtype=Any, shape=Any),
6191
+ "out": Tile(dtype=Any, shape=Any),
5905
6192
  },
5906
6193
  value_func=tile_matmul_generic_value_func,
5907
6194
  lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
@@ -5925,7 +6212,7 @@ add_builtin(
5925
6212
 
5926
6213
  add_builtin(
5927
6214
  "tile_matmul",
5928
- input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
6215
+ input_types={"a": Tile(dtype=Any, shape=Any), "b": Tile(dtype=Any, shape=Any)},
5929
6216
  value_func=tile_matmul_generic_value_func,
5930
6217
  lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
5931
6218
  variadic=False,
@@ -5952,16 +6239,23 @@ add_builtin(
5952
6239
  ##
5953
6240
  def tile_fft_generic_value_func(arg_types, arg_values):
5954
6241
  if arg_types is None:
5955
- return Tile(dtype=Any, M=Any, N=Any)
6242
+ return Tile(dtype=Any, shape=Any)
5956
6243
 
5957
6244
  if len(arg_types) != 1:
5958
- raise RuntimeError("tile_fft() requires 1 positional args")
6245
+ raise TypeError(f"tile_fft() takes exactly 1 positional argument but {len(arg_types)} were given")
5959
6246
 
5960
- if not is_tile(arg_types["inout"]):
5961
- raise RuntimeError("tile_fft() argument 0 must be a tile")
6247
+ inout = arg_types["inout"]
5962
6248
 
5963
- if arg_types["inout"].storage != "register":
5964
- raise RuntimeError("tile_fft() input/output argument must have register memory storage")
6249
+ if not is_tile(inout):
6250
+ raise TypeError(f"tile_fft() argument must be a tile, got {inout!r}")
6251
+
6252
+ if inout.storage != "register":
6253
+ raise ValueError(f"tile_fft() argument must have 'register' storage, got {inout.storage}")
6254
+
6255
+ if inout.dtype not in [vec2f, vec2d]:
6256
+ raise TypeError(
6257
+ f"tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries, got {inout.dtype!r}"
6258
+ )
5965
6259
 
5966
6260
  return None
5967
6261
 
@@ -5978,19 +6272,13 @@ def tile_fft_generic_lto_dispatch_func(
5978
6272
  inout = arg_values["inout"]
5979
6273
  inout.type.storage = "register"
5980
6274
 
5981
- if not is_tile(inout.type):
5982
- raise RuntimeError("tile_fft() arguments must be a single tile with register storage")
5983
-
5984
- if inout.type.dtype not in [vec2f, vec2d]:
5985
- raise RuntimeError("tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries")
5986
-
5987
6275
  # see libcufftdx.hpp
5988
6276
  if direction == "forward":
5989
6277
  dir = 0 # CUFFTDX_DIRECTION_FORWARD
5990
6278
  elif direction == "inverse":
5991
6279
  dir = 1 # CUFFTDX_DIRECTION_INVERSE
5992
6280
  else:
5993
- raise RuntimeError("Invalid direction")
6281
+ raise ValueError(f"Invalid direction: {direction!r}. Expected 'forward' or 'inverse'.")
5994
6282
 
5995
6283
  if inout.type.dtype == vec2f:
5996
6284
  dtype = "wp::vec2f"
@@ -5999,10 +6287,10 @@ def tile_fft_generic_lto_dispatch_func(
5999
6287
  dtype = "wp::vec2d"
6000
6288
  precision = 6 # COMMONDX_PRECISION_F64
6001
6289
  else:
6002
- raise RuntimeError("Unsupported datatype")
6290
+ raise TypeError(f"Unsupported data type, got {dtype!r}")
6003
6291
 
6004
6292
  # M FFTs of size N each
6005
- batch, size = inout.type.M, inout.type.N
6293
+ batch, size = inout.type.shape[0], inout.type.shape[1]
6006
6294
  num_threads = options["block_dim"]
6007
6295
  arch = options["output_arch"]
6008
6296
  ept = size // num_threads
@@ -6034,7 +6322,7 @@ def tile_fft_generic_lto_dispatch_func(
6034
6322
  lto_code.close()
6035
6323
  if lto_code_path.exists():
6036
6324
  lto_code_path.unlink()
6037
- raise RuntimeError("Failed to compile tile_matmul")
6325
+ raise RuntimeError("Failed to compile tile_fft")
6038
6326
 
6039
6327
  with open(lto_code.name, "rb") as f:
6040
6328
  lto_code_data = f.read()
@@ -6044,17 +6332,20 @@ def tile_fft_generic_lto_dispatch_func(
6044
6332
 
6045
6333
  builder.ltoirs[lto_symbol] = lto_code_data
6046
6334
 
6335
+ shared_memory_bytes = Tile.round_up(shared_memory_size.value)
6336
+
6047
6337
  return (
6048
6338
  (
6049
6339
  Var(lto_symbol, str, False, True, False),
6050
6340
  Var(dtype, str, False, True, False),
6051
- Var(str(shared_memory_size.value), str, False, True, False),
6341
+ Var(str(shared_memory_bytes), str, False, True, False),
6052
6342
  Var(str(batch), str, False, True, False),
6053
6343
  Var(str(ept), str, False, True, False),
6054
6344
  inout,
6055
6345
  ),
6056
6346
  [],
6057
6347
  [lto_code_data],
6348
+ shared_memory_bytes,
6058
6349
  )
6059
6350
 
6060
6351
 
@@ -6068,6 +6359,8 @@ add_builtin(
6068
6359
 
6069
6360
  This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
6070
6361
 
6362
+ Note that computing the adjoint is not yet supported.
6363
+
6071
6364
  Supported datatypes are:
6072
6365
  * vec2f, vec2d
6073
6366
 
@@ -6087,6 +6380,8 @@ add_builtin(
6087
6380
 
6088
6381
  This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
6089
6382
 
6383
+ Note that computing the adjoint is not yet supported.
6384
+
6090
6385
  Supported datatypes are:
6091
6386
  * vec2f, vec2d
6092
6387
 
@@ -6096,6 +6391,283 @@ add_builtin(
6096
6391
  namespace="",
6097
6392
  )
6098
6393
 
6394
+
6395
+ ##
6396
+ ## Cholesky
6397
+ ##
6398
+ def tile_cholesky_generic_value_func(arg_types, arg_values):
6399
+ if arg_types is None:
6400
+ return Tile(dtype=Any, shape=Any)
6401
+
6402
+ if len(arg_types) != 1:
6403
+ raise TypeError("tile_cholesky() requires 1 positional args")
6404
+
6405
+ a = arg_types["A"]
6406
+
6407
+ if not is_tile(a):
6408
+ raise TypeError(f"tile_cholesky() argument must be a tile, got {a!r}")
6409
+
6410
+ if len(a.shape) != 2:
6411
+ raise ValueError("tile_cholesky() argumust must be a 2D tile")
6412
+
6413
+ if a.shape[0] != a.shape[1]:
6414
+ raise ValueError("tile_cholesky() argument must be square")
6415
+
6416
+ return Tile(dtype=a.dtype, shape=a.shape, storage="shared")
6417
+
6418
+
6419
+ cusolver_function_map = {"getrf": 0, "getrf_no_pivot": 1, "potrf": 2, "potrs": 3}
6420
+
6421
+ cusolver_type_map = {float32: ("wp::float32", 5), float64: ("wp::float64", 6)}
6422
+
6423
+ cusolver_fill_mode_map = {"upper": 0, "lower": 1}
6424
+
6425
+
6426
+ def tile_cholesky_generic_lto_dispatch_func(
6427
+ arg_types: Mapping[str, type],
6428
+ return_type: Any,
6429
+ return_values: List[Var],
6430
+ arg_values: Mapping[str, Var],
6431
+ options: Mapping[str, Any],
6432
+ builder: warp.context.ModuleBuilder,
6433
+ ):
6434
+ a = arg_values["A"]
6435
+ # force source tile to shared memory
6436
+ a.type.storage = "shared"
6437
+
6438
+ if a.type.dtype not in cusolver_type_map.keys():
6439
+ raise TypeError("tile_cholesky() argument must be a tile of float32 or float64 entries")
6440
+
6441
+ if len(return_values) != 1:
6442
+ raise TypeError("tile_cholesky() returns one output")
6443
+ out = return_values[0]
6444
+
6445
+ dtype, precision_enum = cusolver_type_map[a.type.dtype]
6446
+
6447
+ # We already ensured a is square in tile_cholesky_generic_value_func()
6448
+ M, N = a.type.shape[0], a.type.shape[1]
6449
+ if out.type.shape[0] != M or out.type.shape[1] != M:
6450
+ raise ValueError("tile_cholesky() output tile must be square")
6451
+
6452
+ num_threads = options["block_dim"]
6453
+ arch = options["output_arch"]
6454
+ lto_symbol = f"potrf_{M}_{N}_{arch}_{precision_enum}"
6455
+
6456
+ # early out if LTO for this combination already exists for this module
6457
+ if lto_symbol in builder.ltoirs:
6458
+ return lto_symbol, builder.ltoirs[lto_symbol]
6459
+
6460
+ # otherwise compile LTO
6461
+ lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6462
+ universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6463
+
6464
+ # cuSOLVERDx only support col-major input/outputs,
6465
+ # so we use upper to mimic a row-major input
6466
+ result = warp.context.runtime.core.cuda_compile_solver(
6467
+ universal_fatbin_code.name.encode("utf-8"),
6468
+ lto_code.name.encode("utf-8"),
6469
+ lto_symbol.encode("utf-8"),
6470
+ 0,
6471
+ None,
6472
+ None,
6473
+ arch,
6474
+ M,
6475
+ N,
6476
+ cusolver_function_map["potrf"],
6477
+ precision_enum,
6478
+ cusolver_fill_mode_map["upper"],
6479
+ num_threads,
6480
+ )
6481
+
6482
+ if not result:
6483
+ for f in [lto_code, universal_fatbin_code]:
6484
+ f.close()
6485
+ if Path(f.name).exists():
6486
+ Path(f.name).unlink()
6487
+ raise RuntimeError("Failed to compile tile_cholesky")
6488
+
6489
+ else:
6490
+ with open(lto_code.name, "rb") as f:
6491
+ lto_code_data = f.read()
6492
+ with open(universal_fatbin_code.name, "rb") as f:
6493
+ universal_fatbin_code_data = f.read()
6494
+ for f in [lto_code, universal_fatbin_code]:
6495
+ f.close()
6496
+ Path(f.name).unlink()
6497
+
6498
+ builder.ltoirs[lto_symbol] = lto_code_data
6499
+ builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, unsigned);"
6500
+ builder.fatbins["cholesky"] = universal_fatbin_code_data
6501
+
6502
+ return ((Var(lto_symbol, str, False, True, False), a, out), [], [lto_code_data], 0)
6503
+
6504
+
6505
+ add_builtin(
6506
+ "tile_cholesky",
6507
+ input_types={"A": Tile},
6508
+ value_func=tile_cholesky_generic_value_func,
6509
+ lto_dispatch_func=tile_cholesky_generic_lto_dispatch_func,
6510
+ variadic=True,
6511
+ doc="""Compute the Cholesky factorization L of a matrix A.
6512
+ L is lower triangular and satisfies LL^T = A.
6513
+
6514
+ Note that computing the adjoint is not yet supported.
6515
+
6516
+ Supported datatypes are:
6517
+ * float32
6518
+ * float64
6519
+
6520
+ :param A: A square, symmetric positive-definite, matrix.
6521
+ :returns L: A square, lower triangular, matrix, such that LL^T = A""",
6522
+ group="Tile Primitives",
6523
+ export=False,
6524
+ namespace="",
6525
+ )
6526
+
6527
+
6528
+ def tile_cholesky_solve_generic_value_func(arg_types, arg_values):
6529
+ if arg_types is None:
6530
+ return None
6531
+
6532
+ if len(arg_types) != 2:
6533
+ raise TypeError("tile_cholesky_solve() requires exactly 2 positional args")
6534
+
6535
+ l = arg_types["L"]
6536
+ x = arg_types["x"]
6537
+
6538
+ if not is_tile(l):
6539
+ raise TypeError(f"tile_cholesky_solve() 'L' argument must be a tile, got {l!r}")
6540
+
6541
+ if not is_tile(x):
6542
+ raise TypeError(f"tile_cholesky_solve() 'x' argument must be a tile, got {l!r}")
6543
+
6544
+ if not types_equal(l.dtype, x.dtype):
6545
+ raise TypeError(f"tile_cholesky_solve() arguments must have the same dtype, got {l.dtype} and {x.dtype}")
6546
+
6547
+ if l.shape[0] != l.shape[1]:
6548
+ raise ValueError("tile_cholesky_solve() 'L' argument must be square")
6549
+
6550
+ if len(x.shape) != 1:
6551
+ raise TypeError("tile_cholesky_solve() 'x' argument must be a 1D tile")
6552
+
6553
+ if x.shape[0] != l.shape[0]:
6554
+ raise ValueError(
6555
+ f"tile_cholesky_solve() 'x' argument must have the same number of elements as the number of rows in 'L', "
6556
+ f"got {x.shape[0]} elements in 'x' and {l.shape[0]} rows in 'L'"
6557
+ )
6558
+
6559
+ return Tile(dtype=l.dtype, shape=x.shape, storage="shared")
6560
+
6561
+
6562
+ def tile_cholesky_solve_generic_lto_dispatch_func(
6563
+ arg_types: Mapping[str, type],
6564
+ return_type: Any,
6565
+ return_values: List[Var],
6566
+ arg_values: Mapping[str, Var],
6567
+ options: Mapping[str, Any],
6568
+ builder: warp.context.ModuleBuilder,
6569
+ ):
6570
+ L = arg_values["L"]
6571
+ x = arg_values["x"]
6572
+ # force the storage type of the input variables to shared memory
6573
+ L.type.storage = "shared"
6574
+ x.type.storage = "shared"
6575
+
6576
+ if len(return_values) != 1:
6577
+ raise TypeError(f"tile_cholesky_solve() must return exactly one value, got {len(return_values)}")
6578
+
6579
+ y = return_values[0]
6580
+
6581
+ if any(T not in cusolver_type_map.keys() for T in [x.type.dtype, L.type.dtype]):
6582
+ raise TypeError("tile_cholesky_solve() arguments be tiles of float64 or float32")
6583
+
6584
+ dtype, precision_enum = cusolver_type_map[L.type.dtype]
6585
+ M, N = L.type.shape[0], L.type.shape[1]
6586
+
6587
+ if len(y.type.shape) != 1:
6588
+ raise TypeError("tile_cholesky_solve() output vector must be 1D")
6589
+
6590
+ if y.type.shape[0] != M:
6591
+ raise ValueError(
6592
+ "tile_cholesky_solve() output vector must have same number of elements as the number of rows in 'L' "
6593
+ f"got {y.type.shape[0]} elements in output and {M} rows in 'L'"
6594
+ )
6595
+
6596
+ num_threads = options["block_dim"]
6597
+ arch = options["output_arch"]
6598
+ lto_symbol = f"potrs_{M}_{N}_{arch}_{precision_enum}"
6599
+
6600
+ # early out if LTO for this combination already exists for this module
6601
+ if lto_symbol in builder.ltoirs:
6602
+ return lto_symbol, builder.ltoirs[lto_symbol]
6603
+
6604
+ # otherwise compile LTO
6605
+ lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6606
+ universal_fatbin_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6607
+
6608
+ # cuSOLVERDx only support col-major input/outputs,
6609
+ # so we use upper to mimic a row-major input
6610
+ result = warp.context.runtime.core.cuda_compile_solver(
6611
+ universal_fatbin_code.name.encode("utf-8"),
6612
+ lto_code.name.encode("utf-8"),
6613
+ lto_symbol.encode("utf-8"),
6614
+ 0,
6615
+ None,
6616
+ None,
6617
+ arch,
6618
+ M,
6619
+ N,
6620
+ cusolver_function_map["potrs"],
6621
+ precision_enum,
6622
+ cusolver_fill_mode_map["upper"],
6623
+ num_threads,
6624
+ )
6625
+
6626
+ if not result:
6627
+ for f in [lto_code, universal_fatbin_code]:
6628
+ f.close()
6629
+ if Path(f.name).exists():
6630
+ Path(f.name).unlink()
6631
+ raise RuntimeError("Failed to compile tile_cholesky_solve")
6632
+
6633
+ else:
6634
+ with open(lto_code.name, "rb") as f:
6635
+ lto_code_data = f.read()
6636
+ with open(universal_fatbin_code.name, "rb") as f:
6637
+ universal_fatbin_code_data = f.read()
6638
+ for f in [lto_code, universal_fatbin_code]:
6639
+ f.close()
6640
+ Path(f.name).unlink()
6641
+
6642
+ builder.ltoirs[lto_symbol] = lto_code_data
6643
+ builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({dtype}*, {dtype}*);"
6644
+ builder.fatbins["cholesky"] = universal_fatbin_code_data
6645
+
6646
+ return ((Var(lto_symbol, str, False, True, False), L, x, y), [], [lto_code_data], 0)
6647
+
6648
+
6649
+ add_builtin(
6650
+ "tile_cholesky_solve",
6651
+ input_types={"L": Tile, "x": Tile},
6652
+ value_func=tile_cholesky_solve_generic_value_func,
6653
+ lto_dispatch_func=tile_cholesky_solve_generic_lto_dispatch_func,
6654
+ variadic=True,
6655
+ doc="""With L such that LL^T = A, solve for x in Ax = y
6656
+
6657
+ Note that computing the adjoint is not yet supported.
6658
+
6659
+ Supported datatypes are:
6660
+ * float32
6661
+ * float64
6662
+
6663
+ :param L: A square, lower triangular, matrix, such that LL^T = A
6664
+ :param x: An 1D tile of length M
6665
+ :returns y: An 1D tile of length M such that LL^T y = x""",
6666
+ group="Tile Primitives",
6667
+ export=False,
6668
+ namespace="",
6669
+ )
6670
+
6099
6671
  # ---------------------------------
6100
6672
  # Code Generation
6101
6673
 
@@ -6103,7 +6675,7 @@ add_builtin(
6103
6675
  "static",
6104
6676
  input_types={"expr": Any},
6105
6677
  value_type=Any,
6106
- doc="""Evaluates a static Python expression and replaces it with its result.
6678
+ doc="""Evaluate a static Python expression and replaces it with its result.
6107
6679
 
6108
6680
  See the :ref:`code generation guide <static_expressions>` for more details.
6109
6681
 
@@ -6127,3 +6699,58 @@ def static(expr):
6127
6699
  which includes constant variables and variables captured in the current closure in which the function or kernel is implemented.
6128
6700
  """
6129
6701
  return expr
6702
+
6703
+
6704
+ add_builtin(
6705
+ "len",
6706
+ input_types={"a": vector(length=Any, dtype=Scalar)},
6707
+ value_type=int,
6708
+ doc="Return the number of elements in a vector.",
6709
+ group="Utility",
6710
+ export=False,
6711
+ )
6712
+
6713
+ add_builtin(
6714
+ "len",
6715
+ input_types={"a": quaternion(dtype=Scalar)},
6716
+ value_type=int,
6717
+ doc="Return the number of elements in a quaternion.",
6718
+ group="Utility",
6719
+ export=False,
6720
+ )
6721
+
6722
+ add_builtin(
6723
+ "len",
6724
+ input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
6725
+ value_type=int,
6726
+ doc="Return the number of rows in a matrix.",
6727
+ group="Utility",
6728
+ export=False,
6729
+ )
6730
+
6731
+ add_builtin(
6732
+ "len",
6733
+ input_types={"a": transformation(dtype=Float)},
6734
+ value_type=int,
6735
+ doc="Return the number of elements in a transformation.",
6736
+ group="Utility",
6737
+ export=False,
6738
+ )
6739
+
6740
+ add_builtin(
6741
+ "len",
6742
+ input_types={"a": array(dtype=Any)},
6743
+ value_type=int,
6744
+ doc="Return the size of the first dimension in an array.",
6745
+ group="Utility",
6746
+ export=False,
6747
+ )
6748
+
6749
+ add_builtin(
6750
+ "len",
6751
+ input_types={"a": Tile(dtype=Any, shape=Any)},
6752
+ value_type=int,
6753
+ doc="Return the number of rows in a tile.",
6754
+ group="Utility",
6755
+ export=False,
6756
+ )