warp-lang 1.4.2__py3-none-win_amd64.whl → 1.5.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (158) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1783 -2
  8. warp/codegen.py +177 -45
  9. warp/config.py +2 -2
  10. warp/context.py +321 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +2 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -5
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +600 -0
  82. warp/native/cuda_util.cpp +14 -0
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1857 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +137 -65
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +88 -15
  117. warp/stubs.py +569 -4
  118. warp/tape.py +12 -7
  119. warp/tests/assets/pixel.npy +0 -0
  120. warp/tests/aux_test_instancing_gc.py +18 -0
  121. warp/tests/test_array.py +39 -0
  122. warp/tests/test_codegen.py +81 -1
  123. warp/tests/test_codegen_instancing.py +30 -0
  124. warp/tests/test_collision.py +110 -0
  125. warp/tests/test_coloring.py +241 -0
  126. warp/tests/test_context.py +34 -0
  127. warp/tests/test_examples.py +18 -4
  128. warp/tests/test_fem.py +453 -113
  129. warp/tests/test_func.py +13 -0
  130. warp/tests/test_generics.py +52 -0
  131. warp/tests/test_iter.py +68 -0
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_mesh_query_point.py +1 -1
  134. warp/tests/test_module_hashing.py +23 -0
  135. warp/tests/test_paddle.py +27 -87
  136. warp/tests/test_print.py +56 -1
  137. warp/tests/test_spatial.py +1 -1
  138. warp/tests/test_tile.py +700 -0
  139. warp/tests/test_tile_mathdx.py +144 -0
  140. warp/tests/test_tile_mlp.py +383 -0
  141. warp/tests/test_tile_reduce.py +374 -0
  142. warp/tests/test_tile_shared_memory.py +190 -0
  143. warp/tests/test_vbd.py +12 -20
  144. warp/tests/test_volume.py +43 -0
  145. warp/tests/unittest_suites.py +19 -2
  146. warp/tests/unittest_utils.py +4 -0
  147. warp/types.py +338 -72
  148. warp/utils.py +22 -1
  149. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  150. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
  151. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  152. warp/fem/field/test.py +0 -180
  153. warp/fem/field/trial.py +0 -183
  154. warp/fem/space/collocated_function_space.py +0 -102
  155. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  156. warp/fem/space/trimesh_2d_function_space.py +0 -153
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -12,6 +12,7 @@ import hashlib
12
12
  import inspect
13
13
  import io
14
14
  import itertools
15
+ import json
15
16
  import operator
16
17
  import os
17
18
  import platform
@@ -21,7 +22,7 @@ import typing
21
22
  import weakref
22
23
  from copy import copy as shallowcopy
23
24
  from pathlib import Path
24
- from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
25
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
25
26
 
26
27
  import numpy as np
27
28
 
@@ -101,6 +102,7 @@ class Function:
101
102
  value_func=None,
102
103
  export_func=None,
103
104
  dispatch_func=None,
105
+ lto_dispatch_func=None,
104
106
  module=None,
105
107
  variadic=False,
106
108
  initializer_list_func=None,
@@ -137,6 +139,7 @@ class Function:
137
139
  self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded
138
140
  self.export_func = export_func
139
141
  self.dispatch_func = dispatch_func
142
+ self.lto_dispatch_func = lto_dispatch_func
140
143
  self.input_types = {}
141
144
  self.export = export
142
145
  self.doc = doc
@@ -619,10 +622,13 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
619
622
 
620
623
 
621
624
  class KernelHooks:
622
- def __init__(self, forward, backward):
625
+ def __init__(self, forward, backward, forward_smem_bytes=0, backward_smem_bytes=0):
623
626
  self.forward = forward
624
627
  self.backward = backward
625
628
 
629
+ self.forward_smem_bytes = forward_smem_bytes
630
+ self.backward_smem_bytes = backward_smem_bytes
631
+
626
632
 
627
633
  # caches source and compiled entry points for a kernel (will be populated after module loads)
628
634
  class Kernel:
@@ -970,8 +976,17 @@ def struct(c):
970
976
  return s
971
977
 
972
978
 
973
- # overload a kernel with the given argument types
974
- def overload(kernel, arg_types=None):
979
+ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
980
+ """Overload a generic kernel with the given argument types.
981
+
982
+ Can be called directly or used as a function decorator.
983
+
984
+ Args:
985
+ kernel: The generic kernel to be instantiated with concrete types.
986
+ arg_types: A list of concrete argument types for the kernel or a
987
+ dictionary specifying generic argument names as keys and concrete
988
+ types as variables.
989
+ """
975
990
  if isinstance(kernel, Kernel):
976
991
  # handle cases where user calls us directly, e.g. wp.overload(kernel, [args...])
977
992
 
