warp-lang 1.4.1__py3-none-manylinux2014_x86_64.whl → 1.5.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (164) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1920 -111
  8. warp/codegen.py +186 -62
  9. warp/config.py +2 -2
  10. warp/context.py +322 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/core/example_dem.py +2 -1
  17. warp/examples/core/example_mesh_intersect.py +3 -3
  18. warp/examples/fem/example_adaptive_grid.py +37 -10
  19. warp/examples/fem/example_apic_fluid.py +3 -2
  20. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  21. warp/examples/fem/example_deformed_geometry.py +1 -1
  22. warp/examples/fem/example_diffusion_3d.py +47 -4
  23. warp/examples/fem/example_distortion_energy.py +220 -0
  24. warp/examples/fem/example_magnetostatics.py +127 -85
  25. warp/examples/fem/example_nonconforming_contact.py +5 -5
  26. warp/examples/fem/example_stokes.py +3 -1
  27. warp/examples/fem/example_streamlines.py +12 -19
  28. warp/examples/fem/utils.py +38 -15
  29. warp/examples/optim/example_walker.py +2 -2
  30. warp/examples/sim/example_cloth.py +2 -25
  31. warp/examples/sim/example_jacobian_ik.py +6 -2
  32. warp/examples/sim/example_quadruped.py +2 -1
  33. warp/examples/tile/example_tile_convolution.py +58 -0
  34. warp/examples/tile/example_tile_fft.py +47 -0
  35. warp/examples/tile/example_tile_filtering.py +105 -0
  36. warp/examples/tile/example_tile_matmul.py +79 -0
  37. warp/examples/tile/example_tile_mlp.py +375 -0
  38. warp/fem/__init__.py +8 -0
  39. warp/fem/cache.py +16 -12
  40. warp/fem/dirichlet.py +1 -1
  41. warp/fem/domain.py +44 -1
  42. warp/fem/field/__init__.py +1 -2
  43. warp/fem/field/field.py +31 -19
  44. warp/fem/field/nodal_field.py +101 -49
  45. warp/fem/field/virtual.py +794 -0
  46. warp/fem/geometry/__init__.py +2 -2
  47. warp/fem/geometry/deformed_geometry.py +3 -105
  48. warp/fem/geometry/element.py +13 -0
  49. warp/fem/geometry/geometry.py +165 -5
  50. warp/fem/geometry/grid_2d.py +3 -6
  51. warp/fem/geometry/grid_3d.py +31 -28
  52. warp/fem/geometry/hexmesh.py +3 -46
  53. warp/fem/geometry/nanogrid.py +3 -2
  54. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  55. warp/fem/geometry/tetmesh.py +2 -43
  56. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  57. warp/fem/integrate.py +683 -261
  58. warp/fem/linalg.py +404 -0
  59. warp/fem/operator.py +101 -18
  60. warp/fem/polynomial.py +5 -5
  61. warp/fem/quadrature/quadrature.py +45 -21
  62. warp/fem/space/__init__.py +45 -11
  63. warp/fem/space/basis_function_space.py +451 -0
  64. warp/fem/space/basis_space.py +58 -11
  65. warp/fem/space/function_space.py +146 -5
  66. warp/fem/space/grid_2d_function_space.py +80 -66
  67. warp/fem/space/grid_3d_function_space.py +113 -68
  68. warp/fem/space/hexmesh_function_space.py +96 -108
  69. warp/fem/space/nanogrid_function_space.py +62 -110
  70. warp/fem/space/quadmesh_function_space.py +208 -0
  71. warp/fem/space/shape/__init__.py +45 -7
  72. warp/fem/space/shape/cube_shape_function.py +328 -54
  73. warp/fem/space/shape/shape_function.py +10 -1
  74. warp/fem/space/shape/square_shape_function.py +328 -60
  75. warp/fem/space/shape/tet_shape_function.py +269 -19
  76. warp/fem/space/shape/triangle_shape_function.py +238 -19
  77. warp/fem/space/tetmesh_function_space.py +69 -37
  78. warp/fem/space/topology.py +38 -0
  79. warp/fem/space/trimesh_function_space.py +179 -0
  80. warp/fem/utils.py +6 -331
  81. warp/jax_experimental.py +3 -1
  82. warp/native/array.h +55 -40
  83. warp/native/builtin.h +124 -43
  84. warp/native/bvh.h +4 -0
  85. warp/native/coloring.cpp +600 -0
  86. warp/native/cuda_util.cpp +14 -0
  87. warp/native/cuda_util.h +2 -1
  88. warp/native/fabric.h +8 -0
  89. warp/native/hashgrid.h +4 -0
  90. warp/native/marching.cu +8 -0
  91. warp/native/mat.h +14 -3
  92. warp/native/mathdx.cpp +59 -0
  93. warp/native/mesh.h +4 -0
  94. warp/native/range.h +13 -1
  95. warp/native/reduce.cpp +9 -1
  96. warp/native/reduce.cu +7 -0
  97. warp/native/runlength_encode.cpp +9 -1
  98. warp/native/runlength_encode.cu +7 -1
  99. warp/native/scan.cpp +8 -0
  100. warp/native/scan.cu +8 -0
  101. warp/native/scan.h +8 -1
  102. warp/native/sparse.cpp +8 -0
  103. warp/native/sparse.cu +8 -0
  104. warp/native/temp_buffer.h +7 -0
  105. warp/native/tile.h +1857 -0
  106. warp/native/tile_gemm.h +341 -0
  107. warp/native/tile_reduce.h +210 -0
  108. warp/native/volume_builder.cu +8 -0
  109. warp/native/volume_builder.h +8 -0
  110. warp/native/warp.cpp +10 -2
  111. warp/native/warp.cu +369 -15
  112. warp/native/warp.h +12 -2
  113. warp/optim/adam.py +39 -4
  114. warp/paddle.py +29 -12
  115. warp/render/render_opengl.py +137 -65
  116. warp/sim/graph_coloring.py +292 -0
  117. warp/sim/integrator_euler.py +4 -2
  118. warp/sim/integrator_featherstone.py +115 -44
  119. warp/sim/integrator_vbd.py +6 -0
  120. warp/sim/model.py +90 -17
  121. warp/stubs.py +651 -85
  122. warp/tape.py +12 -7
  123. warp/tests/assets/pixel.npy +0 -0
  124. warp/tests/aux_test_instancing_gc.py +18 -0
  125. warp/tests/test_array.py +207 -48
  126. warp/tests/test_closest_point_edge_edge.py +8 -8
  127. warp/tests/test_codegen.py +120 -1
  128. warp/tests/test_codegen_instancing.py +30 -0
  129. warp/tests/test_collision.py +110 -0
  130. warp/tests/test_coloring.py +241 -0
  131. warp/tests/test_context.py +34 -0
  132. warp/tests/test_examples.py +18 -4
  133. warp/tests/test_fabricarray.py +33 -0
  134. warp/tests/test_fem.py +453 -113
  135. warp/tests/test_func.py +48 -1
  136. warp/tests/test_generics.py +52 -0
  137. warp/tests/test_iter.py +68 -0
  138. warp/tests/test_mat_scalar_ops.py +1 -1
  139. warp/tests/test_mesh_query_point.py +5 -4
  140. warp/tests/test_module_hashing.py +23 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +191 -1
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_tile.py +700 -0
  145. warp/tests/test_tile_mathdx.py +144 -0
  146. warp/tests/test_tile_mlp.py +383 -0
  147. warp/tests/test_tile_reduce.py +374 -0
  148. warp/tests/test_tile_shared_memory.py +190 -0
  149. warp/tests/test_vbd.py +12 -20
  150. warp/tests/test_volume.py +43 -0
  151. warp/tests/unittest_suites.py +23 -2
  152. warp/tests/unittest_utils.py +4 -0
  153. warp/types.py +339 -73
  154. warp/utils.py +22 -1
  155. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  156. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
  157. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  158. warp/fem/field/test.py +0 -180
  159. warp/fem/field/trial.py +0 -183
  160. warp/fem/space/collocated_function_space.py +0 -102
  161. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  162. warp/fem/space/trimesh_2d_function_space.py +0 -153
  163. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  164. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -12,6 +12,7 @@ import hashlib
