warp-lang 1.4.2__py3-none-manylinux2014_x86_64.whl → 1.5.1__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/builtins.py CHANGED
@@ -5,6 +5,9 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
  import builtins
8
+ import functools
9
+ import tempfile
10
+ from pathlib import Path
8
11
  from typing import Any, Callable, Mapping, Sequence
9
12
 
10
13
  from warp.codegen import Reference, Var, strip_reference
@@ -396,11 +399,11 @@ def scalar_infer_type(arg_types: Mapping[str, type]):
396
399
 
397
400
  scalar_types = set()
398
401
  for t in arg_types:
399
- t = strip_reference(t)
400
- if hasattr(t, "_wp_scalar_type_"):
401
- scalar_types.add(t._wp_scalar_type_)
402
- elif t in scalar_and_bool_types:
403
- scalar_types.add(t)
402
+ t_val = strip_reference(t)
403
+ if hasattr(t_val, "_wp_scalar_type_"):
404
+ scalar_types.add(t_val._wp_scalar_type_)
405
+ elif t_val in scalar_and_bool_types:
406
+ scalar_types.add(t_val)
404
407
 
405
408
  if len(scalar_types) > 1:
406
409
  raise RuntimeError(
@@ -1702,6 +1705,1267 @@ add_builtin(
1702
1705
  group="Spatial Math",
1703
1706
  )
1704
1707
 
1708
+ # ------------------
1709
+ # Tile-based primitives
1710
+ shared_memory_id = 0
1711
+
1712
+
1713
+ def tile_zeros_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1714
+ # return generic type (for doc builds)
1715
+ if arg_types is None:
1716
+ return Tile(dtype=Any, M=Any, N=Any)
1717
+
1718
+ if "m" not in arg_values:
1719
+ raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
1720
+
1721
+ if "n" not in arg_values:
1722
+ raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
1723
+
1724
+ if "dtype" not in arg_values:
1725
+ raise RuntimeError("'dtype' keyword argument must be specified when calling tile_zeros() function")
1726
+
1727
+ if "storage" not in arg_values:
1728
+ raise ValueError("'storage' keyword not provided for tile_zeros")
1729
+
1730
+ if arg_values["storage"] not in {"shared", "register"}:
1731
+ raise ValueError(
1732
+ f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1733
+ )
1734
+
1735
+ m, n = arg_values["m"], arg_values["n"]
1736
+ dtype = arg_values["dtype"]
1737
+
1738
+ return TileZeros(dtype=dtype, M=m, N=n, storage=arg_values["storage"])
1739
+
1740
+
1741
+ def tile_zeros_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1742
+ m, n, dtype = arg_values["m"], arg_values["n"], arg_values["dtype"]
1743
+
1744
+ template_args = []
1745
+ template_args.append(dtype)
1746
+ template_args.append(m.constant)
1747
+ template_args.append(n.constant)
1748
+
1749
+ return ([], template_args)
1750
+
1751
+
1752
+ add_builtin(
1753
+ "tile_zeros",
1754
+ input_types={"m": int, "n": int, "dtype": Any, "storage": str},
1755
+ defaults={"storage": "register"},
1756
+ value_func=tile_zeros_value_func,
1757
+ dispatch_func=tile_zeros_dispatch_func,
1758
+ variadic=False,
1759
+ missing_grad=True,
1760
+ doc="""Allocates a tile of zero-initialized items.
1761
+
1762
+ :param m: Size of the first dimension of the output tile
1763
+ :param n: Size of the second dimension of the output tile
1764
+ :param dtype: Datatype of output tile's elements
1765
+ :param storage: The storage location for the tile: ``"register"`` for registers
1766
+ (default) or ``"shared"`` for shared memory.
1767
+ :returns: A zero-initialized tile with ``shape=(m,n)`` and the specified datatype""",
1768
+ group="Tile Primitives",
1769
+ export=False,
1770
+ )
1771
+
1772
+
1773
+ def tile_ones_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1774
+ # return generic type (for doc builds)
1775
+ if arg_types is None:
1776
+ return Tile(dtype=Any, M=Any, N=Any)
1777
+
1778
+ if "m" not in arg_values:
1779
+ raise RuntimeError("'m' keyword argument must be specified when calling tile_zeros() function")
1780
+
1781
+ if "n" not in arg_values:
1782
+ raise RuntimeError("'n' keyword argument must be specified when calling tile_zeros() function")
1783
+
1784
+ if "dtype" not in arg_values:
1785
+ raise RuntimeError("'dtype' keyword argument must be specified when calling tile_zeros() function")
1786
+
1787
+ if arg_values["storage"] not in {"shared", "register"}:
1788
+ raise ValueError(
1789
+ f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1790
+ )
1791
+
1792
+ m, n = arg_values["m"], arg_values["n"]
1793
+ dtype = arg_values["dtype"]
1794
+
1795
+ return TileZeros(dtype=dtype, M=m, N=n, storage=arg_values["storage"])
1796
+
1797
+
1798
+ def tile_ones_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1799
+ m, n, dtype = arg_values["m"], arg_values["n"], arg_values["dtype"]
1800
+
1801
+ template_args = []
1802
+ template_args.append(dtype)
1803
+ template_args.append(m.constant)
1804
+ template_args.append(n.constant)
1805
+
1806
+ return ([], template_args)
1807
+
1808
+
1809
+ add_builtin(
1810
+ "tile_ones",
1811
+ input_types={"m": int, "n": int, "dtype": Any, "storage": str},
1812
+ defaults={"storage": "register"},
1813
+ value_func=tile_ones_value_func,
1814
+ dispatch_func=tile_ones_dispatch_func,
1815
+ missing_grad=True,
1816
+ doc="""Allocates a tile of one-initialized items.
1817
+
1818
+ :param m: Size of the first dimension of the output tile
1819
+ :param n: Size of the second dimension of the output tile
1820
+ :param dtype: Datatype of output tile's elements
1821
+ :param storage: The storage location for the tile: ``"register"`` for registers
1822
+ (default) or ``"shared"`` for shared memory.
1823
+ :returns: A one-initialized tile with ``shape=(m,n)`` and the specified dtype""",
1824
+ group="Tile Primitives",
1825
+ export=False,
1826
+ )
1827
+
1828
+
1829
+ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
1830
+ # return generic type (for doc builds)
1831
+ if arg_types is None:
1832
+ return Tile(dtype=Any, M=Any, N=Any)
1833
+
1834
+ start = 0
1835
+ stop = 0
1836
+ step = 1
1837
+ dtype = int
1838
+
1839
+ args = arg_values["args"]
1840
+
1841
+ if len(args) == 1:
1842
+ start = 0
1843
+ stop = args[0]
1844
+
1845
+ elif len(args) == 2:
1846
+ start = args[0]
1847
+ stop = args[1]
1848
+
1849
+ elif len(args) == 3:
1850
+ start = args[0]
1851
+ stop = args[1]
1852
+ step = args[2]
1853
+
1854
+ if start is None or stop is None or step is None:
1855
+ print(args)
1856
+ raise RuntimeError("wp.tile_arange() arguments must be compile time constants")
1857
+
1858
+ if "dtype" in arg_values:
1859
+ dtype = arg_values["dtype"]
1860
+ else:
1861
+ dtype = float
1862
+
1863
+ if arg_values["storage"] not in {"shared", "register"}:
1864
+ raise ValueError(
1865
+ f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1866
+ )
1867
+
1868
+ return TileRange(dtype=dtype, start=start, stop=stop, step=step, storage=arg_values["storage"])
1869
+
1870
+
1871
+ def tile_arange_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1872
+ m, n, dtype = return_type.M, return_type.N, return_type.dtype
1873
+
1874
+ template_args = []
1875
+ template_args.append(dtype)
1876
+ template_args.append(m)
1877
+ template_args.append(n)
1878
+
1879
+ # todo: it is somewhat redundant to create new vars here since some of start,stop,step
1880
+ # already exist depending on which form the function was called by the user
1881
+ start = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.start)
1882
+ stop = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.stop)
1883
+ step = warp.codegen.Var(label=None, type=return_type.dtype, constant=return_type.step)
1884
+
1885
+ function_args = []
1886
+ function_args.append(start)
1887
+ function_args.append(stop)
1888
+ function_args.append(step)
1889
+
1890
+ return (function_args, template_args)
1891
+
1892
+
1893
+ add_builtin(
1894
+ "tile_arange",
1895
+ input_types={"*args": Scalar, "dtype": Any, "storage": str},
1896
+ defaults={"dtype": None, "storage": "register"},
1897
+ value_func=tile_arange_value_func,
1898
+ dispatch_func=tile_arange_dispatch_func,
1899
+ variadic=True,
1900
+ missing_grad=True,
1901
+ doc="""Generates a tile of linearly spaced elements.
1902
+
1903
+ :param args: Variable-length positional arguments, interpreted as:
1904
+
1905
+ - ``(stop,)``: Generates values from ``0`` to ``stop - 1``
1906
+ - ``(start, stop)``: Generates values from ``start`` to ``stop - 1``
1907
+ - ``(start, stop, step)``: Generates values from ``start`` to ``stop - 1`` with a step size
1908
+
1909
+ :param dtype: Datatype of output tile's elements (optional, default: int)
1910
+ :param storage: The storage location for the tile: ``"register"`` for registers
1911
+ (default) or ``"shared"`` for shared memory.
1912
+ :returns: A tile with ``shape=(1,n)`` with linearly spaced elements of specified dtype""",
1913
+ group="Tile Primitives",
1914
+ export=False,
1915
+ )
1916
+
1917
+
1918
+ def tile_load_1d_value_func(arg_types, arg_values):
1919
+ # return generic type (for doc builds)
1920
+ if arg_types is None:
1921
+ return Tile(dtype=Any, M=Any, N=Any)
1922
+
1923
+ if not is_array(arg_types["a"]):
1924
+ raise RuntimeError("tile_load() argument 0 must be an array")
1925
+
1926
+ if arg_types["a"].ndim != 1:
1927
+ raise RuntimeError(
1928
+ "tile_load() argument 0 must be 1-dimensional if using the ``wp.tile_load(array, i, n)`` syntax."
1929
+ )
1930
+
1931
+ if not type_is_int(arg_types["i"]):
1932
+ raise RuntimeError("tile_load() argument 1 must be an integer")
1933
+
1934
+ if "n" not in arg_values:
1935
+ raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
1936
+
1937
+ if arg_values["storage"] not in {"shared", "register"}:
1938
+ raise ValueError(
1939
+ f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
1940
+ )
1941
+
1942
+ a = arg_types["a"]
1943
+ _m, n = 1, arg_values["n"]
1944
+
1945
+ return TileLoad(a, 1, n, arg_values["storage"])
1946
+
1947
+
1948
+ def tile_load_1d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
1949
+ array = arg_values["a"]
1950
+ i = arg_values["i"]
1951
+ n = arg_values["n"].constant
1952
+ dtype = arg_values["a"].type.dtype
1953
+
1954
+ template_args = []
1955
+ template_args.append(dtype)
1956
+ template_args.append(n)
1957
+
1958
+ return ((array, i), template_args)
1959
+
1960
+
1961
+ add_builtin(
1962
+ "tile_load",
1963
+ input_types={"a": array(dtype=Any), "i": int, "n": int, "storage": str},
1964
+ defaults={"storage": "register"},
1965
+ value_func=tile_load_1d_value_func,
1966
+ dispatch_func=tile_load_1d_dispatch_func,
1967
+ variadic=False,
1968
+ doc="""Loads a 1D tile from a global memory array.
1969
+
1970
+ This method will cooperatively load a tile from global memory using all threads in the block.
1971
+
1972
+ :param a: The source array in global memory
1973
+ :param i: Offset in the source array measured in multiples of ``n``, i.e.: ``offset=i*n``
1974
+ :param n: The number of elements in the tile
1975
+ :param storage: The storage location for the tile: ``"register"`` for registers
1976
+ (default) or ``"shared"`` for shared memory.
1977
+ :returns: A tile with ``shape=(1,n)`` and dtype the same as the source array""",
1978
+ group="Tile Primitives",
1979
+ export=False,
1980
+ )
1981
+
1982
+
1983
+ def tile_load_2d_value_func(arg_types, arg_values):
1984
+ # return generic type (for doc builds)
1985
+ if arg_types is None:
1986
+ return Tile(dtype=Any, M=Any, N=Any)
1987
+
1988
+ if not is_array(arg_types["a"]):
1989
+ raise RuntimeError("tile_load() argument 0 must be an array")
1990
+
1991
+ if arg_types["a"].ndim != 2:
1992
+ raise RuntimeError(
1993
+ "tile_load() argument 0 must be 2-dimensional if using the ``wp.tile_load(array, i, j, m, n)`` syntax."
1994
+ )
1995
+
1996
+ if not type_is_int(arg_types["i"]):
1997
+ raise RuntimeError("tile_load() argument 1 must be an integer")
1998
+
1999
+ if not type_is_int(arg_types["j"]):
2000
+ raise RuntimeError("tile_load() argument 1 must be an integer")
2001
+
2002
+ if "m" not in arg_values:
2003
+ raise RuntimeError("'m' keyword argument must be specified when calling tile_load() function")
2004
+
2005
+ if "n" not in arg_values:
2006
+ raise RuntimeError("'n' keyword argument must be specified when calling tile_load() function")
2007
+
2008
+ if arg_values["storage"] not in {"shared", "register"}:
2009
+ raise ValueError(
2010
+ f"'storage' keyword argument must be either 'shared' or 'register', got {arg_values['storage']}"
2011
+ )
2012
+
2013
+ a = arg_types["a"]
2014
+ m, n = arg_values["m"], arg_values["n"]
2015
+
2016
+ return TileLoad(a, m, n, arg_values["storage"])
2017
+
2018
+
2019
+ def tile_load_2d_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2020
+ array = arg_values["a"]
2021
+ i, j = arg_values["i"], arg_values["j"]
2022
+ m, n = arg_values["m"].constant, arg_values["n"].constant
2023
+ dtype = arg_values["a"].type.dtype
2024
+
2025
+ template_args = []
2026
+ template_args.append(dtype)
2027
+ template_args.append(m)
2028
+ template_args.append(n)
2029
+
2030
+ return ((array, i, j), template_args)
2031
+
2032
+
2033
+ add_builtin(
2034
+ "tile_load",
2035
+ input_types={"a": array(dtype=Any), "i": int, "j": int, "m": int, "n": int, "storage": str},
2036
+ defaults={"storage": "register"},
2037
+ value_func=tile_load_2d_value_func,
2038
+ dispatch_func=tile_load_2d_dispatch_func,
2039
+ variadic=False,
2040
+ doc="""Loads a 2D tile from a global memory array.
2041
+
2042
+ This method will cooperatively load a tile from global memory using all threads in the block.
2043
+
2044
+ :param a: The source array in global memory
2045
+ :param i: Offset in the source array measured in multiples of ``m``, i.e.: ``row=i*m``
2046
+ :param j: Offset in the source array measured in multiples of ``n``, i.e.; ``col=j*n``
2047
+ :param m: The size of the tile's first dimension
2048
+ :param n: The size of the tile's second dimension
2049
+ :param storage: The storage location for the tile: ``"register"`` for registers
2050
+ (default) or ``"shared"`` for shared memory.
2051
+ :returns: A tile with ``shape=(m,n)`` and dtype the same as the source array""",
2052
+ group="Tile Primitives",
2053
+ export=False,
2054
+ )
2055
+
2056
+
2057
+ def tile_store_1d_value_func(arg_types, arg_values):
2058
+ # return generic type (for doc builds)
2059
+ if arg_types is None:
2060
+ return None
2061
+
2062
+ if len(arg_types) != 3:
2063
+ raise RuntimeError("tile_store() requires 3 positional args")
2064
+
2065
+ if not is_array(arg_types["a"]):
2066
+ raise RuntimeError("tile_store() argument 0 must be an array")
2067
+
2068
+ if arg_types["a"].ndim != 1:
2069
+ raise RuntimeError(
2070
+ "tile_load() argument 0 must be a 1-dimensional array if using the ``wp.tile_store(array, i, t)`` syntax."
2071
+ )
2072
+
2073
+ if not type_is_int(arg_types["i"]):
2074
+ raise RuntimeError("tile_store() argument 1 must be an integer")
2075
+
2076
+ if not is_tile(arg_types["t"]):
2077
+ raise RuntimeError("tile_store() argument 2 must be a tile")
2078
+
2079
+ if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2080
+ raise RuntimeError("tile_store() destination array must have same type as source tile")
2081
+
2082
+ return None
2083
+
2084
+
2085
+ add_builtin(
2086
+ "tile_store",
2087
+ input_types={"a": array(dtype=Any), "i": int, "t": Tile(dtype=Any, M=Any, N=Any)},
2088
+ value_func=tile_store_1d_value_func,
2089
+ variadic=False,
2090
+ skip_replay=True,
2091
+ doc="""Stores a 1D tile to a global memory array.
2092
+
2093
+ This method will cooperatively store a tile to global memory using all threads in the block.
2094
+
2095
+ :param a: The destination array in global memory
2096
+ :param i: Offset in the destination array measured in multiples of ``n``, i.e.: ``offset=i*n``
2097
+ :param t: The source tile to store data from, must have the same dtype as the destination array""",
2098
+ group="Tile Primitives",
2099
+ export=False,
2100
+ )
2101
+
2102
+
2103
+ def tile_store_2d_value_func(arg_types, arg_values):
2104
+ # return generic type (for doc builds)
2105
+ if arg_types is None:
2106
+ return None
2107
+
2108
+ if len(arg_types) != 4:
2109
+ raise RuntimeError("tile_store() requires 4 positional args")
2110
+
2111
+ if not is_array(arg_types["a"]):
2112
+ raise RuntimeError("tile_store() argument 0 must be an array")
2113
+
2114
+ if arg_types["a"].ndim != 2:
2115
+ raise RuntimeError(
2116
+ "tile_load() argument 0 must be a 2-dimensional array if using the ``wp.tile_store(array, i, j, t)`` syntax."
2117
+ )
2118
+
2119
+ if not type_is_int(arg_types["i"]):
2120
+ raise RuntimeError("tile_store() argument 1 must be an integer")
2121
+
2122
+ if not type_is_int(arg_types["j"]):
2123
+ raise RuntimeError("tile_store() argument 2 must be an integer")
2124
+
2125
+ if not is_tile(arg_types["t"]):
2126
+ raise RuntimeError("tile_store() argument 3 must be a tile")
2127
+
2128
+ if not types_equal(arg_types["a"].dtype, arg_types["t"].dtype):
2129
+ raise RuntimeError("tile_store() destination array must have same type as source tile")
2130
+
2131
+ return None
2132
+
2133
+
2134
+ add_builtin(
2135
+ "tile_store",
2136
+ input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Tile(dtype=Any, M=Any, N=Any)},
2137
+ value_func=tile_store_2d_value_func,
2138
+ variadic=False,
2139
+ skip_replay=True,
2140
+ doc="""Stores a tile to a global memory array.
2141
+
2142
+ This method will cooperatively store a tile to global memory using all threads in the block.
2143
+
2144
+ :param a: The destination array in global memory
2145
+ :param i: Offset in the destination array measured in multiples of ``m``, i.e.: ``row=i*m``
2146
+ :param j: Offset in the destination array measured in multiples of ``n``, i.e.; ``col=j*n``
2147
+ :param t: The source tile to store data from, must have the same dtype as the destination array""",
2148
+ group="Tile Primitives",
2149
+ export=False,
2150
+ )
2151
+
2152
+
2153
+ def tile_atomic_add_value_func(arg_types, arg_values):
2154
+ # return generic type (for doc builds)
2155
+ if arg_types is None:
2156
+ return Tile(dtype=Any, M=Any, N=Any)
2157
+
2158
+ if len(arg_types) != 4:
2159
+ raise RuntimeError("tile_atomic_add() requires 4 positional args")
2160
+
2161
+ if not is_array(arg_types["a"]):
2162
+ raise RuntimeError("tile_atomic_add() argument 0 must be an array")
2163
+
2164
+ if not type_is_int(arg_types["x"]):
2165
+ raise RuntimeError("tile_atomic_add() argument 1 must be an integer")
2166
+
2167
+ if not type_is_int(arg_types["y"]):
2168
+ raise RuntimeError("tile_atomic_add() argument 2 must be an integer")
2169
+
2170
+ if not is_tile(arg_types["t"]):
2171
+ raise RuntimeError("tile_atomic_add() argument 3 must be a tile")
2172
+
2173
+ if arg_types["a"].dtype != arg_types["t"].dtype:
2174
+ raise RuntimeError("tile_atomic_add() tile dtype and array dtype must match")
2175
+
2176
+ return Tile(dtype=arg_types["t"].dtype, M=arg_types["t"].M, N=arg_types["t"].N)
2177
+
2178
+
2179
+ add_builtin(
2180
+ "tile_atomic_add",
2181
+ input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Tile(dtype=Any, M=Any, N=Any)},
2182
+ value_func=tile_atomic_add_value_func,
2183
+ variadic=True,
2184
+ skip_replay=True,
2185
+ doc="""Atomically add a tile to the array `a`, each element will be updated atomically.
2186
+
2187
+ :param a: Array in global memory, should have the same ``dtype`` as the input tile
2188
+ :param x: Offset in the destination array measured in multiples of ``m``, i.e.: ``i=x*M`` where ``M`` is the first tile dimension
2189
+ :param y: Offset in the destination array measured in multiples of ``n``, i.e.: ``j=y*N`` where ``N`` is the second tile dimension
2190
+ :param t: Source tile to add to the destination array
2191
+ :returns: A tile with the same dimensions and type as the source tile, holding the original value of the destination elements""",
2192
+ group="Tile Primitives",
2193
+ export=False,
2194
+ )
2195
+
2196
+
2197
+ def tile_view_value_func(arg_types, arg_values):
2198
+ # return generic type (for doc builds)
2199
+ if arg_types is None:
2200
+ return Tile(dtype=Any, M=Any, N=Any)
2201
+
2202
+ tile = arg_types["t"]
2203
+
2204
+ if "m" not in arg_values:
2205
+ m = 1
2206
+ else:
2207
+ m = arg_values["m"]
2208
+
2209
+ if "n" not in arg_values:
2210
+ n = tile.N
2211
+ else:
2212
+ n = arg_values["n"]
2213
+
2214
+ if m > tile.M or n > tile.N:
2215
+ raise RuntimeError(
2216
+ f"Trying to view a tile subrange with dimensions ({m}, {n}) which is larger than source tile with dimensions ({tile.M}, {tile.N})"
2217
+ )
2218
+
2219
+ # force source tile to shared memory
2220
+ tile.storage = "shared"
2221
+
2222
+ output = Tile(dtype=tile.dtype, M=m, N=n, strides=tile.strides, layout=tile.layout, storage="shared", owner=False)
2223
+ return output
2224
+
2225
+
2226
+ def tile_view_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2227
+ tile = arg_values["t"]
2228
+ i = arg_values["i"]
2229
+
2230
+ if "j" not in arg_values:
2231
+ j = warp.codegen.Var(label=None, type=int, constant=0)
2232
+ else:
2233
+ j = arg_values["j"]
2234
+
2235
+ template_args = []
2236
+ template_args.append(return_type.M)
2237
+ template_args.append(return_type.N)
2238
+
2239
+ return ((tile, i, j), template_args)
2240
+
2241
+
2242
+ add_builtin(
2243
+ "tile_view",
2244
+ input_types={"t": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int, "m": int, "n": int},
2245
+ value_func=tile_view_value_func,
2246
+ dispatch_func=tile_view_dispatch_func,
2247
+ defaults={"j": None, "m": None, "n": None},
2248
+ variadic=True,
2249
+ doc="""Return a subrange of a given tile from coordinates (i,j) to (i+m, j+n).
2250
+
2251
+ :param t: Input tile to extract a subrange from
2252
+ :param i: Offset in the source tile along the first dimension
2253
+ :param j: Offset in the source tile along the second dimensions
2254
+ :param m: Size of the subrange to return along the first dimension
2255
+ :param n: Size of the subrange to return along the second dimension
2256
+ :returns: A tile with dimensions (m,n) and the same datatype as the input tile""",
2257
+ group="Tile Primitives",
2258
+ export=False,
2259
+ )
2260
+
2261
+
2262
+ def tile_assign_value_func(arg_types, arg_values):
2263
+ # return generic type (for doc builds)
2264
+ return None
2265
+
2266
+
2267
+ add_builtin(
2268
+ "tile_assign",
2269
+ input_types={"dst": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int, "src": Tile(dtype=Any, M=Any, N=Any)},
2270
+ value_func=tile_assign_value_func,
2271
+ # dispatch_func=tile_assign_dispatch_func,
2272
+ doc="""Assign a tile to a subrange of a destination tile at coordinates (i,j).
2273
+
2274
+ :param t: The destination tile to assign to
2275
+ :param i: Offset in the source tile along the first dimension
2276
+ :param j: Offset in the source tile along the second dimensions
2277
+ :param src: The source tile to read values from""",
2278
+ group="Tile Primitives",
2279
+ export=False,
2280
+ )
2281
+
2282
+
2283
+ def tile_value_func(arg_types, arg_values):
2284
+ # return generic type (for doc builds)
2285
+ if arg_types is None:
2286
+ return Tile
2287
+
2288
+ if len(arg_types) != 1:
2289
+ raise RuntimeError("tile() requires 1 positional arg")
2290
+
2291
+ dtype = None
2292
+ length = None
2293
+
2294
+ if type_is_vector(arg_types["x"]):
2295
+ dtype = arg_types["x"]._wp_scalar_type_
2296
+ length = arg_types["x"]._shape_[0]
2297
+ else:
2298
+ dtype = arg_types["x"]
2299
+ length = 1
2300
+
2301
+ return Tile(dtype=dtype, M=length, N=warp.codegen.options["block_dim"], op="tile")
2302
+
2303
+
2304
+ add_builtin(
2305
+ "tile",
2306
+ input_types={"x": Any},
2307
+ value_func=tile_value_func,
2308
+ variadic=True,
2309
+ doc="""Constructs a new Tile from per-thread kernel values.
2310
+
2311
+ This function converts values computed using scalar kernel code to a tile representation for input into collective operations.
2312
+
2313
+ * If the input value is a scalar, then the resulting tile has ``shape=(1, block_dim)``
2314
+ * If the input value is a vector, then the resulting tile has ``shape=(length(vector), block_dim)``
2315
+
2316
+ :param x: A per-thread local value, e.g.: scalar, vector, or matrix.
2317
+ :returns: A tile with first dimension according to the value type length and a second dimension equal to ``block_dim``
2318
+
2319
+ This example shows how to create a linear sequence from thread variables:
2320
+
2321
+ .. code-block:: python
2322
+
2323
+ @wp.kernel
2324
+ def compute():
2325
+ i = wp.tid()
2326
+ t = wp.tile(i*2)
2327
+ print(t)
2328
+
2329
+ wp.launch(compute, dim=16, inputs=[], block_dim=16)
2330
+
2331
+ Prints:
2332
+
2333
+ .. code-block:: text
2334
+
2335
+ tile(m=1, n=16, storage=register) = [[0 2 4 6 8 ...]]
2336
+
2337
+ """,
2338
+ group="Tile Primitives",
2339
+ export=False,
2340
+ )
2341
+
2342
+
2343
+ def untile_value_func(arg_types, arg_values):
2344
+ # return generic type (for doc builds)
2345
+ if arg_types is None:
2346
+ return Scalar
2347
+
2348
+ if len(arg_types) != 1:
2349
+ raise RuntimeError("untile() requires 1 positional arg")
2350
+
2351
+ t = arg_types["a"]
2352
+
2353
+ if not is_tile(t):
2354
+ raise RuntimeError(f"untile() accepts arguments of type tile only, got {arg_types[0]}")
2355
+
2356
+ if t.N != warp.codegen.options["block_dim"]:
2357
+ raise RuntimeError(
2358
+ f"untile() argument must have the same length as the block width, got {t.N}, expected {warp.codegen.options['block_dim']}"
2359
+ )
2360
+
2361
+ if t.M == 1:
2362
+ return t.dtype
2363
+ elif t.M > 1:
2364
+ return warp.types.vector(t.M, t.dtype)
2365
+
2366
+
2367
+ add_builtin(
2368
+ "untile",
2369
+ input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
2370
+ value_func=untile_value_func,
2371
+ variadic=True,
2372
+ doc="""Convert a Tile back to per-thread values.
2373
+
2374
+ This function converts a block-wide tile back to per-thread values.
2375
+
2376
+ * If the input tile is 1-dimensional then the resulting value will be a per-thread scalar
2377
+ * If the input tile is 2-dimensional then the resulting value will be a per-thread vector of length M
2378
+
2379
+ :param a: A tile with dimensions ``shape=(M, block_dim)``
2380
+ :returns: A single value per-thread with the same dtype as the tile
2381
+
2382
+ This example shows how to create a linear sequence from thread variables:
2383
+
2384
+ .. code-block:: python
2385
+
2386
+ @wp.kernel
2387
+ def compute():
2388
+ i = wp.tid()
2389
+
2390
+ # create block-wide tile
2391
+ t = wp.tile(i)*2
2392
+
2393
+ # convert back to per-thread values
2394
+ s = wp.untile(t)
2395
+
2396
+ print(s)
2397
+
2398
+ wp.launch(compute, dim=16, inputs=[], block_dim=16)
2399
+
2400
+ Prints:
2401
+
2402
+ .. code-block:: text
2403
+
2404
+ 0
2405
+ 2
2406
+ 4
2407
+ 6
2408
+ 8
2409
+ ...
2410
+ """,
2411
+ group="Tile Primitives",
2412
+ export=False,
2413
+ )
2414
+
2415
+
2416
+ def tile_extract_value_func(arg_types, arg_values):
2417
+ # return generic type (for doc builds)
2418
+ if arg_types is None:
2419
+ return Scalar
2420
+
2421
+ if len(arg_types) != 3:
2422
+ raise RuntimeError("tile_extract() requires 3 positional args")
2423
+
2424
+ if not is_tile(arg_types["a"]):
2425
+ raise RuntimeError("tile_extract() argument 0 must be a tile")
2426
+
2427
+ return arg_types["a"].dtype
2428
+
2429
+
2430
+ add_builtin(
2431
+ "tile_extract",
2432
+ input_types={"a": Tile(dtype=Any, M=Any, N=Any), "i": int, "j": int},
2433
+ value_func=tile_extract_value_func,
2434
+ variadic=True,
2435
+ doc="""Extracts a single element from the tile and returns it as a scalar type.
2436
+
2437
+ This function will extract an element from the tile and broadcast its value to all threads in the block.
2438
+
2439
+ Note that this may incur additional synchronization if the source tile is a register tile.
2440
+
2441
+ :param a: Tile to extract the element from
2442
+ :param i: Coordinate of element on first dimension
2443
+ :param j: Coordinate of element on the second dimension
2444
+ :returns: The value of the element at the specified tile location, with the same type as the input tile's per-element dtype""",
2445
+ group="Tile Primitives",
2446
+ export=False,
2447
+ )
2448
+
2449
+
2450
+ def tile_transpose_value_func(arg_types, arg_values):
2451
+ # return generic type (for doc builds)
2452
+ if arg_types is None:
2453
+ return Tile
2454
+
2455
+ if len(arg_types) != 1:
2456
+ raise RuntimeError("tile_transpose() requires 1 positional args")
2457
+
2458
+ t = arg_types["a"]
2459
+
2460
+ if not is_tile(t):
2461
+ raise RuntimeError("tile_transpose() argument 0 must be a tile")
2462
+
2463
+ layout = None
2464
+
2465
+ # flip layout
2466
+ if t.layout == "rowmajor":
2467
+ layout = "colmajor"
2468
+ elif t.layout == "colmajor":
2469
+ layout = "rowmajor"
2470
+
2471
+ # force the input tile to shared memory
2472
+ t.storage = "shared"
2473
+
2474
+ return Tile(
2475
+ dtype=t.dtype,
2476
+ M=t.N,
2477
+ N=t.M,
2478
+ op="transpose",
2479
+ storage=t.storage,
2480
+ strides=t.strides[::-1],
2481
+ layout=layout,
2482
+ owner=False,
2483
+ )
2484
+
2485
+
2486
+ add_builtin(
2487
+ "tile_transpose",
2488
+ input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
2489
+ value_func=tile_transpose_value_func,
2490
+ variadic=True,
2491
+ doc="""Transpose a tile.
2492
+
2493
+ For shared memory tiles this operation will alias the input tile, register tiles will first be transferred to shared memory before transposition.
2494
+
2495
+ :param a: Tile to transpose with ``shape=(M,N)``
2496
+ :returns: Tile with ``shape=(N,M)``""",
2497
+ group="Tile Primitives",
2498
+ export=False,
2499
+ )
2500
+
2501
+
2502
+ def tile_broadcast_value_func(arg_types, arg_values):
2503
+ # return generic type (for doc builds)
2504
+ if arg_types is None:
2505
+ return Tile
2506
+
2507
+ if len(arg_types) != 3:
2508
+ raise RuntimeError("tile_broadcast() requires 1 positional args")
2509
+
2510
+ t = arg_types["a"]
2511
+ m = arg_values["m"]
2512
+ n = arg_values["n"]
2513
+
2514
+ if not is_tile(t):
2515
+ raise RuntimeError("tile_broadcast() argument 0 must be a tile")
2516
+
2517
+ # try to broadcast last dimension
2518
+ if t.N == 1:
2519
+ stride_n = 0
2520
+ elif t.N == n:
2521
+ stride_n = t.strides[1]
2522
+ else:
2523
+ raise RuntimeError(
2524
+ f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
2525
+ )
2526
+
2527
+ # try to broadcast first dimension
2528
+ if t.M == 1:
2529
+ stride_m = 0
2530
+ elif t.M == m:
2531
+ stride_m = t.strides[0]
2532
+ else:
2533
+ raise RuntimeError(
2534
+ f"Broadcast dimension must be 1 or match destination, shape(src) = {t.m, t.n}, shape(dest) = {m, n}"
2535
+ )
2536
+
2537
+ # force the input tile to shared memory
2538
+ t.storage = "shared"
2539
+
2540
+ tile_type = Tile(
2541
+ dtype=t.dtype, M=m, N=n, op="broadcast", storage=t.storage, strides=(stride_m, stride_n), owner=False
2542
+ )
2543
+ return tile_type
2544
+
2545
+
2546
+ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2547
+ tile = arg_values["a"]
2548
+
2549
+ template_args = []
2550
+ template_args.append(return_type.M)
2551
+ template_args.append(return_type.N)
2552
+ template_args.append(return_type.strides[0])
2553
+ template_args.append(return_type.strides[1])
2554
+
2555
+ return ((tile,), template_args)
2556
+
2557
+
2558
+ add_builtin(
2559
+ "tile_broadcast",
2560
+ input_types={"a": Tile(dtype=Any, M=Any, N=Any), "m": int, "n": int},
2561
+ value_func=tile_broadcast_value_func,
2562
+ dispatch_func=tile_broadcast_dispatch_func,
2563
+ variadic=True,
2564
+ doc="""Broadcast a tile.
2565
+
2566
+ This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
2567
+
2568
+ :param a: Tile to broadcast
2569
+ :returns: Tile with broadcast ``shape=(m, n)``""",
2570
+ group="Tile Primitives",
2571
+ export=False,
2572
+ )
2573
+
2574
+
2575
+ def tile_matmul_value_func(arg_types, arg_values):
2576
+ # return generic type (for doc builds)
2577
+ if arg_types is None:
2578
+ return Tile(dtype=Any, M=Any, N=Any)
2579
+
2580
+ if len(arg_types) != 3:
2581
+ raise RuntimeError("tile_matmul() requires 4 positional args")
2582
+
2583
+ if not is_tile(arg_types["a"]):
2584
+ raise RuntimeError("tile_matmul() argument 0 must be a tile")
2585
+
2586
+ if not is_tile(arg_types["b"]):
2587
+ raise RuntimeError("tile_matmul() argument 1 must be an tile")
2588
+
2589
+ if not isinstance(arg_types["out"], Tile):
2590
+ raise RuntimeError("tile_matmul() output argument must be a tile")
2591
+
2592
+ return None
2593
+
2594
+
2595
+ def tile_matmul_dispatch_func(arg_types: Mapping[str, type], return_type: Any, arg_values: Mapping[str, Var]):
2596
+ a = arg_values["a"]
2597
+ b = arg_values["b"]
2598
+ out = arg_values["out"]
2599
+
2600
+ # force the storage type of the input variables to shared memory
2601
+ a.type.storage = "shared"
2602
+ b.type.storage = "shared"
2603
+ out.type.storage = "shared"
2604
+
2605
+ template_args = []
2606
+ return ((a, b, out), template_args)
2607
+
2608
+
2609
+ add_builtin(
2610
+ "tile_matmul_scalar",
2611
+ input_types={"a": Tile, "b": Tile, "out": Tile},
2612
+ value_func=tile_matmul_value_func,
2613
+ dispatch_func=tile_matmul_dispatch_func,
2614
+ variadic=True,
2615
+ doc="Compute matrix product and accumulate out += a*b.",
2616
+ group="Tile Primitives",
2617
+ hidden=True,
2618
+ export=False,
2619
+ )
2620
+
2621
+
2622
+ def tile_sum_value_func(arg_types, arg_values):
2623
+ # return generic type (for doc builds)
2624
+ if arg_types is None:
2625
+ return Tile(dtype=Any, M=1, N=1)
2626
+
2627
+ if len(arg_types) != 1:
2628
+ raise RuntimeError("tile_sum() requires 1 positional args")
2629
+
2630
+ a = arg_types["a"]
2631
+
2632
+ if not is_tile(a):
2633
+ raise RuntimeError("tile_sum() argument 0 must be a tile")
2634
+
2635
+ return Tile(dtype=a.dtype, M=1, N=1, op="sum")
2636
+
2637
+
2638
+ add_builtin(
2639
+ "tile_sum",
2640
+ input_types={"a": Tile},
2641
+ value_func=tile_sum_value_func,
2642
+ variadic=True,
2643
+ doc="""Cooperatively compute the sum of the tile elements using all threads in the block.
2644
+
2645
+ :param a: The tile to compute the sum of
2646
+ :returns: A single-element tile with dimensions of (1,1) holding the sum
2647
+
2648
+ Example:
2649
+
2650
+ .. code-block:: python
2651
+
2652
+ @wp.kernel
2653
+ def compute():
2654
+
2655
+ t = wp.tile_ones(dtype=float, m=16, n=16)
2656
+ s = wp.tile_sum(t)
2657
+
2658
+ print(s)
2659
+
2660
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
2661
+
2662
+ Prints:
2663
+
2664
+ .. code-block:: text
2665
+
2666
+ tile(m=1, n=1, storage=register) = [[256]]
2667
+
2668
+ """,
2669
+ group="Tile Primitives",
2670
+ export=False,
2671
+ )
2672
+
2673
+
2674
+ def tile_min_value_func(arg_types, arg_values):
2675
+ # return generic type (for doc builds)
2676
+ if arg_types is None:
2677
+ return Tile(dtype=Any, M=1, N=1)
2678
+
2679
+ if len(arg_types) != 1:
2680
+ raise RuntimeError("tile_min() requires 1 positional args")
2681
+
2682
+ a = arg_types["a"]
2683
+
2684
+ if not is_tile(a):
2685
+ raise RuntimeError("tile_min() argument 0 must be a tile")
2686
+
2687
+ return Tile(dtype=a.dtype, M=1, N=1, op="min")
2688
+
2689
+
2690
+ add_builtin(
2691
+ "tile_min",
2692
+ input_types={"a": Tile},
2693
+ value_func=tile_min_value_func,
2694
+ variadic=True,
2695
+ doc="""Cooperatively compute the minimum of the tile elements using all threads in the block.
2696
+
2697
+ :param a: The tile to compute the minimum of
2698
+ :returns: A single-element tile with dimensions of (1,1) holding the minimum value
2699
+
2700
+ Example:
2701
+
2702
+ .. code-block:: python
2703
+
2704
+ @wp.kernel
2705
+ def compute():
2706
+
2707
+ t = wp.tile_arange(64, 128)
2708
+ s = wp.tile_min(t)
2709
+
2710
+ print(s)
2711
+
2712
+
2713
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
2714
+
2715
+ Prints:
2716
+
2717
+ .. code-block:: text
2718
+
2719
+ tile(m=1, n=1, storage=register) = [[64 ]]
2720
+
2721
+ """,
2722
+ group="Tile Primitives",
2723
+ export=False,
2724
+ )
2725
+
2726
+
2727
+ def tile_max_value_func(arg_types, arg_values):
2728
+ # return generic type (for doc builds)
2729
+ if arg_types is None:
2730
+ return Tile(dtype=Any, M=1, N=1)
2731
+
2732
+ if len(arg_types) != 1:
2733
+ raise RuntimeError("tile_max() requires 1 positional args")
2734
+
2735
+ a = arg_types["a"]
2736
+
2737
+ if not is_tile(a):
2738
+ raise RuntimeError("tile_max() argument 0 must be a tile")
2739
+
2740
+ return Tile(dtype=a.dtype, M=1, N=1, op="min")
2741
+
2742
+
2743
+ add_builtin(
2744
+ "tile_max",
2745
+ input_types={"a": Tile},
2746
+ value_func=tile_max_value_func,
2747
+ variadic=True,
2748
+ doc="""Cooperatively compute the maximum of the tile elements using all threads in the block.
2749
+
2750
+ :param a: The tile to compute the maximum from
2751
+ :returns: A single-element tile with dimensions of (1,1) holding the maximum value
2752
+
2753
+ Example:
2754
+
2755
+ .. code-block:: python
2756
+
2757
+ @wp.kernel
2758
+ def compute():
2759
+
2760
+ t = wp.tile_arange(64, 128)
2761
+ s = wp.tile_max(t)
2762
+
2763
+ print(s)
2764
+
2765
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
2766
+
2767
+ Prints:
2768
+
2769
+ .. code-block:: text
2770
+
2771
+ tile(m=1, n=1, storage=register) = [[127 ]]
2772
+
2773
+ """,
2774
+ group="Tile Primitives",
2775
+ export=False,
2776
+ )
2777
+
2778
+
2779
+ # does type propagation for load()
2780
+ def tile_reduce_value_func(arg_types, arg_values):
2781
+ if arg_types is None:
2782
+ return Tile(dtype=Any, M=Any, N=Any)
2783
+
2784
+ a = arg_types["a"]
2785
+
2786
+ # check all args are tiles
2787
+ if not is_tile(a):
2788
+ raise RuntimeError(f"tile_reduce() arguments must be tiles, got type {a}")
2789
+
2790
+ return Tile(dtype=a.dtype, M=1, N=1, op="reduce")
2791
+
2792
+
2793
+ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2794
+ func_args = (args["op"], *args["args"])
2795
+ template_args = ()
2796
+ return (func_args, template_args)
2797
+
2798
+
2799
+ add_builtin(
2800
+ "tile_reduce",
2801
+ input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)},
2802
+ value_func=tile_reduce_value_func,
2803
+ native_func="tile_reduce",
2804
+ doc="""Apply a custom reduction operator across the tile.
2805
+
2806
+ This function cooperatively performs a reduction using the provided operator across the tile.
2807
+
2808
+ :param op: A callable function that accepts two arguments and returns one argument, may be a user function or builtin
2809
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
2810
+ :returns: A single-element tile with ``shape=(1,1)`` with the same datatype as the input tile.
2811
+
2812
+ Example:
2813
+
2814
+ .. code-block:: python
2815
+
2816
+ @wp.kernel
2817
+ def factorial():
2818
+
2819
+ t = wp.tile_arange(1, 10, dtype=int)
2820
+ s = wp.tile_reduce(wp.mul, t)
2821
+
2822
+ print(s)
2823
+
2824
+ wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
2825
+
2826
+ Prints:
2827
+
2828
+ .. code-block:: text
2829
+
2830
+ tile(m=1, n=1, storage=register) = [[362880]]
2831
+ """,
2832
+ group="Tile Primitives",
2833
+ export=False,
2834
+ )
2835
+
2836
+ # maps
2837
+
2838
+
2839
+ # does type propagation for load()
2840
+ def tile_unary_map_value_func(arg_types, arg_values):
2841
+ if arg_types is None:
2842
+ return Tile(dtype=Any, M=Any, N=Any)
2843
+
2844
+ a = arg_types["a"]
2845
+
2846
+ # check all args are tiles
2847
+ if not is_tile(a):
2848
+ raise RuntimeError(f"tile_map() arguments must be tiles, got type {a}")
2849
+
2850
+ return TileUnaryMap(a)
2851
+
2852
+
2853
+ # def tile_map_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2854
+ # func_args = (args["op"], *args["args"])
2855
+ # template_args = ()
2856
+ # return (func_args, template_args)
2857
+
2858
+
2859
+ add_builtin(
2860
+ "tile_map",
2861
+ input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)},
2862
+ value_func=tile_unary_map_value_func,
2863
+ # dispatch_func=tile_map_dispatch_func,
2864
+ # variadic=True,
2865
+ native_func="tile_unary_map",
2866
+ doc="""Apply a unary function onto the tile.
2867
+
2868
+ This function cooperatively applies a unary function to each element of the tile using all threads in the block.
2869
+
2870
+ :param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
2871
+ :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
2872
+ :returns: A tile with the same dimensions and datatype as the input tile.
2873
+
2874
+ Example:
2875
+
2876
+ .. code-block:: python
2877
+
2878
+ @wp.kernel
2879
+ def compute():
2880
+
2881
+ t = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
2882
+ s = wp.tile_map(wp.sin, t)
2883
+
2884
+ print(s)
2885
+
2886
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
2887
+
2888
+ Prints:
2889
+
2890
+ .. code-block:: text
2891
+
2892
+ tile(m=1, n=10, storage=register) = [[0 0.0998334 0.198669 0.29552 ...]]
2893
+ """,
2894
+ group="Tile Primitives",
2895
+ export=False,
2896
+ )
2897
+
2898
+
2899
+ def tile_binary_map_value_func(arg_types, arg_values):
2900
+ if arg_types is None:
2901
+ return Tile(dtype=Any, M=Any, N=Any)
2902
+
2903
+ a = arg_types["a"]
2904
+ b = arg_types["b"]
2905
+
2906
+ # check all args are tiles
2907
+ if not is_tile(a):
2908
+ raise RuntimeError(f"tile_map() arguments must be tiles, got type {a}")
2909
+
2910
+ if not is_tile(b):
2911
+ raise RuntimeError(f"tile_map() arguments must be tiles, got type {b}")
2912
+
2913
+ # use first argument to define output type
2914
+ if not types_equal(a.dtype, b.dtype):
2915
+ raise RuntimeError(f"tile_map() arguments must all have the same type {a.dtype} != {b.dtype}")
2916
+
2917
+ if a.M != b.M:
2918
+ raise RuntimeError(f"tile_map() arguments must all have the same m dimension {a.M} != {b.M}")
2919
+
2920
+ if a.N != b.N:
2921
+ raise RuntimeError(f"tile_map() arguments must all have the same n dimension {a.N} != {b.N}")
2922
+
2923
+ return TileBinaryMap(a, b)
2924
+
2925
+
2926
+ add_builtin(
2927
+ "tile_map",
2928
+ input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
2929
+ value_func=tile_binary_map_value_func,
2930
+ # dispatch_func=tile_map_dispatch_func,
2931
+ # variadic=True,
2932
+ native_func="tile_binary_map",
2933
+ doc="""Apply a binary function onto the tile.
2934
+
2935
+ This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
2936
+ Both input tiles must have the same dimensions and datatype.
2937
+
2938
+ :param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
2939
+ :param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
2940
+ :param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
2941
+ :returns: A tile with the same dimensions and datatype as the input tiles.
2942
+
2943
+ Example:
2944
+
2945
+ .. code-block:: python
2946
+
2947
+ @wp.kernel
2948
+ def compute():
2949
+
2950
+ a = wp.tile_arange(0.0, 1.0, 0.1, dtype=float)
2951
+ b = wp.tile_ones(m=1, n=10, dtype=float)
2952
+
2953
+ s = wp.tile_map(wp.add, a, b)
2954
+
2955
+ print(s)
2956
+
2957
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
2958
+
2959
+ Prints:
2960
+
2961
+ .. code-block:: text
2962
+
2963
+ tile(m=1, n=10, storage=register) = [[1 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]]""",
2964
+ group="Tile Primitives",
2965
+ export=False,
2966
+ )
2967
+
2968
+
1705
2969
  # ---------------------------------