@@ -1073,6 +1088,7 @@ def add_builtin(
1073
1088
  value_func=None,
1074
1089
  export_func=None,
1075
1090
  dispatch_func=None,
1091
+ lto_dispatch_func=None,
1076
1092
  doc="",
1077
1093
  namespace="wp::",
1078
1094
  variadic=False,
@@ -1113,6 +1129,9 @@ def add_builtin(
1113
1129
  The arguments returned must be of type `codegen.Var`.
1114
1130
  If not provided, all arguments passed by the users when calling
1115
1131
  the built-in are passed as-is as runtime arguments to the C++ function.
1132
+ lto_dispatch_func (Callable): Same as dispatch_func, but takes an 'option' dict
1133
+ as extra argument (indicating tile_size and target architecture) and returns
1134
+ an LTO-IR buffer as extra return value
1116
1135
  doc (str): Used to generate the Python's docstring and the HTML documentation.
1117
1136
  namespace: Namespace for the underlying C++ function.
1118
1137
  variadic (bool): Whether the function declares variadic arguments.
@@ -1252,6 +1271,7 @@ def add_builtin(
1252
1271
  value_func=value_func if return_type is Any else None,
1253
1272
  export_func=export_func,
1254
1273
  dispatch_func=dispatch_func,
1274
+ lto_dispatch_func=lto_dispatch_func,
1255
1275
  doc=doc,
1256
1276
  namespace=namespace,
1257
1277
  variadic=variadic,
@@ -1274,6 +1294,7 @@ def add_builtin(
1274
1294
  value_func=value_func,
1275
1295
  export_func=export_func,
1276
1296
  dispatch_func=dispatch_func,
1297
+ lto_dispatch_func=lto_dispatch_func,
1277
1298
  variadic=variadic,
1278
1299
  initializer_list_func=initializer_list_func,
1279
1300
  export=export,
@@ -1540,6 +1561,8 @@ class ModuleBuilder:
1540
1561
  self.options = options
1541
1562
  self.module = module
1542
1563
  self.deferred_functions = []
1564
+ self.ltoirs = {} # map from lto symbol to lto binary
1565
+ self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
1543
1566
 
1544
1567
  if hasher is None:
1545
1568
  hasher = ModuleHasher(module)
@@ -1607,9 +1630,26 @@ class ModuleBuilder:
1607
1630
  # use dict to preserve import order
1608
1631
  self.functions[func] = None
1609
1632
 
1633
+ def build_meta(self):
1634
+ meta = {}
1635
+
1636
+ for kernel in self.kernels:
1637
+ name = kernel.get_mangled_name()
1638
+
1639
+ meta[name + "_cuda_kernel_forward_smem_bytes"] = kernel.adj.get_total_required_shared()
1640
+ meta[name + "_cuda_kernel_backward_smem_bytes"] = kernel.adj.get_total_required_shared() * 2
1641
+
1642
+ return meta
1643
+
1610
1644
  def codegen(self, device):
1611
1645
  source = ""
1612
1646
 
1647
+ # code-gen LTO forward declarations
1648
+ source += 'extern "C" {\n'
1649
+ for fwd in self.ltoirs_decl.values():
1650
+ source += fwd + "\n"
1651
+ source += "}\n"
1652
+
1613
1653
  # code-gen structs
1614
1654
  visited_structs = set()
1615
1655
  for struct in self.structs.keys():
@@ -1639,9 +1679,9 @@ class ModuleBuilder:
1639
1679
 
1640
1680
  # add headers
1641
1681
  if device == "cpu":
1642
- source = warp.codegen.cpu_module_header + source
1682
+ source = warp.codegen.cpu_module_header.format(tile_size=self.options["block_dim"]) + source
1643
1683
  else:
1644
- source = warp.codegen.cuda_module_header + source
1684
+ source = warp.codegen.cuda_module_header.format(tile_size=self.options["block_dim"]) + source
1645
1685
 
1646
1686
  return source
1647
1687
 
@@ -1660,11 +1700,12 @@ class ModuleExec:
1660
1700
  instance.handle = None
1661
1701
  return instance
1662
1702
 
1663
- def __init__(self, handle, module_hash, device):
1703
+ def __init__(self, handle, module_hash, device, meta):
1664
1704
  self.handle = handle
1665
1705
  self.module_hash = module_hash
1666
1706
  self.device = device
1667
1707
  self.kernel_hooks = {}
1708
+ self.meta = meta
1668
1709
 
1669
1710
  # release the loaded module
1670
1711
  def __del__(self):
@@ -1678,19 +1719,50 @@ class ModuleExec:
1678
1719
 
1679
1720
  # lookup and cache kernel entry points
1680
1721
  def get_kernel_hooks(self, kernel):
1681
- hooks = self.kernel_hooks.get(kernel)
1722
+ # Use kernel.adj as a unique key for cache lookups instead of the kernel itself.
1723
+ # This avoids holding a reference to the kernel and is faster than using
1724
+ # a WeakKeyDictionary with kernels as keys.
1725
+ hooks = self.kernel_hooks.get(kernel.adj)
1682
1726
  if hooks is not None:
1683
1727
  return hooks
1684
1728
 
1685
1729
  name = kernel.get_mangled_name()
1686
1730
 
1687
1731
  if self.device.is_cuda:
1688
- forward = runtime.core.cuda_get_kernel(
1689
- self.device.context, self.handle, (name + "_cuda_kernel_forward").encode("utf-8")
1732
+ forward_name = name + "_cuda_kernel_forward"
1733
+ forward_kernel = runtime.core.cuda_get_kernel(
1734
+ self.device.context, self.handle, forward_name.encode("utf-8")
1690
1735
  )
1691
- backward = runtime.core.cuda_get_kernel(
1692
- self.device.context, self.handle, (name + "_cuda_kernel_backward").encode("utf-8")
1736
+
1737
+ backward_name = name + "_cuda_kernel_backward"
1738
+ backward_kernel = runtime.core.cuda_get_kernel(
1739
+ self.device.context, self.handle, backward_name.encode("utf-8")
1693
1740
  )
1741
+
1742
+ # look up the required shared memory size for each kernel from module metadata
1743
+ forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
1744
+ backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
1745
+
1746
+ # configure kernels maximum shared memory size
1747
+ max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
1748
+
1749
+ if not runtime.core.cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
1750
+ print(
1751
+ f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1752
+ )
1753
+
1754
+ options = dict(kernel.module.options)
1755
+ options.update(kernel.options)
1756
+
1757
+ if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
1758
+ backward_kernel, backward_smem_bytes
1759
+ ):
1760
+ print(
1761
+ f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {backward_name} kernel for {backward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1762
+ )
1763
+
1764
+ hooks = KernelHooks(forward_kernel, backward_kernel, forward_smem_bytes, backward_smem_bytes)
1765
+
1694
1766
  else:
1695
1767
  func = ctypes.CFUNCTYPE(None)
1696
1768
  forward = (
@@ -1700,9 +1772,9 @@ class ModuleExec:
1700
1772
  func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
1701
1773
  )
1702
1774
 
1703
- hooks = KernelHooks(forward, backward)
1704
- self.kernel_hooks[kernel] = hooks
1775
+ hooks = KernelHooks(forward, backward)
1705
1776
 
1777
+ self.kernel_hooks[kernel.adj] = hooks
1706
1778
  return hooks
1707
1779
 
1708
1780
 
@@ -1712,7 +1784,8 @@ class ModuleExec:
1712
1784
  # build cache
1713
1785
  class Module:
1714
1786
  def __init__(self, name, loader):
1715
- self.name = name
1787
+ self.name = name if name is not None else "None"
1788
+
1716
1789
  self.loader = loader
1717
1790
 
1718
1791
  # lookup the latest versions of kernels, functions, and structs by key
@@ -1720,12 +1793,14 @@ class Module:
1720
1793
  self.functions = {} # (key: function)
1721
1794
  self.structs = {} # (key: struct)
1722
1795
 
1723
- # Set of all "live" kernels in this module.
1796
+ # Set of all "live" kernels in this module, i.e., kernels that still have references.
1797
+ # We keep a weak reference to every kernel ever created in this module and rely on Python GC
1798
+ # to release kernels that no longer have any references (in user code or internal bookkeeping).
1724
1799
  # The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
1725
1800
  # multiple kernels with the same key (which is essential to support closures), while `kernels`
1726
1801
  # only holds the latest kernel for each key. When the module is built, we compute the hash
1727
1802
  # of each kernel in `live_kernels` and filter out duplicates for codegen.
1728
- self.live_kernels = weakref.WeakSet()
1803
+ self._live_kernels = weakref.WeakSet()
1729
1804
 
1730
1805
  # executable modules currently loaded
1731
1806
  self.execs = {} # (device.context: ModuleExec)
@@ -1749,6 +1824,7 @@ class Module:
1749
1824
  "fast_math": False,
1750
1825
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1751
1826
  "mode": warp.config.mode,
1827
+ "block_dim": 256,
1752
1828
  }
1753
1829
 
1754
1830
  # Module dependencies are determined by scanning each function
@@ -1773,7 +1849,7 @@ class Module:
1773
1849
  self.kernels[kernel.key] = kernel
1774
1850
 
1775
1851
  # track all kernel objects, even if they are duplicates
1776
- self.live_kernels.add(kernel)
1852
+ self._live_kernels.add(kernel)
1777
1853
 
1778
1854
  self.find_references(kernel.adj)
1779
1855
 
@@ -1839,6 +1915,19 @@ class Module:
1839
1915
  # for a reload of module on next launch
1840
1916
  self.mark_modified()
1841
1917
 
1918
+ @property
1919
+ def live_kernels(self):
1920
+ # Return a list of kernels that still have references.
1921
+ # We return a regular list instead of the WeakSet to avoid undesirable issues
1922
+ # if kernels are garbage collected before the caller is done using this list.
1923
+ # Note that we should avoid retaining strong references to kernels unnecessarily
1924
+ # so that Python GC can release kernels that no longer have user references.
1925
+ # It is tempting to call gc.collect() here to force garbage collection,
1926
+ # but this can have undesirable consequences (e.g., GC during graph capture),
1927
+ # so we should avoid it as a general rule. Instead, we rely on Python's
1928
+ # reference counting GC to collect kernels that have gone out of scope.
1929
+ return list(self._live_kernels)
1930
+
1842
1931
  # find kernel corresponding to a Python function
1843
1932
  def find_kernel(self, func):
1844
1933
  qualname = warp.codegen.make_full_qualified_name(func)
@@ -1879,9 +1968,17 @@ class Module:
1879
1968
  self.hasher = ModuleHasher(self)
1880
1969
  return self.hasher.get_module_hash()
1881
1970
 
1882
- def load(self, device) -> ModuleExec:
1971
+ def load(self, device, block_dim=None) -> ModuleExec:
1883
1972
  device = runtime.get_device(device)
1884
1973
 
1974
+ # re-compile module if tile size (blockdim) changes
1975
+ # todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
1976
+ # that can return a single kernel instance with a given block size
1977
+ if block_dim is not None:
1978
+ if self.options["block_dim"] != block_dim:
1979
+ self.unload()
1980
+ self.options["block_dim"] = block_dim
1981
+
1885
1982
  # compute the hash if needed
1886
1983
  if self.hasher is None:
1887
1984
  self.hasher = ModuleHasher(self)
@@ -1909,6 +2006,7 @@ class Module:
1909
2006
  # determine output paths
1910
2007
  if device.is_cpu:
1911
2008
  output_name = "module_codegen.o"
2009
+ output_arch = None
1912
2010
 
1913
2011
  elif device.is_cuda:
1914
2012
  # determine whether to use PTX or CUBIN
@@ -1947,7 +2045,12 @@ class Module:
1947
2045
  or not warp.config.cache_kernels
1948
2046
  or warp.config.verify_autograd_array_access
1949
2047
  ):
1950
- builder = ModuleBuilder(self, self.options, hasher=self.hasher)
2048
+ builder_options = {
2049
+ **self.options,
2050
+ # Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2051
+ "output_arch": output_arch,
2052
+ }
2053
+ builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
1951
2054
 
1952
2055
  # create a temporary (process unique) dir for build outputs before moving to the binary dir
1953
2056
  build_dir = os.path.join(
@@ -2010,6 +2113,7 @@ class Module:
2010
2113
  config=self.options["mode"],
2011
2114
  fast_math=self.options["fast_math"],
2012
2115
  verify_fp=warp.config.verify_fp,
2116
+ ltoirs=builder.ltoirs.values(),
2013
2117
  )
2014
2118
 
2015
2119
  except Exception as e:
@@ -2017,6 +2121,15 @@ class Module:
2017
2121
  module_load_timer.extra_msg = " (error)"
2018
2122
  raise (e)
2019
2123
 
2124
+ # ------------------------------------------------------------
2125
+ # build meta data
2126
+
2127
+ meta = builder.build_meta()
2128
+ meta_path = os.path.join(build_dir, "module_codegen.meta")
2129
+
2130
+ with open(meta_path, "w") as meta_file:
2131
+ json.dump(meta, meta_file)
2132
+
2020
2133
  # -----------------------------------------------------------
2021
2134
  # update cache
2022
2135
 
@@ -2053,18 +2166,23 @@ class Module:
2053
2166
 
2054
2167
  # -----------------------------------------------------------
2055
2168
  # Load CPU or CUDA binary
2169
+
2170
+ meta_path = os.path.join(module_dir, "module_codegen.meta")
2171
+ with open(meta_path, "r") as meta_file:
2172
+ meta = json.load(meta_file)
2173
+
2056
2174
  if device.is_cpu:
2057
2175
  # LLVM modules are identified using strings, so we need to ensure uniqueness
2058
2176
  module_handle = f"{module_name}_{self.cpu_exec_id}"
2059
2177
  self.cpu_exec_id += 1
2060
2178
  runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2061
- module_exec = ModuleExec(module_handle, module_hash, device)
2179
+ module_exec = ModuleExec(module_handle, module_hash, device, meta)
2062
2180
  self.execs[None] = module_exec
2063
2181
 
2064
2182
  elif device.is_cuda:
2065
2183
  cuda_module = warp.build.load_cuda(binary_path, device)
2066
2184
  if cuda_module is not None:
2067
- module_exec = ModuleExec(cuda_module, module_hash, device)
2185
+ module_exec = ModuleExec(cuda_module, module_hash, device, meta)
2068
2186
  self.execs[device.context] = module_exec
2069
2187
  else:
2070
2188
  module_load_timer.extra_msg = " (error)"
@@ -2719,21 +2837,16 @@ class Graph:
2719
2837
 
2720
2838
  class Runtime:
2721
2839
  def __init__(self):
2722
- if sys.version_info < (3, 7):
2723
- raise RuntimeError("Warp requires Python 3.7 as a minimum")
2840
+ if sys.version_info < (3, 8):
2841
+ raise RuntimeError("Warp requires Python 3.8 as a minimum")
2724
2842
  if sys.version_info < (3, 9):
2725
2843
  warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
2726
2844
 
2727
2845
  bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
2728
2846
 
2729
2847
  if os.name == "nt":
2730
- if sys.version_info >= (3, 8):
2731
- # Python >= 3.8 this method to add dll search paths
2732
- os.add_dll_directory(bin_path)
2733
-
2734
- else:
2735
- # Python < 3.8 we add dll directory to path
2736
- os.environ["PATH"] = bin_path + os.pathsep + os.environ["PATH"]
2848
+ # Python >= 3.8 this method to add dll search paths
2849
+ os.add_dll_directory(bin_path)
2737
2850
 
2738
2851
  warp_lib = os.path.join(bin_path, "warp.dll")
2739
2852
  llvm_lib = os.path.join(bin_path, "warp-clang.dll")
@@ -3205,6 +3318,8 @@ class Runtime:
3205
3318
  self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
3206
3319
  self.core.is_cutlass_enabled.argtypes = None
3207
3320
  self.core.is_cutlass_enabled.restype = ctypes.c_int
3321
+ self.core.is_mathdx_enabled.argtypes = None
3322
+ self.core.is_mathdx_enabled.restype = ctypes.c_int
3208
3323
 
3209
3324
  self.core.cuda_driver_version.argtypes = None
3210
3325
  self.core.cuda_driver_version.restype = ctypes.c_int
@@ -3329,17 +3444,58 @@ class Runtime:
3329
3444
  self.core.cuda_graph_destroy.restype = ctypes.c_bool
3330
3445
 
3331
3446
  self.core.cuda_compile_program.argtypes = [
3332
- ctypes.c_char_p,
3333
- ctypes.c_int,
3334
- ctypes.c_char_p,
3335
- ctypes.c_bool,
3336
- ctypes.c_bool,
3337
- ctypes.c_bool,
3338
- ctypes.c_bool,
3339
- ctypes.c_char_p,
3447
+ ctypes.c_char_p, # cuda_src
3448
+ ctypes.c_int, # arch
3449
+ ctypes.c_char_p, # include_dir
3450
+ ctypes.c_int, # num_cuda_include_dirs
3451
+ ctypes.POINTER(ctypes.c_char_p), # cuda include dirs
3452
+ ctypes.c_bool, # debug
3453
+ ctypes.c_bool, # verbose
3454
+ ctypes.c_bool, # verify_fp
3455
+ ctypes.c_bool, # fast_math
3456
+ ctypes.c_char_p, # output_path
3457
+ ctypes.c_size_t, # num_ltoirs
3458
+ ctypes.POINTER(ctypes.c_char_p), # ltoirs
3459
+ ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
3340
3460
  ]
3341
3461
  self.core.cuda_compile_program.restype = ctypes.c_size_t
3342
3462
 
3463
+ self.core.cuda_compile_fft.argtypes = [
3464
+ ctypes.c_char_p, # lto
3465
+ ctypes.c_char_p, # function name
3466
+ ctypes.c_int, # num include dirs
3467
+ ctypes.POINTER(ctypes.c_char_p), # include dirs
3468
+ ctypes.c_char_p, # mathdx include dir
3469
+ ctypes.c_int, # arch
3470
+ ctypes.c_int, # size
3471
+ ctypes.c_int, # ept
3472
+ ctypes.c_int, # direction
3473
+ ctypes.c_int, # precision
3474
+ ctypes.POINTER(ctypes.c_int), # smem (out)
3475
+ ]
3476
+ self.core.cuda_compile_fft.restype = ctypes.c_bool
3477
+
3478
+ self.core.cuda_compile_dot.argtypes = [
3479
+ ctypes.c_char_p, # lto
3480
+ ctypes.c_char_p, # function name
3481
+ ctypes.c_int, # num include dirs
3482
+ ctypes.POINTER(ctypes.c_char_p), # include dirs
3483
+ ctypes.c_char_p, # mathdx include dir
3484
+ ctypes.c_int, # arch
3485
+ ctypes.c_int, # M
3486
+ ctypes.c_int, # N
3487
+ ctypes.c_int, # K
3488
+ ctypes.c_int, # a_precision
3489
+ ctypes.c_int, # b_precision
3490
+ ctypes.c_int, # c_precision
3491
+ ctypes.c_int, # type
3492
+ ctypes.c_int, # a_arrangement
3493
+ ctypes.c_int, # b_arrangement
3494
+ ctypes.c_int, # c_arrangement
3495
+ ctypes.c_int, # num threads
3496
+ ]
3497
+ self.core.cuda_compile_dot.restype = ctypes.c_bool
3498
+
3343
3499
  self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3344
3500
  self.core.cuda_load_module.restype = ctypes.c_void_p
3345
3501
 
@@ -3349,11 +3505,19 @@ class Runtime:
3349
3505
  self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
3350
3506
  self.core.cuda_get_kernel.restype = ctypes.c_void_p
3351
3507
 
3508
+ self.core.cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
3509
+ self.core.cuda_get_max_shared_memory.restype = ctypes.c_int
3510
+
3511
+ self.core.cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
3512
+ self.core.cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
3513
+
3352
3514
  self.core.cuda_launch_kernel.argtypes = [
3353
3515
  ctypes.c_void_p,
3354
3516
  ctypes.c_void_p,
3355
3517
  ctypes.c_size_t,
3356
3518
  ctypes.c_int,
3519
+ ctypes.c_int,
3520
+ ctypes.c_int,
3357
3521
  ctypes.POINTER(ctypes.c_void_p),
3358
3522
  ctypes.c_void_p,
3359
3523
  ]
@@ -3382,6 +3546,23 @@ class Runtime:
3382
3546
  self.core.cuda_timing_end.argtypes = []
3383
3547
  self.core.cuda_timing_end.restype = None
3384
3548
 
3549
+ self.core.graph_coloring.argtypes = [
3550
+ ctypes.c_int,
3551
+ warp.types.array_t,
3552
+ ctypes.c_int,
3553
+ warp.types.array_t,
3554
+ ]
3555
+ self.core.graph_coloring.restype = ctypes.c_int
3556
+
3557
+ self.core.balance_coloring.argtypes = [
3558
+ ctypes.c_int,
3559
+ warp.types.array_t,
3560
+ ctypes.c_int,
3561
+ ctypes.c_float,
3562
+ warp.types.array_t,
3563
+ ]
3564
+ self.core.balance_coloring.restype = ctypes.c_float
3565
+
3385
3566
  self.core.init.restype = ctypes.c_int
3386
3567
 
3387
3568
  except AttributeError as e:
@@ -3607,10 +3788,7 @@ class Runtime:
3607
3788
 
3608
3789
  def load_dll(self, dll_path):
3609
3790
  try:
3610
- if sys.version_info >= (3, 8):
3611
- dll = ctypes.CDLL(dll_path, winmode=0)
3612
- else:
3613
- dll = ctypes.CDLL(dll_path)
3791
+ dll = ctypes.CDLL(dll_path, winmode=0)
3614
3792
  except OSError as e:
3615
3793
  if "GLIBCXX" in str(e):
3616
3794
  raise RuntimeError(
@@ -3751,7 +3929,7 @@ def is_cuda_available() -> bool:
3751
3929
  return get_cuda_device_count() > 0
3752
3930
 
3753
3931
 
3754
- def is_device_available(device):
3932
+ def is_device_available(device: Device) -> bool:
3755
3933
  return device in get_devices()
3756
3934
 
3757
3935
 
@@ -3811,7 +3989,7 @@ def get_cuda_devices() -> List[Device]:
3811
3989
 
3812
3990
 
3813
3991
  def get_preferred_device() -> Device:
3814
- """Returns the preferred compute device, CUDA if available and CPU otherwise."""
3992
+ """Returns the preferred compute device, ``cuda:0`` if available and ``cpu`` otherwise."""
3815
3993
 
3816
3994
  init()
3817
3995
 
@@ -3951,7 +4129,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
3951
4129
 
3952
4130
 
3953
4131
  def get_mempool_release_threshold(device: Devicelike) -> int:
3954
- """Get the CUDA memory pool release threshold on the device."""
4132
+ """Get the CUDA memory pool release threshold on the device in bytes."""
3955
4133
 
3956
4134
  init()
3957
4135
 
@@ -3970,7 +4148,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
3970
4148
  """Check if `peer_device` can directly access the memory of `target_device` on this system.
3971
4149
 
3972
4150
  This applies to memory allocated using default CUDA allocators. For memory allocated using
3973
- CUDA pooled allocators, use `is_mempool_access_supported()`.
4151
+ CUDA pooled allocators, use :func:`is_mempool_access_supported()`.
3974
4152
 
3975
4153
  Returns:
3976
4154
  A Boolean value indicating if this peer access is supported by the system.
@@ -3991,7 +4169,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -
3991
4169
  """Check if `peer_device` can currently access the memory of `target_device`.
3992
4170
 
3993
4171
  This applies to memory allocated using default CUDA allocators. For memory allocated using
3994
- CUDA pooled allocators, use `is_mempool_access_enabled()`.
4172
+ CUDA pooled allocators, use :func:`is_mempool_access_enabled()`.
3995
4173
 
3996
4174
  Returns:
3997
4175
  A Boolean value indicating if this peer access is currently enabled.
@@ -4015,7 +4193,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
4015
4193
  a negative impact on memory consumption and allocation performance.
4016
4194
 
4017
4195
  This applies to memory allocated using default CUDA allocators. For memory allocated using
4018
- CUDA pooled allocators, use `set_mempool_access_enabled()`.
4196
+ CUDA pooled allocators, use :func:`set_mempool_access_enabled()`.
4019
4197
  """
4020
4198
 
4021
4199
  init()
@@ -4043,7 +4221,8 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
4043
4221
  def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
4044
4222
  """Check if `peer_device` can directly access the memory pool of `target_device`.
4045
4223
 
4046
- If mempool access is possible, it can be managed using `set_mempool_access_enabled()` and `is_mempool_access_enabled()`.
4224
+ If mempool access is possible, it can be managed using :func:`set_mempool_access_enabled()`
4225
+ and :func:`is_mempool_access_enabled()`.
4047
4226
 
4048
4227
  Returns:
4049
4228
  A Boolean value indicating if this memory pool access is supported by the system.
@@ -4061,7 +4240,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
4061
4240
  """Check if `peer_device` can currently access the memory pool of `target_device`.
4062
4241
 
4063
4242
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
4064
- default CUDA allocators, use `is_peer_access_enabled()`.
4243
+ default CUDA allocators, use :func:`is_peer_access_enabled()`.
4065
4244
 
4066
4245
  Returns:
4067
4246
  A Boolean value indicating if this peer access is currently enabled.
@@ -4082,7 +4261,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
4082
4261
  """Enable or disable access from `peer_device` to the memory pool of `target_device`.
4083
4262
 
4084
4263
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
4085
- default CUDA allocators, use `set_peer_access_enabled()`.
4264
+ default CUDA allocators, use :func:`set_peer_access_enabled()`.
4086
4265
  """
4087
4266
 
4088
4267
  init()
@@ -4791,7 +4970,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4791
4970
  # represents all data required for a kernel launch
4792
4971
  # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
4793
4972
  class Launch:
4794
- def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
4973
+ def __init__(
4974
+ self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0, block_dim=256
4975
+ ):
4795
4976
  # retain the module executable so it doesn't get unloaded
4796
4977
  self.module_exec = kernel.module.load(device)
4797
4978
  if not self.module_exec:
@@ -4830,6 +5011,7 @@ class Launch:
4830
5011
  self.device = device
4831
5012
  self.bounds = bounds
4832
5013
  self.max_blocks = max_blocks
5014
+ self.block_dim = block_dim
4833
5015
 
4834
5016
  def set_dim(self, dim):
4835
5017
  self.bounds = warp.types.launch_bounds_t(dim)
@@ -4911,6 +5093,8 @@ class Launch:
4911
5093
  self.hooks.forward,
4912
5094
  self.bounds.size,
4913
5095
  self.max_blocks,
5096
+ self.block_dim,
5097
+ self.hooks.forward_smem_bytes,
4914
5098
  self.params_addr,
4915
5099
  stream.cuda_stream,
4916
5100
  )
@@ -4929,6 +5113,7 @@ def launch(
4929
5113
  record_tape=True,
4930
5114
  record_cmd=False,
4931
5115
  max_blocks=0,
5116
+ block_dim=256,
4932
5117
  ):
4933
5118
  """Launch a Warp kernel on the target device
4934
5119
 
@@ -4948,6 +5133,7 @@ def launch(
4948
5133
  record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
4949
5134
  max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
4950
5135
  If negative or zero, the maximum hardware value will be used.
5136
+ block_dim: The number of threads per block.
4951
5137
  """
4952
5138
 
4953
5139
  init()
@@ -5001,7 +5187,12 @@ def launch(
5001
5187
  kernel = kernel.add_overload(fwd_types)
5002
5188
 
5003
5189
  # delay load modules, including new overload if needed
5004
- module_exec = kernel.module.load(device)
5190
+ try:
5191
+ module_exec = kernel.module.load(device, block_dim)
5192
+ except Exception:
5193
+ kernel.adj.skip_build = True
5194
+ raise
5195
+
5005
5196
  if not module_exec:
5006
5197
  return
5007
5198
 
@@ -5057,7 +5248,14 @@ def launch(
5057
5248
  )
5058
5249
 
5059
5250
  runtime.core.cuda_launch_kernel(
5060
- device.context, hooks.backward, bounds.size, max_blocks, kernel_params, stream.cuda_stream
5251
+ device.context,
5252
+ hooks.backward,
5253
+ bounds.size,
5254
+ max_blocks,
5255
+ block_dim,
5256
+ hooks.backward_smem_bytes,
5257
+ kernel_params,
5258
+ stream.cuda_stream,
5061
5259
  )
5062
5260
 
5063
5261
  else:
@@ -5080,7 +5278,14 @@ def launch(
5080
5278
  else:
5081
5279
  # launch
5082
5280
  runtime.core.cuda_launch_kernel(
5083
- device.context, hooks.forward, bounds.size, max_blocks, kernel_params, stream.cuda_stream
5281
+ device.context,
5282
+ hooks.forward,
5283
+ bounds.size,
5284
+ max_blocks,
5285
+ block_dim,
5286
+ hooks.forward_smem_bytes,
5287
+ kernel_params,
5288
+ stream.cuda_stream,
5084
5289
  )
5085
5290
 
5086
5291
  try:
@@ -5094,13 +5299,65 @@ def launch(
5094
5299
  # record file, lineno, func as metadata
5095
5300
  frame = inspect.currentframe().f_back
5096
5301
  caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name}
5097
- runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device, metadata={"caller": caller})
5302
+ runtime.tape.record_launch(
5303
+ kernel, dim, max_blocks, inputs, outputs, device, block_dim, metadata={"caller": caller}
5304
+ )
5098
5305
 
5099
5306
  # detect illegal inter-kernel read/write access patterns if verification flag is set
5100
5307
  if warp.config.verify_autograd_array_access:
5101
5308
  runtime.tape._check_kernel_array_access(kernel, fwd_args)
5102
5309
 
5103
5310
 
5311
+ def launch_tiled(*args, **kwargs):
5312
+ """A helper method for launching a grid with an extra trailing dimension equal to the block size.
5313
+
5314
+ For example, to launch a 2D grid, where each element has 64 threads assigned you would use the following:
5315
+
5316
+ .. code-block:: python
5317
+
5318
+ wp.launch_tiled(kernel, [M, N], inputs=[...], block_dim=64)
5319
+
5320
+ Which is equivalent to the following:
5321
+
5322
+ .. code-block:: python
5323
+
5324
+ wp.launch(kernel, [M, N, 64], inputs=[...], block_dim=64)
5325
+
5326
+ Inside your kernel code you can retrieve the first two indices of the thread as usual, ignoring the implicit third dimension if desired:
5327
+
5328
+ .. code-block:: python
5329
+
5330
+ @wp.kernel
5331
+ def compute()
5332
+
5333
+ i, j = wp.tid()
5334
+
5335
+ ...
5336
+ """
5337
+
5338
+ # promote dim to a list in case it was passed as a scalar or tuple
5339
+ if "dim" not in kwargs:
5340
+ raise RuntimeError("Launch dimensions 'dim' argument should be passed via. keyword args for wp.launch_tiled()")
5341
+
5342
+ if "block_dim" not in kwargs:
5343
+ raise RuntimeError(
5344
+ "Launch block dimension 'block_dim' argument should be passed via. keyword args for wp.launch_tiled()"
5345
+ )
5346
+
5347
+ dim = kwargs["dim"]
5348
+ if not isinstance(dim, list):
5349
+ dim = list(dim) if isinstance(dim, tuple) else [dim]
5350
+
5351
+ if len(dim) > 3:
5352
+ raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
5353
+
5354
+ # add trailing dimension
5355
+ kwargs["dim"] = dim + [kwargs["block_dim"]]
5356
+
5357
+ # forward to original launch method
5358
+ launch(*args, **kwargs)
5359
+
5360
+
5104
5361
  def synchronize():
5105
5362
  """Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
5106
5363
 
@@ -5619,16 +5876,6 @@ def type_str(t):
5619
5876
  return "Any"
5620
5877
  elif t == Callable:
5621
5878
  return "Callable"
5622
- elif t == Tuple[int]:
5623
- return "Tuple[int]"
5624
- elif t == Tuple[int, int]:
5625
- return "Tuple[int, int]"
5626
- elif t == Tuple[int, int, int]:
5627
- return "Tuple[int, int, int]"
5628
- elif t == Tuple[int, int, int, int]:
5629
- return "Tuple[int, int, int, int]"
5630
- elif t == Tuple[int, ...]:
5631
- return "Tuple[int, ...]"
5632
5879
  elif isinstance(t, int):
5633
5880
  return str(t)
5634
5881
  elif isinstance(t, List):
@@ -5663,9 +5910,13 @@ def type_str(t):
5663
5910
  return f"Transformation[{type_str(t._wp_scalar_type_)}]"
5664
5911
 
5665
5912
  raise TypeError("Invalid vector or matrix dimensions")
5666
- elif typing.get_origin(t) in (List, Mapping, Sequence, Union, Tuple):
5667
- args_repr = ", ".join(type_str(x) for x in typing.get_args(t))
5668
- return f"{t.__name__}[{args_repr}]"
5913
+ elif warp.codegen.get_type_origin(t) in (list, tuple):
5914
+ args_repr = ", ".join(type_str(x) for x in warp.codegen.get_type_args(t))
5915
+ return f"{t._name}[{args_repr}]"
5916
+ elif t is Ellipsis:
5917
+ return "..."
5918
+ elif warp.types.is_tile(t):
5919
+ return "Tile"
5669
5920
 
5670
5921
  return t.__name__
5671
5922
 
@@ -5826,9 +6077,6 @@ def export_stubs(file): # pragma: no cover
5826
6077
  print('Cols = TypeVar("Cols", bound=int)', file=file)
5827
6078
  print('DType = TypeVar("DType")', file=file)
5828
6079
 
5829
- print('Int = TypeVar("Int")', file=file)
5830
- print('Float = TypeVar("Float")', file=file)
5831
- print('Scalar = TypeVar("Scalar")', file=file)
5832
6080
  print("Vector = Generic[Length, Scalar]", file=file)
5833
6081
  print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
5834
6082
  print("Quaternion = Generic[Float]", file=file)