12
12
  import inspect
13
13
  import io
14
14
  import itertools
15
+ import json
15
16
  import operator
16
17
  import os
17
18
  import platform
@@ -21,7 +22,7 @@ import typing
21
22
  import weakref
22
23
  from copy import copy as shallowcopy
23
24
  from pathlib import Path
24
- from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
25
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
25
26
 
26
27
  import numpy as np
27
28
 
@@ -101,6 +102,7 @@ class Function:
101
102
  value_func=None,
102
103
  export_func=None,
103
104
  dispatch_func=None,
105
+ lto_dispatch_func=None,
104
106
  module=None,
105
107
  variadic=False,
106
108
  initializer_list_func=None,
@@ -137,6 +139,7 @@ class Function:
137
139
  self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded
138
140
  self.export_func = export_func
139
141
  self.dispatch_func = dispatch_func
142
+ self.lto_dispatch_func = lto_dispatch_func
140
143
  self.input_types = {}
141
144
  self.export = export
142
145
  self.doc = doc
@@ -619,10 +622,13 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
619
622
 
620
623
 
621
624
  class KernelHooks:
622
- def __init__(self, forward, backward):
625
+ def __init__(self, forward, backward, forward_smem_bytes=0, backward_smem_bytes=0):
623
626
  self.forward = forward
624
627
  self.backward = backward
625
628
 
629
+ self.forward_smem_bytes = forward_smem_bytes
630
+ self.backward_smem_bytes = backward_smem_bytes
631
+
626
632
 
627
633
  # caches source and compiled entry points for a kernel (will be populated after module loads)
628
634
  class Kernel:
@@ -970,8 +976,17 @@ def struct(c):
970
976
  return s
971
977
 
972
978
 
973
- # overload a kernel with the given argument types
974
- def overload(kernel, arg_types=None):
979
+ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
980
+ """Overload a generic kernel with the given argument types.
981
+
982
+ Can be called directly or used as a function decorator.
983
+
984
+ Args:
985
+ kernel: The generic kernel to be instantiated with concrete types.
986
+ arg_types: A list of concrete argument types for the kernel or a
987
+ dictionary specifying generic argument names as keys and concrete
988
+ types as variables.
989
+ """
975
990
  if isinstance(kernel, Kernel):
976
991
  # handle cases where user calls us directly, e.g. wp.overload(kernel, [args...])
977
992
 