1706
2970
  # Linear Algebra
1707
2971
 
@@ -2389,6 +3653,16 @@ add_builtin(
2389
3653
  "iter_next", input_types={"query": mesh_query_aabb_t}, value_type=int, group="Utility", export=False, hidden=True
2390
3654
  )
2391
3655
 
3656
+ add_builtin(
3657
+ "reversed",
3658
+ input_types={"range": range_t},
3659
+ value_type=range_t,
3660
+ native_func="iter_reverse",
3661
+ group="Utility",
3662
+ doc="""Returns the range in reversed order.""",
3663
+ export=False,
3664
+ )
3665
+
2392
3666
  # ---------------------------------
2393
3667
  # Volumes
2394
3668
 
@@ -2804,7 +4078,11 @@ add_builtin(
2804
4078
  doc="Return a random float between [low, high).",
2805
4079
  )
2806
4080
  add_builtin(
2807
- "randn", input_types={"state": uint32}, value_type=float, group="Random", doc="Sample a normal distribution."
4081
+ "randn",
4082
+ input_types={"state": uint32},
4083
+ value_type=float,
4084
+ group="Random",
4085
+ doc="Sample a normal (Gaussian) distribution of mean 0 and variance 1. ",
2808
4086
  )
2809
4087
 
2810
4088
  add_builtin(
@@ -2976,12 +4254,20 @@ add_builtin(
2976
4254
  )
2977
4255
 
2978
4256
 
4257
+ def printf_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
4258
+ if arg_types is not None:
4259
+ if len(arg_types.get("args", ())) > 32:
4260
+ raise RuntimeError("the maximum number of variadic arguments that can be passed to `printf` is 32")
4261
+
4262
+ return None
4263
+
4264
+
2979
4265
  def printf_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
2980
4266
  # We're in the codegen stage where we emit the code calling the built-in.
2981
4267
  # Further validate the given argument values if needed and map them
2982
4268
  # to the underlying C++ function's runtime and template params.
2983
4269
 
2984
- func_args = (args["fmt"], *args["args"])
4270
+ func_args = (args["fmt"], *args.get("args", ()))
2985
4271
  template_args = ()
2986
4272
  return (func_args, template_args)
2987
4273
 
@@ -2992,6 +4278,7 @@ add_builtin(
2992
4278
  input_types={"fmt": str, "*args": Any},
2993
4279
  namespace="",
2994
4280
  variadic=True,
4281
+ value_func=printf_value_func,
2995
4282
  dispatch_func=printf_dispatch_func,
2996
4283
  group="Utility",
2997
4284
  doc="Allows printing formatted strings using C-style format specifiers.",
@@ -3380,6 +4667,19 @@ def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str,
3380
4667
  return arr_type.dtype
3381
4668
 
3382
4669
 
4670
+ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
4671
+ # as this is a codegen callback, we can mark the fact that this func writes to an array here
4672
+ if warp.config.verify_autograd_array_access:
4673
+ arr = args["arr"]
4674
+ arr.mark_write()
4675
+
4676
+ func_args = tuple(args.values())
4677
+ # we don't need to specify template arguments for atomic ops
4678
+ template_args = ()
4679
+
4680
+ return (func_args, template_args)
4681
+
4682
+
3383
4683
  for array_type in array_types:
3384
4684
  # don't list indexed array operations explicitly in docs
3385
4685
  hidden = array_type == indexedarray
@@ -3390,6 +4690,7 @@ for array_type in array_types:
3390
4690
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
3391
4691
  constraint=atomic_op_constraint,
3392
4692
  value_func=atomic_op_value_func,
4693
+ dispatch_func=atomic_op_dispatch_func,
3393
4694
  doc="Atomically add ``value`` onto ``arr[i]`` and return the old value.",
3394
4695
  group="Utility",
3395
4696
  skip_replay=True,
@@ -3400,6 +4701,7 @@ for array_type in array_types:
3400
4701
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
3401
4702
  constraint=atomic_op_constraint,
3402
4703
  value_func=atomic_op_value_func,
4704
+ dispatch_func=atomic_op_dispatch_func,
3403
4705
  doc="Atomically add ``value`` onto ``arr[i,j]`` and return the old value.",
3404
4706
  group="Utility",
3405
4707
  skip_replay=True,
@@ -3410,6 +4712,7 @@ for array_type in array_types:
3410
4712
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
3411
4713
  constraint=atomic_op_constraint,
3412
4714
  value_func=atomic_op_value_func,
4715
+ dispatch_func=atomic_op_dispatch_func,
3413
4716
  doc="Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value.",
3414
4717
  group="Utility",
3415
4718
  skip_replay=True,
@@ -3420,6 +4723,7 @@ for array_type in array_types:
3420
4723
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
3421
4724
  constraint=atomic_op_constraint,
3422
4725
  value_func=atomic_op_value_func,
4726
+ dispatch_func=atomic_op_dispatch_func,
3423
4727
  doc="Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
3424
4728
  group="Utility",
3425
4729
  skip_replay=True,
@@ -3431,6 +4735,7 @@ for array_type in array_types:
3431
4735
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
3432
4736
  constraint=atomic_op_constraint,
3433
4737
  value_func=atomic_op_value_func,
4738
+ dispatch_func=atomic_op_dispatch_func,
3434
4739
  doc="Atomically subtract ``value`` onto ``arr[i]`` and return the old value.",
3435
4740
  group="Utility",
3436
4741
  skip_replay=True,
@@ -3441,6 +4746,7 @@ for array_type in array_types:
3441
4746
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
3442
4747
  constraint=atomic_op_constraint,
3443
4748
  value_func=atomic_op_value_func,
4749
+ dispatch_func=atomic_op_dispatch_func,
3444
4750
  doc="Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value.",
3445
4751
  group="Utility",
3446
4752
  skip_replay=True,
@@ -3451,6 +4757,7 @@ for array_type in array_types:
3451
4757
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
3452
4758
  constraint=atomic_op_constraint,
3453
4759
  value_func=atomic_op_value_func,
4760
+ dispatch_func=atomic_op_dispatch_func,
3454
4761
  doc="Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value.",
3455
4762
  group="Utility",
3456
4763
  skip_replay=True,
@@ -3461,6 +4768,7 @@ for array_type in array_types:
3461
4768
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
3462
4769
  constraint=atomic_op_constraint,
3463
4770
  value_func=atomic_op_value_func,
4771
+ dispatch_func=atomic_op_dispatch_func,
3464
4772
  doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
3465
4773
  group="Utility",
3466
4774
  skip_replay=True,
@@ -3472,6 +4780,7 @@ for array_type in array_types:
3472
4780
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
3473
4781
  constraint=atomic_op_constraint,
3474
4782
  value_func=atomic_op_value_func,
4783
+ dispatch_func=atomic_op_dispatch_func,
3475
4784
  doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
3476
4785
 
3477
4786
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -3484,6 +4793,7 @@ for array_type in array_types:
3484
4793
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
3485
4794
  constraint=atomic_op_constraint,
3486
4795
  value_func=atomic_op_value_func,
4796
+ dispatch_func=atomic_op_dispatch_func,
3487
4797
  doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
3488
4798
 
3489
4799
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -3496,6 +4806,7 @@ for array_type in array_types:
3496
4806
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
3497
4807
  constraint=atomic_op_constraint,
3498
4808
  value_func=atomic_op_value_func,
4809
+ dispatch_func=atomic_op_dispatch_func,
3499
4810
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
3500
4811
 
3501
4812
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -3508,6 +4819,7 @@ for array_type in array_types:
3508
4819
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
3509
4820
  constraint=atomic_op_constraint,
3510
4821
  value_func=atomic_op_value_func,
4822
+ dispatch_func=atomic_op_dispatch_func,
3511
4823
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
3512
4824
 
3513
4825
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -3521,6 +4833,7 @@ for array_type in array_types:
3521
4833
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
3522
4834
  constraint=atomic_op_constraint,
3523
4835
  value_func=atomic_op_value_func,
4836
+ dispatch_func=atomic_op_dispatch_func,
3524
4837
  doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
3525
4838
 
3526
4839
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -3533,6 +4846,7 @@ for array_type in array_types:
3533
4846
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
3534
4847
  constraint=atomic_op_constraint,
3535
4848
  value_func=atomic_op_value_func,
4849
+ dispatch_func=atomic_op_dispatch_func,
3536
4850
  doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
3537
4851
 
3538
4852
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -3545,6 +4859,7 @@ for array_type in array_types:
3545
4859
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
3546
4860
  constraint=atomic_op_constraint,
3547
4861
  value_func=atomic_op_value_func,
4862
+ dispatch_func=atomic_op_dispatch_func,
3548
4863
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
3549
4864
 
3550
4865
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -3557,6 +4872,7 @@ for array_type in array_types:
3557
4872
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
3558
4873
  constraint=atomic_op_constraint,
3559
4874
  value_func=atomic_op_value_func,
4875
+ dispatch_func=atomic_op_dispatch_func,
3560
4876
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
3561
4877
 
3562
4878
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -3775,6 +5091,15 @@ for t in scalar_types + vector_types + (bool,):
3775
5091
  hidden=True,
3776
5092
  )
3777
5093
 
5094
+ add_builtin(
5095
+ "expect_neq",
5096
+ input_types={"a": t, "b": t},
5097
+ value_type=None,
5098
+ doc="Prints an error to stdout if ``a`` and ``b`` are not equal",
5099
+ group="Utility",
5100
+ hidden=True,
5101
+ )
5102
+
3778
5103
 
3779
5104
  def expect_eq_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
3780
5105
  if not types_equal(arg_types["a"], arg_types["b"]):
@@ -4315,6 +5640,493 @@ for t in int_types:
4315
5640
 
4316
5641
  add_builtin("unot", input_types={"a": array(dtype=Any)}, value_type=builtins.bool, doc="", group="Operators")
4317
5642
 
5643
+
5644
+ # Tile operators
5645
+ def tile_unary_value_func(arg_types, arg_values):
5646
+ if arg_types is None:
5647
+ return Tile(dtype=Any, M=Any, N=Any)
5648
+
5649
+ t = arg_types["x"]
5650
+
5651
+ if not is_tile(t):
5652
+ raise RuntimeError("Expected tile for unary expression")
5653
+
5654
+ return TileUnaryMap(t)
5655
+
5656
+
5657
+ def tile_scalar_mul_value_func(arg_types, arg_values):
5658
+ if arg_types is None:
5659
+ return Tile(dtype=Any, M=Any, N=Any)
5660
+
5661
+ x = arg_types["x"]
5662
+ y = arg_types["y"]
5663
+
5664
+ # tile*scalar
5665
+ if is_tile(x):
5666
+ if x.dtype != y:
5667
+ raise RuntimeError(
5668
+ "Scalar factor should have the same type as tile for tile*scalar, tile type: {x} scalar type: {y}"
5669
+ )
5670
+
5671
+ return TileBinaryMap(x, TileConstant(y, x.M, x.N))
5672
+
5673
+ # scalar*tile
5674
+ if is_tile(y):
5675
+ if y.dtype != x:
5676
+ raise RuntimeError(
5677
+ "Scalar factor should have the same type as tile for scalar*tile, tile type: {x} scalar type: {y}"
5678
+ )
5679
+
5680
+ return TileBinaryMap(TileConstant(x, y.M, y.N), y)
5681
+
5682
+
5683
+ add_builtin(
5684
+ "neg",
5685
+ input_types={"x": Tile(dtype=Any, M=Any, N=Any)},
5686
+ value_func=tile_unary_value_func,
5687
+ doc="Negate each element of a tile",
5688
+ export=False,
5689
+ native_func="tile_neg",
5690
+ group="Operators",
5691
+ )
5692
+
5693
+ add_builtin(
5694
+ "add",
5695
+ input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
5696
+ value_func=tile_binary_map_value_func,
5697
+ # dispatch_func=tile_map_dispatch_func,
5698
+ # variadic=True,
5699
+ native_func="tile_add",
5700
+ doc="Add each element of two tiles together",
5701
+ group="Tile Primitives",
5702
+ export=False,
5703
+ )
5704
+
5705
+ add_builtin(
5706
+ "mul",
5707
+ input_types={"x": Tile(dtype=Any, M=Any, N=Any), "y": Scalar},
5708
+ value_func=tile_scalar_mul_value_func,
5709
+ doc="Multiply each element of a tile by a scalar",
5710
+ export=False,
5711
+ native_func="tile_mul",
5712
+ group="Operators",
5713
+ )
5714
+
5715
+ add_builtin(
5716
+ "mul",
5717
+ input_types={"x": Scalar, "y": Tile(dtype=Any, M=Any, N=Any)},
5718
+ value_func=tile_scalar_mul_value_func,
5719
+ doc="Multiply each element of a tile by a scalar",
5720
+ export=False,
5721
+ native_func="tile_mul",
5722
+ group="Operators",
5723
+ )
5724
+
5725
+
5726
+ ##
5727
+ ## MathDx, LTOIR-based, Tile functions
5728
+ ##
5729
+
5730
+
5731
+ ##
5732
+ ## Matmul
5733
+ ##
5734
+ def tile_matmul_generic_value_func(arg_types, arg_values):
5735
+ # return generic type (for doc builds)
5736
+ if arg_types is None:
5737
+ return Tile(dtype=Any, M=Any, N=Any)
5738
+
5739
+ a = arg_types["a"]
5740
+ b = arg_types["b"]
5741
+
5742
+ if not is_tile(a):
5743
+ raise RuntimeError("tile_matmul() argument 0 must be a tile")
5744
+ if not is_tile(b):
5745
+ raise RuntimeError("tile_matmul() argument 1 must be an tile")
5746
+
5747
+ # out = wp.tile_matmul(a, b)
5748
+ if len(arg_types) == 2:
5749
+ return Tile(dtype=a.dtype, M=a.M, N=b.N, storage="shared")
5750
+
5751
+ # wp.tile_matmul(a, b, out)
5752
+ elif len(arg_types) == 3:
5753
+ if not is_tile(arg_types["out"]):
5754
+ raise RuntimeError("tile_matmul() output argument must be a tile")
5755
+
5756
+ return None
5757
+
5758
+
5759
+ def tile_matmul_generic_lto_dispatch_func(
5760
+ arg_types: Mapping[str, type],
5761
+ return_type: Any,
5762
+ return_values: List[Var],
5763
+ arg_values: Mapping[str, Var],
5764
+ options: Mapping[str, Any],
5765
+ builder: warp.context.ModuleBuilder,
5766
+ ):
5767
+ a = arg_values["a"]
5768
+ b = arg_values["b"]
5769
+
5770
+ if len(return_values) > 0:
5771
+ accumulate = 0 # for c = tile_matmul(a,b) case we want to overwrite c value
5772
+ out = return_values[0]
5773
+ else:
5774
+ accumulate = 1 # for tile_matmul(a,b,c) case we want to add to c value
5775
+ out = arg_values["out"]
5776
+
5777
+ if any(not is_tile(arg.type) for arg in [a, b, out]):
5778
+ raise RuntimeError("tile_matmul() requires three Tile arguments")
5779
+
5780
+ if any(arg.type.dtype not in [float16, float32, float64, vec2h, vec2f, vec2d] for arg in [a, b, out]):
5781
+ raise RuntimeError(
5782
+ "tile_matmul() arguments must be tiles of float16, float32 or float64, vec2h, vec2f, vec2d entries"
5783
+ )
5784
+
5785
+ if (a.type.N != b.type.M) or (a.type.M != out.type.M) or (b.type.N != out.type.N):
5786
+ raise RuntimeError("tile_matmul(A, B, C) requires sizes of A, B and C to be consistent for a matmul")
5787
+
5788
+ # set the storage type to the inputs to shared
5789
+ a.type.storage = "shared"
5790
+ b.type.storage = "shared"
5791
+ out.type.storage = "shared"
5792
+ template_args = [accumulate]
5793
+
5794
+ # Maps Python/Warp types to C++ types and enums
5795
+ def cublasdx_type_map(dtype):
5796
+ if dtype == float16:
5797
+ return ("wp::float16", 3, 0)
5798
+ if dtype == float32:
5799
+ return ("wp::float32", 5, 0)
5800
+ if dtype == float64:
5801
+ return ("wp::float64", 6, 0)
5802
+ if dtype == vec2h:
5803
+ return ("wp::vec2h", 3, 1)
5804
+ if dtype == vec2f:
5805
+ return ("wp::vec2f", 5, 1)
5806
+ if dtype == vec2d:
5807
+ return ("wp::vec2d", 6, 1)
5808
+ raise RuntimeError("Unsupported input type in tile_matmul")
5809
+
5810
+ def cublasdx_arrangement_map(layout):
5811
+ if layout == "colmajor":
5812
+ return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
5813
+ if layout == "rowmajor":
5814
+ return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
5815
+ raise RuntimeError("Unsupported layout in tile_matmul")
5816
+
5817
+ # generate the LTO
5818
+ M, K = a.type.M, a.type.N
5819
+ _, N = b.type.M, b.type.N
5820
+ num_threads = options["block_dim"]
5821
+ arch = options["output_arch"]
5822
+
5823
+ def make_function(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout):
5824
+ (a_dtype, a_prec, a_type) = cublasdx_type_map(adtype)
5825
+ (b_dtype, b_prec, b_type) = cublasdx_type_map(bdtype)
5826
+ (c_dtype, c_prec, c_type) = cublasdx_type_map(cdtype)
5827
+ a_arrangement = cublasdx_arrangement_map(alayout)
5828
+ b_arrangement = cublasdx_arrangement_map(blayout)
5829
+ c_arrangement = cublasdx_arrangement_map(clayout)
5830
+
5831
+ if a_type != b_type or a_type != c_type:
5832
+ raise RuntimeError("time_matmul(A, B, C) requires all inputs to be real or complex")
5833
+ element_type = a_type
5834
+
5835
+ lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
5836
+
5837
+ # early out if LTO for this combination already exists for this module
5838
+ if lto_symbol in builder.ltoirs:
5839
+ return lto_symbol, builder.ltoirs[lto_symbol]
5840
+
5841
+ # otherwise compile LTO
5842
+ lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
5843
+ result = warp.context.runtime.core.cuda_compile_dot(
5844
+ lto_code.name.encode("utf-8"),
5845
+ lto_symbol.encode("utf-8"),
5846
+ 0,
5847
+ None,
5848
+ None,
5849
+ arch,
5850
+ M,
5851
+ N,
5852
+ K,
5853
+ a_prec,
5854
+ b_prec,
5855
+ c_prec,
5856
+ element_type,
5857
+ a_arrangement,
5858
+ b_arrangement,
5859
+ c_arrangement,
5860
+ num_threads,
5861
+ )
5862
+ lto_code_path = Path(lto_code.name)
5863
+ if not result:
5864
+ lto_code.close()
5865
+ if lto_code_path.exists():
5866
+ lto_code_path.unlink()
5867
+ raise RuntimeError("Failed to compile tile_matmul")
5868
+ else:
5869
+ with open(lto_code.name, "rb") as f:
5870
+ lto_code_data = f.read()
5871
+ lto_code.close()
5872
+ lto_code_path.unlink()
5873
+
5874
+ builder.ltoirs[lto_symbol] = lto_code_data
5875
+ builder.ltoirs_decl[lto_symbol] = (
5876
+ f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
5877
+ )
5878
+
5879
+ return lto_symbol, lto_code_data
5880
+
5881
+ def tile_flip_layout(layout):
5882
+ if layout == "rowmajor":
5883
+ return "colmajor"
5884
+ elif layout == "colmajor":
5885
+ return "rowmajor"
5886
+
5887
+ # C += A * B
5888
+ (fun_forward, lto_forward) = make_function(
5889
+ M, N, K, a.type.dtype, b.type.dtype, out.type.dtype, a.type.layout, b.type.layout, out.type.layout
5890
+ )
5891
+ # adjA += adjC * B^T - Transpose ~= flipped layout
5892
+ (fun_backward_A, lto_backward_A) = make_function(
5893
+ M,
5894
+ K,
5895
+ N,
5896
+ out.type.dtype,
5897
+ b.type.dtype,
5898
+ a.type.dtype,
5899
+ out.type.layout,
5900
+ tile_flip_layout(b.type.layout),
5901
+ a.type.layout,
5902
+ )
5903
+ # adjB += A^T * adjC - Transpose ~= flipped layout
5904
+ (fun_backward_B, lto_backward_B) = make_function(
5905
+ K,
5906
+ N,
5907
+ M,
5908
+ a.type.dtype,
5909
+ out.type.dtype,
5910
+ b.type.dtype,
5911
+ tile_flip_layout(a.type.layout),
5912
+ out.type.layout,
5913
+ b.type.layout,
5914
+ )
5915
+
5916
+ return (
5917
+ (
5918
+ Var(fun_forward, str, False, True, False),
5919
+ Var(fun_backward_A, str, False, True, False),
5920
+ Var(fun_backward_B, str, False, True, False),
5921
+ a,
5922
+ b,
5923
+ out,
5924
+ ),
5925
+ template_args,
5926
+ [lto_forward, lto_backward_A, lto_backward_B],
5927
+ )
5928
+
5929
+
5930
+ add_builtin(
5931
+ "tile_matmul",
5932
+ input_types={
5933
+ "a": Tile(dtype=Any, M=Any, N=Any),
5934
+ "b": Tile(dtype=Any, M=Any, N=Any),
5935
+ "out": Tile(dtype=Any, M=Any, N=Any),
5936
+ },
5937
+ value_func=tile_matmul_generic_value_func,
5938
+ lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
5939
+ variadic=False,
5940
+ doc="""Computes the matrix product and accumulates ``out += a*b``.
5941
+
5942
+ Supported datatypes are:
5943
+ * fp16, fp32, fp64 (real)
5944
+ * vec2h, vec2f, vec2d (complex)
5945
+
5946
+ All input and output tiles must have the same datatype. Tile data will be automatically be migrated
5947
+ to shared memory if necessary and will use TensorCore operations when available.
5948
+
5949
+ :param a: A tile with ``shape=(M, K)``
5950
+ :param b: A tile with ``shape=(K, N)``
5951
+ :param out: A tile with ``shape=(M, N)``
5952
+ """,
5953
+ group="Tile Primitives",
5954
+ export=False,
5955
+ )
5956
+
5957
+ add_builtin(
5958
+ "tile_matmul",
5959
+ input_types={"a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
5960
+ value_func=tile_matmul_generic_value_func,
5961
+ lto_dispatch_func=tile_matmul_generic_lto_dispatch_func,
5962
+ variadic=False,
5963
+ doc="""Computes the matrix product ``out = a*b``.
5964
+
5965
+ Supported datatypes are:
5966
+ * fp16, fp32, fp64 (real)
5967
+ * vec2h, vec2f, vec2d (complex)
5968
+
5969
+ Both input tiles must have the same datatype. Tile data will be automatically be migrated
5970
+ to shared memory if necessary and will use TensorCore operations when available.
5971
+
5972
+ :param a: A tile with ``shape=(M, K)``
5973
+ :param b: A tile with ``shape=(K, N)``
5974
+ :returns: A tile with ``shape=(M, N)``
5975
+ """,
5976
+ group="Tile Primitives",
5977
+ export=False,
5978
+ )
5979
+
5980
+
5981
+ ##
5982
+ ## FFT
5983
+ ##
5984
+ def tile_fft_generic_value_func(arg_types, arg_values):
5985
+ if arg_types is None:
5986
+ return Tile(dtype=Any, M=Any, N=Any)
5987
+
5988
+ if len(arg_types) != 1:
5989
+ raise RuntimeError("tile_fft() requires 1 positional args")
5990
+
5991
+ if not is_tile(arg_types["inout"]):
5992
+ raise RuntimeError("tile_fft() argument 0 must be a tile")
5993
+
5994
+ if arg_types["inout"].storage != "register":
5995
+ raise RuntimeError("tile_fft() input/output argument must have register memory storage")
5996
+
5997
+ return None
5998
+
5999
+
6000
+ def tile_fft_generic_lto_dispatch_func(
6001
+ arg_types: Mapping[str, type],
6002
+ return_type: Any,
6003
+ return_values: List[Var],
6004
+ arg_values: Mapping[str, Var],
6005
+ options: Mapping[str, Any],
6006
+ builder: warp.context.ModuleBuilder,
6007
+ direction: str = None,
6008
+ ):
6009
+ inout = arg_values["inout"]
6010
+ inout.type.storage = "register"
6011
+
6012
+ if not is_tile(inout.type):
6013
+ raise RuntimeError("tile_fft() arguments must be a single tile with register storage")
6014
+
6015
+ if inout.type.dtype not in [vec2f, vec2d]:
6016
+ raise RuntimeError("tile_fft() argument must be a tile of vec2f or vec2d (interpreted as complex) entries")
6017
+
6018
+ # see libcufftdx.hpp
6019
+ if direction == "forward":
6020
+ dir = 0 # CUFFTDX_DIRECTION_FORWARD
6021
+ elif direction == "inverse":
6022
+ dir = 1 # CUFFTDX_DIRECTION_INVERSE
6023
+ else:
6024
+ raise RuntimeError("Invalid direction")
6025
+
6026
+ if inout.type.dtype == vec2f:
6027
+ dtype = "wp::vec2f"
6028
+ precision = 5 # COMMONDX_PRECISION_F32
6029
+ elif inout.type.dtype == vec2d:
6030
+ dtype = "wp::vec2d"
6031
+ precision = 6 # COMMONDX_PRECISION_F64
6032
+ else:
6033
+ raise RuntimeError("Unsupported datatype")
6034
+
6035
+ # M FFTs of size N each
6036
+ batch, size = inout.type.M, inout.type.N
6037
+ num_threads = options["block_dim"]
6038
+ arch = options["output_arch"]
6039
+ ept = size // num_threads
6040
+ lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
6041
+
6042
+ # early out if LTO for this combination already exists for this module
6043
+ if lto_symbol in builder.ltoirs:
6044
+ return lto_symbol, builder.ltoirs[lto_symbol]
6045
+
6046
+ # otherwise compile LTO
6047
+ lto_code = tempfile.NamedTemporaryFile(prefix="warp", delete=False)
6048
+ shared_memory_size = ctypes.c_int(0)
6049
+
6050
+ result = warp.context.runtime.core.cuda_compile_fft(
6051
+ lto_code.name.encode("utf-8"),
6052
+ lto_symbol.encode("utf-8"),
6053
+ 0,
6054
+ None,
6055
+ None,
6056
+ arch,
6057
+ size,
6058
+ ept,
6059
+ dir,
6060
+ precision,
6061
+ ctypes.byref(shared_memory_size),
6062
+ )
6063
+ lto_code_path = Path(lto_code.name)
6064
+ if not result:
6065
+ lto_code.close()
6066
+ if lto_code_path.exists():
6067
+ lto_code_path.unlink()
6068
+ raise RuntimeError("Failed to compile tile_matmul")
6069
+
6070
+ with open(lto_code.name, "rb") as f:
6071
+ lto_code_data = f.read()
6072
+
6073
+ lto_code.close()
6074
+ lto_code_path.unlink()
6075
+
6076
+ builder.ltoirs[lto_symbol] = lto_code_data
6077
+
6078
+ return (
6079
+ (
6080
+ Var(lto_symbol, str, False, True, False),
6081
+ Var(dtype, str, False, True, False),
6082
+ Var(str(shared_memory_size.value), str, False, True, False),
6083
+ Var(str(batch), str, False, True, False),
6084
+ Var(str(ept), str, False, True, False),
6085
+ inout,
6086
+ ),
6087
+ [],
6088
+ [lto_code_data],
6089
+ )
6090
+
6091
+
6092
+ add_builtin(
6093
+ "tile_fft",
6094
+ input_types={"inout": Tile},
6095
+ value_func=tile_fft_generic_value_func,
6096
+ lto_dispatch_func=functools.partial(tile_fft_generic_lto_dispatch_func, direction="forward"),
6097
+ variadic=True,
6098
+ doc="""Compute the forward FFT along the second dimension of a 2D tile of data.
6099
+
6100
+ This function cooperatively computes the forward FFT on a tile of data inplace, treating each row individually.
6101
+
6102
+ Supported datatypes are:
6103
+ * vec2f, vec2d
6104
+
6105
+ :param inout: The input/output tile""",
6106
+ group="Tile Primitives",
6107
+ export=False,
6108
+ namespace="",
6109
+ )
6110
+
6111
+ add_builtin(
6112
+ "tile_ifft",
6113
+ input_types={"inout": Tile},
6114
+ value_func=tile_fft_generic_value_func,
6115
+ lto_dispatch_func=functools.partial(tile_fft_generic_lto_dispatch_func, direction="inverse"),
6116
+ variadic=True,
6117
+ doc="""Compute the inverse FFT along the second dimension of a 2D tile of data.
6118
+
6119
+ This function cooperatively computes the inverse FFT on a tile of data inplace, treating each row individually.
6120
+
6121
+ Supported datatypes are:
6122
+ * vec2f, vec2d
6123
+
6124
+ :param inout: The input/output tile""",
6125
+ group="Tile Primitives",
6126
+ export=False,
6127
+ namespace="",
6128
+ )
6129
+
4318
6130
  # ---------------------------------
4319
6131
  # Code Generation
4320
6132