@@ -1073,6 +1088,7 @@ def add_builtin(
1073
1088
  value_func=None,
1074
1089
  export_func=None,
1075
1090
  dispatch_func=None,
1091
+ lto_dispatch_func=None,
1076
1092
  doc="",
1077
1093
  namespace="wp::",
1078
1094
  variadic=False,
@@ -1113,6 +1129,9 @@ def add_builtin(
1113
1129
  The arguments returned must be of type `codegen.Var`.
1114
1130
  If not provided, all arguments passed by the users when calling
1115
1131
  the built-in are passed as-is as runtime arguments to the C++ function.
1132
+ lto_dispatch_func (Callable): Same as dispatch_func, but takes an 'option' dict
1133
+ as extra argument (indicating tile_size and target architecture) and returns
1134
+ an LTO-IR buffer as extra return value
1116
1135
  doc (str): Used to generate the Python's docstring and the HTML documentation.
1117
1136
  namespace: Namespace for the underlying C++ function.
1118
1137
  variadic (bool): Whether the function declares variadic arguments.
@@ -1249,8 +1268,10 @@ def add_builtin(
1249
1268
  key,
1250
1269
  input_types=arg_types,
1251
1270
  value_type=return_type,
1271
+ value_func=value_func if return_type is Any else None,
1252
1272
  export_func=export_func,
1253
1273
  dispatch_func=dispatch_func,
1274
+ lto_dispatch_func=lto_dispatch_func,
1254
1275
  doc=doc,
1255
1276
  namespace=namespace,
1256
1277
  variadic=variadic,
@@ -1273,6 +1294,7 @@ def add_builtin(
1273
1294
  value_func=value_func,
1274
1295
  export_func=export_func,
1275
1296
  dispatch_func=dispatch_func,
1297
+ lto_dispatch_func=lto_dispatch_func,
1276
1298
  variadic=variadic,
1277
1299
  initializer_list_func=initializer_list_func,
1278
1300
  export=export,
@@ -1539,6 +1561,8 @@ class ModuleBuilder:
1539
1561
  self.options = options
1540
1562
  self.module = module
1541
1563
  self.deferred_functions = []
1564
+ self.ltoirs = {} # map from lto symbol to lto binary
1565
+ self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
1542
1566
 
1543
1567
  if hasher is None:
1544
1568
  hasher = ModuleHasher(module)
@@ -1606,9 +1630,26 @@ class ModuleBuilder:
1606
1630
  # use dict to preserve import order
1607
1631
  self.functions[func] = None
1608
1632
 
1633
+ def build_meta(self):
1634
+ meta = {}
1635
+
1636
+ for kernel in self.kernels:
1637
+ name = kernel.get_mangled_name()
1638
+
1639
+ meta[name + "_cuda_kernel_forward_smem_bytes"] = kernel.adj.get_total_required_shared()
1640
+ meta[name + "_cuda_kernel_backward_smem_bytes"] = kernel.adj.get_total_required_shared() * 2
1641
+
1642
+ return meta
1643
+
1609
1644
  def codegen(self, device):
1610
1645
  source = ""
1611
1646
 
1647
+ # code-gen LTO forward declarations
1648
+ source += 'extern "C" {\n'
1649
+ for fwd in self.ltoirs_decl.values():
1650
+ source += fwd + "\n"
1651
+ source += "}\n"
1652
+
1612
1653
  # code-gen structs
1613
1654
  visited_structs = set()
1614
1655
  for struct in self.structs.keys():
@@ -1638,9 +1679,9 @@ class ModuleBuilder:
1638
1679
 
1639
1680
  # add headers
1640
1681
  if device == "cpu":
1641
- source = warp.codegen.cpu_module_header + source
1682
+ source = warp.codegen.cpu_module_header.format(tile_size=self.options["block_dim"]) + source
1642
1683
  else:
1643
- source = warp.codegen.cuda_module_header + source
1684
+ source = warp.codegen.cuda_module_header.format(tile_size=self.options["block_dim"]) + source
1644
1685
 
1645
1686
  return source
1646
1687
 
@@ -1659,11 +1700,12 @@ class ModuleExec:
1659
1700
  instance.handle = None
1660
1701
  return instance
1661
1702
 
1662
- def __init__(self, handle, module_hash, device):
1703
+ def __init__(self, handle, module_hash, device, meta):
1663
1704
  self.handle = handle
1664
1705
  self.module_hash = module_hash
1665
1706
  self.device = device
1666
1707
  self.kernel_hooks = {}
1708
+ self.meta = meta
1667
1709
 
1668
1710
  # release the loaded module
1669
1711
  def __del__(self):
@@ -1677,19 +1719,50 @@ class ModuleExec:
1677
1719
 
1678
1720
  # lookup and cache kernel entry points
1679
1721
  def get_kernel_hooks(self, kernel):
1680
- hooks = self.kernel_hooks.get(kernel)
1722
+ # Use kernel.adj as a unique key for cache lookups instead of the kernel itself.
1723
+ # This avoids holding a reference to the kernel and is faster than using
1724
+ # a WeakKeyDictionary with kernels as keys.
1725
+ hooks = self.kernel_hooks.get(kernel.adj)
1681
1726
  if hooks is not None:
1682
1727
  return hooks
1683
1728
 
1684
1729
  name = kernel.get_mangled_name()
1685
1730
 
1686
1731
  if self.device.is_cuda:
1687
- forward = runtime.core.cuda_get_kernel(
1688
- self.device.context, self.handle, (name + "_cuda_kernel_forward").encode("utf-8")
1732
+ forward_name = name + "_cuda_kernel_forward"
1733
+ forward_kernel = runtime.core.cuda_get_kernel(
1734
+ self.device.context, self.handle, forward_name.encode("utf-8")
1689
1735
  )
1690
- backward = runtime.core.cuda_get_kernel(
1691
- self.device.context, self.handle, (name + "_cuda_kernel_backward").encode("utf-8")
1736
+
1737
+ backward_name = name + "_cuda_kernel_backward"
1738
+ backward_kernel = runtime.core.cuda_get_kernel(
1739
+ self.device.context, self.handle, backward_name.encode("utf-8")
1692
1740
  )
1741
+
1742
+ # look up the required shared memory size for each kernel from module metadata
1743
+ forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
1744
+ backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
1745
+
1746
+ # configure kernels maximum shared memory size
1747
+ max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
1748
+
1749
+ if not runtime.core.cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
1750
+ print(
1751
+ f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1752
+ )
1753
+
1754
+ options = dict(kernel.module.options)
1755
+ options.update(kernel.options)
1756
+
1757
+ if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
1758
+ backward_kernel, backward_smem_bytes
1759
+ ):
1760
+ print(
1761
+ f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {backward_name} kernel for {backward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1762
+ )
1763
+
1764
+ hooks = KernelHooks(forward_kernel, backward_kernel, forward_smem_bytes, backward_smem_bytes)
1765
+
1693
1766
  else:
1694
1767
  func = ctypes.CFUNCTYPE(None)
1695
1768
  forward = (
@@ -1699,9 +1772,9 @@ class ModuleExec:
1699
1772
  func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
1700
1773
  )
1701
1774
 
1702
- hooks = KernelHooks(forward, backward)
1703
- self.kernel_hooks[kernel] = hooks
1775
+ hooks = KernelHooks(forward, backward)
1704
1776
 
1777
+ self.kernel_hooks[kernel.adj] = hooks
1705
1778
  return hooks
1706
1779
 
1707
1780
 
@@ -1711,7 +1784,8 @@ class ModuleExec:
1711
1784
  # build cache
1712
1785
  class Module:
1713
1786
  def __init__(self, name, loader):
1714
- self.name = name
1787
+ self.name = name if name is not None else "None"
1788
+
1715
1789
  self.loader = loader
1716
1790
 
1717
1791
  # lookup the latest versions of kernels, functions, and structs by key
@@ -1719,12 +1793,14 @@ class Module:
1719
1793
  self.functions = {} # (key: function)
1720
1794
  self.structs = {} # (key: struct)
1721
1795
 
1722
- # Set of all "live" kernels in this module.
1796
+ # Set of all "live" kernels in this module, i.e., kernels that still have references.
1797
+ # We keep a weak reference to every kernel ever created in this module and rely on Python GC
1798
+ # to release kernels that no longer have any references (in user code or internal bookkeeping).
1723
1799
  # The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
1724
1800
  # multiple kernels with the same key (which is essential to support closures), while `kernels`
1725
1801
  # only holds the latest kernel for each key. When the module is built, we compute the hash
1726
1802
  # of each kernel in `live_kernels` and filter out duplicates for codegen.
1727
- self.live_kernels = weakref.WeakSet()
1803
+ self._live_kernels = weakref.WeakSet()
1728
1804
 
1729
1805
  # executable modules currently loaded
1730
1806
  self.execs = {} # (device.context: ModuleExec)
@@ -1748,6 +1824,7 @@ class Module:
1748
1824
  "fast_math": False,
1749
1825
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1750
1826
  "mode": warp.config.mode,
1827
+ "block_dim": 256,
1751
1828
  }
1752
1829
 
1753
1830
  # Module dependencies are determined by scanning each function
@@ -1772,7 +1849,7 @@ class Module:
1772
1849
  self.kernels[kernel.key] = kernel
1773
1850
 
1774
1851
  # track all kernel objects, even if they are duplicates
1775
- self.live_kernels.add(kernel)
1852
+ self._live_kernels.add(kernel)
1776
1853
 
1777
1854
  self.find_references(kernel.adj)
1778
1855
 
@@ -1838,6 +1915,19 @@ class Module:
1838
1915
  # for a reload of module on next launch
1839
1916
  self.mark_modified()
1840
1917
 
1918
+ @property
1919
+ def live_kernels(self):
1920
+ # Return a list of kernels that still have references.
1921
+ # We return a regular list instead of the WeakSet to avoid undesirable issues
1922
+ # if kernels are garbage collected before the caller is done using this list.
1923
+ # Note that we should avoid retaining strong references to kernels unnecessarily
1924
+ # so that Python GC can release kernels that no longer have user references.
1925
+ # It is tempting to call gc.collect() here to force garbage collection,
1926
+ # but this can have undesirable consequences (e.g., GC during graph capture),
1927
+ # so we should avoid it as a general rule. Instead, we rely on Python's
1928
+ # reference counting GC to collect kernels that have gone out of scope.
1929
+ return list(self._live_kernels)
1930
+
1841
1931
  # find kernel corresponding to a Python function
1842
1932
  def find_kernel(self, func):
1843
1933
  qualname = warp.codegen.make_full_qualified_name(func)
@@ -1878,9 +1968,17 @@ class Module:
1878
1968
  self.hasher = ModuleHasher(self)
1879
1969
  return self.hasher.get_module_hash()
1880
1970
 
1881
- def load(self, device) -> ModuleExec:
1971
+ def load(self, device, block_dim=None) -> ModuleExec:
1882
1972
  device = runtime.get_device(device)
1883
1973
 
1974
+ # re-compile module if tile size (blockdim) changes
1975
+ # todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
1976
+ # that can return a single kernel instance with a given block size
1977
+ if block_dim is not None:
1978
+ if self.options["block_dim"] != block_dim:
1979
+ self.unload()
1980
+ self.options["block_dim"] = block_dim
1981
+
1884
1982
  # compute the hash if needed
1885
1983
  if self.hasher is None:
1886
1984
  self.hasher = ModuleHasher(self)
@@ -1908,6 +2006,7 @@ class Module:
1908
2006
  # determine output paths
1909
2007
  if device.is_cpu:
1910
2008
  output_name = "module_codegen.o"
2009
+ output_arch = None
1911
2010
 
1912
2011
  elif device.is_cuda:
1913
2012
  # determine whether to use PTX or CUBIN
@@ -1946,7 +2045,12 @@ class Module:
1946
2045
  or not warp.config.cache_kernels
1947
2046
  or warp.config.verify_autograd_array_access
1948
2047
  ):
1949
- builder = ModuleBuilder(self, self.options, hasher=self.hasher)
2048
+ builder_options = {
2049
+ **self.options,
2050
+ # Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2051
+ "output_arch": output_arch,
2052
+ }
2053
+ builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
1950
2054
 
1951
2055
  # create a temporary (process unique) dir for build outputs before moving to the binary dir
1952
2056
  build_dir = os.path.join(
@@ -2009,6 +2113,7 @@ class Module:
2009
2113
  config=self.options["mode"],
2010
2114
  fast_math=self.options["fast_math"],
2011
2115
  verify_fp=warp.config.verify_fp,
2116
+ ltoirs=builder.ltoirs.values(),
2012
2117
  )
2013
2118
 
2014
2119
  except Exception as e:
@@ -2016,6 +2121,15 @@ class Module:
2016
2121
  module_load_timer.extra_msg = " (error)"
2017
2122
  raise (e)
2018
2123
 
2124
+ # ------------------------------------------------------------
2125
+ # build meta data
2126
+
2127
+ meta = builder.build_meta()
2128
+ meta_path = os.path.join(build_dir, "module_codegen.meta")
2129
+
2130
+ with open(meta_path, "w") as meta_file:
2131
+ json.dump(meta, meta_file)
2132
+
2019
2133
  # -----------------------------------------------------------
2020
2134
  # update cache
2021
2135
 
@@ -2052,18 +2166,23 @@ class Module:
2052
2166
 
2053
2167
  # -----------------------------------------------------------
2054
2168
  # Load CPU or CUDA binary
2169
+
2170
+ meta_path = os.path.join(module_dir, "module_codegen.meta")
2171
+ with open(meta_path, "r") as meta_file:
2172
+ meta = json.load(meta_file)
2173
+
2055
2174
  if device.is_cpu:
2056
2175
  # LLVM modules are identified using strings, so we need to ensure uniqueness
2057
2176
  module_handle = f"{module_name}_{self.cpu_exec_id}"
2058
2177
  self.cpu_exec_id += 1
2059
2178
  runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2060
- module_exec = ModuleExec(module_handle, module_hash, device)
2179
+ module_exec = ModuleExec(module_handle, module_hash, device, meta)
2061
2180
  self.execs[None] = module_exec
2062
2181
 
2063
2182
  elif device.is_cuda:
2064
2183
  cuda_module = warp.build.load_cuda(binary_path, device)
2065
2184
  if cuda_module is not None:
2066
- module_exec = ModuleExec(cuda_module, module_hash, device)
2185
+ module_exec = ModuleExec(cuda_module, module_hash, device, meta)
2067
2186
  self.execs[device.context] = module_exec
2068
2187
  else:
2069
2188
  module_load_timer.extra_msg = " (error)"
@@ -2718,21 +2837,16 @@ class Graph:
2718
2837
 
2719
2838
  class Runtime:
2720
2839
  def __init__(self):
2721
- if sys.version_info < (3, 7):
2722
- raise RuntimeError("Warp requires Python 3.7 as a minimum")
2840
+ if sys.version_info < (3, 8):
2841
+ raise RuntimeError("Warp requires Python 3.8 as a minimum")
2723
2842
  if sys.version_info < (3, 9):
2724
2843
  warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
2725
2844
 
2726
2845
  bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
2727
2846
 
2728
2847
  if os.name == "nt":
2729
- if sys.version_info >= (3, 8):
2730
- # Python >= 3.8 this method to add dll search paths
2731
- os.add_dll_directory(bin_path)
2732
-
2733
- else:
2734
- # Python < 3.8 we add dll directory to path
2735
- os.environ["PATH"] = bin_path + os.pathsep + os.environ["PATH"]
2848
+ # Python >= 3.8 this method to add dll search paths
2849
+ os.add_dll_directory(bin_path)
2736
2850
 
2737
2851
  warp_lib = os.path.join(bin_path, "warp.dll")
2738
2852
  llvm_lib = os.path.join(bin_path, "warp-clang.dll")
@@ -3204,6 +3318,8 @@ class Runtime:
3204
3318
  self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
3205
3319
  self.core.is_cutlass_enabled.argtypes = None
3206
3320
  self.core.is_cutlass_enabled.restype = ctypes.c_int
3321
+ self.core.is_mathdx_enabled.argtypes = None
3322
+ self.core.is_mathdx_enabled.restype = ctypes.c_int
3207
3323
 
3208
3324
  self.core.cuda_driver_version.argtypes = None
3209
3325
  self.core.cuda_driver_version.restype = ctypes.c_int
@@ -3328,17 +3444,58 @@ class Runtime:
3328
3444
  self.core.cuda_graph_destroy.restype = ctypes.c_bool
3329
3445
 
3330
3446
  self.core.cuda_compile_program.argtypes = [
3331
- ctypes.c_char_p,
3332
- ctypes.c_int,
3333
- ctypes.c_char_p,
3334
- ctypes.c_bool,
3335
- ctypes.c_bool,
3336
- ctypes.c_bool,
3337
- ctypes.c_bool,
3338
- ctypes.c_char_p,
3447
+ ctypes.c_char_p, # cuda_src
3448
+ ctypes.c_int, # arch
3449
+ ctypes.c_char_p, # include_dir
3450
+ ctypes.c_int, # num_cuda_include_dirs
3451
+ ctypes.POINTER(ctypes.c_char_p), # cuda include dirs
3452
+ ctypes.c_bool, # debug
3453
+ ctypes.c_bool, # verbose
3454
+ ctypes.c_bool, # verify_fp
3455
+ ctypes.c_bool, # fast_math
3456
+ ctypes.c_char_p, # output_path
3457
+ ctypes.c_size_t, # num_ltoirs
3458
+ ctypes.POINTER(ctypes.c_char_p), # ltoirs
3459
+ ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
3339
3460
  ]
3340
3461
  self.core.cuda_compile_program.restype = ctypes.c_size_t
3341
3462
 
3463
+ self.core.cuda_compile_fft.argtypes = [
3464
+ ctypes.c_char_p, # lto
3465
+ ctypes.c_char_p, # function name
3466
+ ctypes.c_int, # num include dirs
3467
+ ctypes.POINTER(ctypes.c_char_p), # include dirs
3468
+ ctypes.c_char_p, # mathdx include dir
3469
+ ctypes.c_int, # arch
3470
+ ctypes.c_int, # size
3471
+ ctypes.c_int, # ept
3472
+ ctypes.c_int, # direction
3473
+ ctypes.c_int, # precision
3474
+ ctypes.POINTER(ctypes.c_int), # smem (out)
3475
+ ]
3476
+ self.core.cuda_compile_fft.restype = ctypes.c_bool
3477
+
3478
+ self.core.cuda_compile_dot.argtypes = [
3479
+ ctypes.c_char_p, # lto
3480
+ ctypes.c_char_p, # function name
3481
+ ctypes.c_int, # num include dirs
3482
+ ctypes.POINTER(ctypes.c_char_p), # include dirs
3483
+ ctypes.c_char_p, # mathdx include dir
3484
+ ctypes.c_int, # arch
3485
+ ctypes.c_int, # M
3486
+ ctypes.c_int, # N
3487
+ ctypes.c_int, # K
3488
+ ctypes.c_int, # a_precision
3489
+ ctypes.c_int, # b_precision
3490
+ ctypes.c_int, # c_precision
3491
+ ctypes.c_int, # type
3492
+ ctypes.c_int, # a_arrangement
3493
+ ctypes.c_int, # b_arrangement
3494
+ ctypes.c_int, # c_arrangement
3495
+ ctypes.c_int, # num threads
3496
+ ]
3497
+ self.core.cuda_compile_dot.restype = ctypes.c_bool
3498
+
3342
3499
  self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3343
3500
  self.core.cuda_load_module.restype = ctypes.c_void_p
3344
3501
 
@@ -3348,11 +3505,19 @@ class Runtime:
3348
3505
  self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
3349
3506
  self.core.cuda_get_kernel.restype = ctypes.c_void_p
3350
3507
 
3508
+ self.core.cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
3509
+ self.core.cuda_get_max_shared_memory.restype = ctypes.c_int
3510
+
3511
+ self.core.cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
3512
+ self.core.cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
3513
+
3351
3514
  self.core.cuda_launch_kernel.argtypes = [
3352
3515
  ctypes.c_void_p,
3353
3516
  ctypes.c_void_p,
3354
3517
  ctypes.c_size_t,
3355
3518
  ctypes.c_int,
3519
+ ctypes.c_int,
3520
+ ctypes.c_int,
3356
3521
  ctypes.POINTER(ctypes.c_void_p),
3357
3522
  ctypes.c_void_p,
3358
3523
  ]
@@ -3381,6 +3546,23 @@ class Runtime:
3381
3546
  self.core.cuda_timing_end.argtypes = []
3382
3547
  self.core.cuda_timing_end.restype = None
3383
3548
 
3549
+ self.core.graph_coloring.argtypes = [
3550
+ ctypes.c_int,
3551
+ warp.types.array_t,
3552
+ ctypes.c_int,
3553
+ warp.types.array_t,
3554
+ ]
3555
+ self.core.graph_coloring.restype = ctypes.c_int
3556
+
3557
+ self.core.balance_coloring.argtypes = [
3558
+ ctypes.c_int,
3559
+ warp.types.array_t,
3560
+ ctypes.c_int,
3561
+ ctypes.c_float,
3562
+ warp.types.array_t,
3563
+ ]
3564
+ self.core.balance_coloring.restype = ctypes.c_float
3565
+
3384
3566
  self.core.init.restype = ctypes.c_int
3385
3567
 
3386
3568
  except AttributeError as e:
@@ -3606,10 +3788,7 @@ class Runtime:
3606
3788
 
3607
3789
  def load_dll(self, dll_path):
3608
3790
  try:
3609
- if sys.version_info >= (3, 8):
3610
- dll = ctypes.CDLL(dll_path, winmode=0)
3611
- else:
3612
- dll = ctypes.CDLL(dll_path)
3791
+ dll = ctypes.CDLL(dll_path, winmode=0)
3613
3792
  except OSError as e:
3614
3793
  if "GLIBCXX" in str(e):
3615
3794
  raise RuntimeError(
@@ -3750,7 +3929,7 @@ def is_cuda_available() -> bool:
3750
3929
  return get_cuda_device_count() > 0
3751
3930
 
3752
3931
 
3753
- def is_device_available(device):
3932
+ def is_device_available(device: Device) -> bool:
3754
3933
  return device in get_devices()
3755
3934
 
3756
3935
 
@@ -3810,7 +3989,7 @@ def get_cuda_devices() -> List[Device]:
3810
3989
 
3811
3990
 
3812
3991
  def get_preferred_device() -> Device:
3813
- """Returns the preferred compute device, CUDA if available and CPU otherwise."""
3992
+ """Returns the preferred compute device, ``cuda:0`` if available and ``cpu`` otherwise."""
3814
3993
 
3815
3994
  init()
3816
3995
 
@@ -3950,7 +4129,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
3950
4129
 
3951
4130
 
3952
4131
  def get_mempool_release_threshold(device: Devicelike) -> int:
3953
- """Get the CUDA memory pool release threshold on the device."""
4132
+ """Get the CUDA memory pool release threshold on the device in bytes."""
3954
4133
 
3955
4134
  init()
3956
4135
 
@@ -3969,7 +4148,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
3969
4148
  """Check if `peer_device` can directly access the memory of `target_device` on this system.
3970
4149
 
3971
4150
  This applies to memory allocated using default CUDA allocators. For memory allocated using
3972
- CUDA pooled allocators, use `is_mempool_access_supported()`.
4151
+ CUDA pooled allocators, use :func:`is_mempool_access_supported()`.
3973
4152
 
3974
4153
  Returns:
3975
4154
  A Boolean value indicating if this peer access is supported by the system.
@@ -3990,7 +4169,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -
3990
4169
  """Check if `peer_device` can currently access the memory of `target_device`.
3991
4170
 
3992
4171
  This applies to memory allocated using default CUDA allocators. For memory allocated using
3993
- CUDA pooled allocators, use `is_mempool_access_enabled()`.
4172
+ CUDA pooled allocators, use :func:`is_mempool_access_enabled()`.
3994
4173
 
3995
4174
  Returns:
3996
4175
  A Boolean value indicating if this peer access is currently enabled.
@@ -4014,7 +4193,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
4014
4193
  a negative impact on memory consumption and allocation performance.
4015
4194
 
4016
4195
  This applies to memory allocated using default CUDA allocators. For memory allocated using
4017
- CUDA pooled allocators, use `set_mempool_access_enabled()`.
4196
+ CUDA pooled allocators, use :func:`set_mempool_access_enabled()`.
4018
4197
  """
4019
4198
 
4020
4199
  init()
@@ -4042,7 +4221,8 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
4042
4221
  def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
4043
4222
  """Check if `peer_device` can directly access the memory pool of `target_device`.
4044
4223
 
4045
- If mempool access is possible, it can be managed using `set_mempool_access_enabled()` and `is_mempool_access_enabled()`.
4224
+ If mempool access is possible, it can be managed using :func:`set_mempool_access_enabled()`
4225
+ and :func:`is_mempool_access_enabled()`.
4046
4226
 
4047
4227
  Returns:
4048
4228
  A Boolean value indicating if this memory pool access is supported by the system.
@@ -4060,7 +4240,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
4060
4240
  """Check if `peer_device` can currently access the memory pool of `target_device`.
4061
4241
 
4062
4242
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
4063
- default CUDA allocators, use `is_peer_access_enabled()`.
4243
+ default CUDA allocators, use :func:`is_peer_access_enabled()`.
4064
4244
 
4065
4245
  Returns:
4066
4246
  A Boolean value indicating if this peer access is currently enabled.
@@ -4081,7 +4261,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
4081
4261
  """Enable or disable access from `peer_device` to the memory pool of `target_device`.
4082
4262
 
4083
4263
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
4084
- default CUDA allocators, use `set_peer_access_enabled()`.
4264
+ default CUDA allocators, use :func:`set_peer_access_enabled()`.
4085
4265
  """
4086
4266
 
4087
4267
  init()
@@ -4790,7 +4970,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4790
4970
  # represents all data required for a kernel launch
4791
4971
  # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
4792
4972
  class Launch:
4793
- def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
4973
+ def __init__(
4974
+ self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0, block_dim=256
4975
+ ):
4794
4976
  # retain the module executable so it doesn't get unloaded
4795
4977
  self.module_exec = kernel.module.load(device)
4796
4978
  if not self.module_exec:
@@ -4829,6 +5011,7 @@ class Launch:
4829
5011
  self.device = device
4830
5012
  self.bounds = bounds
4831
5013
  self.max_blocks = max_blocks
5014
+ self.block_dim = block_dim
4832
5015
 
4833
5016
  def set_dim(self, dim):
4834
5017
  self.bounds = warp.types.launch_bounds_t(dim)
@@ -4910,6 +5093,8 @@ class Launch:
4910
5093
  self.hooks.forward,
4911
5094
  self.bounds.size,
4912
5095
  self.max_blocks,
5096
+ self.block_dim,
5097
+ self.hooks.forward_smem_bytes,
4913
5098
  self.params_addr,
4914
5099
  stream.cuda_stream,
4915
5100
  )
@@ -4928,6 +5113,7 @@ def launch(
4928
5113
  record_tape=True,
4929
5114
  record_cmd=False,
4930
5115
  max_blocks=0,
5116
+ block_dim=256,
4931
5117
  ):
4932
5118
  """Launch a Warp kernel on the target device
4933
5119
 
@@ -4947,6 +5133,7 @@ def launch(
4947
5133
  record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
4948
5134
  max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
4949
5135
  If negative or zero, the maximum hardware value will be used.
5136
+ block_dim: The number of threads per block.
4950
5137
  """
4951
5138
 
4952
5139
  init()
@@ -5000,7 +5187,12 @@ def launch(
5000
5187
  kernel = kernel.add_overload(fwd_types)
5001
5188
 
5002
5189
  # delay load modules, including new overload if needed
5003
- module_exec = kernel.module.load(device)
5190
+ try:
5191
+ module_exec = kernel.module.load(device, block_dim)
5192
+ except Exception:
5193
+ kernel.adj.skip_build = True
5194
+ raise
5195
+
5004
5196
  if not module_exec:
5005
5197
  return
5006
5198
 
@@ -5056,7 +5248,14 @@ def launch(
5056
5248
  )
5057
5249
 
5058
5250
  runtime.core.cuda_launch_kernel(
5059
- device.context, hooks.backward, bounds.size, max_blocks, kernel_params, stream.cuda_stream
5251
+ device.context,
5252
+ hooks.backward,
5253
+ bounds.size,
5254
+ max_blocks,
5255
+ block_dim,
5256
+ hooks.backward_smem_bytes,
5257
+ kernel_params,
5258
+ stream.cuda_stream,
5060
5259
  )
5061
5260
 
5062
5261
  else:
@@ -5079,7 +5278,14 @@ def launch(
5079
5278
  else:
5080
5279
  # launch
5081
5280
  runtime.core.cuda_launch_kernel(
5082
- device.context, hooks.forward, bounds.size, max_blocks, kernel_params, stream.cuda_stream
5281
+ device.context,
5282
+ hooks.forward,
5283
+ bounds.size,
5284
+ max_blocks,
5285
+ block_dim,
5286
+ hooks.forward_smem_bytes,
5287
+ kernel_params,
5288
+ stream.cuda_stream,
5083
5289
  )
5084
5290
 
5085
5291
  try:
@@ -5093,13 +5299,65 @@ def launch(
5093
5299
  # record file, lineno, func as metadata
5094
5300
  frame = inspect.currentframe().f_back
5095
5301
  caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name}
5096
- runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device, metadata={"caller": caller})
5302
+ runtime.tape.record_launch(
5303
+ kernel, dim, max_blocks, inputs, outputs, device, block_dim, metadata={"caller": caller}
5304
+ )
5097
5305
 
5098
5306
  # detect illegal inter-kernel read/write access patterns if verification flag is set
5099
5307
  if warp.config.verify_autograd_array_access:
5100
5308
  runtime.tape._check_kernel_array_access(kernel, fwd_args)
5101
5309
 
5102
5310
 
5311
+ def launch_tiled(*args, **kwargs):
5312
+ """A helper method for launching a grid with an extra trailing dimension equal to the block size.
5313
+
5314
+ For example, to launch a 2D grid, where each element has 64 threads assigned you would use the following:
5315
+
5316
+ .. code-block:: python
5317
+
5318
+ wp.launch_tiled(kernel, [M, N], inputs=[...], block_dim=64)
5319
+
5320
+ Which is equivalent to the following:
5321
+
5322
+ .. code-block:: python
5323
+
5324
+ wp.launch(kernel, [M, N, 64], inputs=[...], block_dim=64)
5325
+
5326
+ Inside your kernel code you can retrieve the first two indices of the thread as usual, ignoring the implicit third dimension if desired:
5327
+
5328
+ .. code-block:: python
5329
+
5330
+ @wp.kernel
5331
+ def compute()
5332
+
5333
+ i, j = wp.tid()
5334
+
5335
+ ...
5336
+ """
5337
+
5338
+ # promote dim to a list in case it was passed as a scalar or tuple
5339
+ if "dim" not in kwargs:
5340
+ raise RuntimeError("Launch dimensions 'dim' argument should be passed via. keyword args for wp.launch_tiled()")
5341
+
5342
+ if "block_dim" not in kwargs:
5343
+ raise RuntimeError(
5344
+ "Launch block dimension 'block_dim' argument should be passed via. keyword args for wp.launch_tiled()"
5345
+ )
5346
+
5347
+ dim = kwargs["dim"]
5348
+ if not isinstance(dim, list):
5349
+ dim = list(dim) if isinstance(dim, tuple) else [dim]
5350
+
5351
+ if len(dim) > 3:
5352
+ raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
5353
+
5354
+ # add trailing dimension
5355
+ kwargs["dim"] = dim + [kwargs["block_dim"]]
5356
+
5357
+ # forward to original launch method
5358
+ launch(*args, **kwargs)
5359
+
5360
+
5103
5361
  def synchronize():
5104
5362
  """Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
5105
5363
 
@@ -5618,16 +5876,6 @@ def type_str(t):
5618
5876
  return "Any"
5619
5877
  elif t == Callable:
5620
5878
  return "Callable"
5621
- elif t == Tuple[int]:
5622
- return "Tuple[int]"
5623
- elif t == Tuple[int, int]:
5624
- return "Tuple[int, int]"
5625
- elif t == Tuple[int, int, int]:
5626
- return "Tuple[int, int, int]"
5627
- elif t == Tuple[int, int, int, int]:
5628
- return "Tuple[int, int, int, int]"
5629
- elif t == Tuple[int, ...]:
5630
- return "Tuple[int, ...]"
5631
5879
  elif isinstance(t, int):
5632
5880
  return str(t)
5633
5881
  elif isinstance(t, List):
@@ -5662,9 +5910,13 @@ def type_str(t):
5662
5910
  return f"Transformation[{type_str(t._wp_scalar_type_)}]"
5663
5911
 
5664
5912
  raise TypeError("Invalid vector or matrix dimensions")
5665
- elif typing.get_origin(t) in (List, Mapping, Sequence, Union, Tuple):
5666
- args_repr = ", ".join(type_str(x) for x in typing.get_args(t))
5667
- return f"{t.__name__}[{args_repr}]"
5913
+ elif warp.codegen.get_type_origin(t) in (list, tuple):
5914
+ args_repr = ", ".join(type_str(x) for x in warp.codegen.get_type_args(t))
5915
+ return f"{t._name}[{args_repr}]"
5916
+ elif t is Ellipsis:
5917
+ return "..."
5918
+ elif warp.types.is_tile(t):
5919
+ return "Tile"
5668
5920
 
5669
5921
  return t.__name__
5670
5922
 
@@ -5825,9 +6077,6 @@ def export_stubs(file): # pragma: no cover
5825
6077
  print('Cols = TypeVar("Cols", bound=int)', file=file)
5826
6078
  print('DType = TypeVar("DType")', file=file)
5827
6079
 
5828
- print('Int = TypeVar("Int")', file=file)
5829
- print('Float = TypeVar("Float")', file=file)
5830
- print('Scalar = TypeVar("Scalar")', file=file)
5831
6080
  print("Vector = Generic[Length, Scalar]", file=file)
5832
6081
  print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
5833
6082
  print("Quaternion = Generic[Float]", file=file)